Commit 4fbd2ef6 by xuchen

fix the bug of conformer (success)

parent 9ac7a1aa
...@@ -96,6 +96,8 @@ class S2TConformerEncoder(S2TTransformerEncoder): ...@@ -96,6 +96,8 @@ class S2TConformerEncoder(S2TTransformerEncoder):
[ConformerEncoderLayer(args) for _ in range(args.encoder_layers)] [ConformerEncoderLayer(args) for _ in range(args.encoder_layers)]
) )
del self.transformer_layers
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
x, input_lengths = self.subsample(src_tokens, src_lengths) x, input_lengths = self.subsample(src_tokens, src_lengths)
x = self.embed_scale * x x = self.embed_scale * x
......
...@@ -215,7 +215,7 @@ class ConformerEncoderLayer(nn.Module): ...@@ -215,7 +215,7 @@ class ConformerEncoderLayer(nn.Module):
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.conv_norm(x) x = self.conv_norm(x)
x = residual + self.dropout_module(self.conv_module(x)) x = residual + self.dropout_module(self.conv_module(x, encoder_padding_mask))
if not self.normalize_before: if not self.normalize_before:
x = self.conv_norm(x) x = self.conv_norm(x)
x = x.transpose(0, 1) x = x.transpose(0, 1)
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe) # Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Northwestern Polytechnical University (Pengcheng Guo) # Author: di.wu@mobvoi.com (DI WU)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""ConvolutionModule definition.""" """ConvolutionModule definition."""
from typing import Optional, Tuple
import torch
from torch import nn from torch import nn
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model. """ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = nn.ReLU(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True):
"""Construct an ConvolutionModule object.
Args: Args:
channels (int): The number of channels of conv layers. channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers. kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
""" """
super().__init__()
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d( self.pointwise_conv1 = nn.Conv1d(
channels, channels,
...@@ -33,16 +36,36 @@ class ConvolutionModule(nn.Module): ...@@ -33,16 +36,36 @@ class ConvolutionModule(nn.Module):
padding=0, padding=0,
bias=bias, bias=bias,
) )
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if causal:
padding = 0
self.lorder = kernel_size - 1
else:
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = nn.Conv1d( self.depthwise_conv = nn.Conv1d(
channels, channels,
channels, channels,
kernel_size, kernel_size,
stride=1, stride=1,
padding=(kernel_size - 1) // 2, padding=padding,
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels) self.norm = nn.BatchNorm1d(channels)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,
channels, channels,
...@@ -53,27 +76,58 @@ class ConvolutionModule(nn.Module): ...@@ -53,27 +76,58 @@ class ConvolutionModule(nn.Module):
) )
self.activation = activation self.activation = activation
def forward(self, x): def forward(
self,
x: torch.Tensor,
mask_pad: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module. """Compute convolution module.
Args: Args:
x (torch.Tensor): Input tensor (#batch, time, channels). x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding
cache (torch.Tensor): left context cache, it is only
used in causal convolution
Returns: Returns:
torch.Tensor: Output tensor (#batch, time, channels). torch.Tensor: Output tensor (#batch, time, channels).
""" """
# exchange the temporal dimension and the feature dimension # exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2) x = x.transpose(1, 2)
new_pad = mask_pad.unsqueeze(1).repeat(1, x.size(1), 1)
# mask batch padding
if mask_pad is not None:
x.masked_fill_(~new_pad, 0.0)
if self.lorder > 0:
if cache is None:
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
else:
assert cache.size(0) == x.size(0)
assert cache.size(1) == x.size(1)
x = torch.cat((cache, x), dim=2)
assert (x.size(2) > self.lorder)
new_cache = x[:, :, -self.lorder:]
else:
# It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
# GLU mechanism # GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim) x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim) x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.activation(self.norm(x)) x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.pointwise_conv2(x) x = self.pointwise_conv2(x)
# mask batch padding
if new_pad is not None:
x.masked_fill_(~new_pad, 0.0)
return x.transpose(1, 2) return x.transpose(1, 2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论