Commit cb0af126 by xuchen

fix the bug of the conformer

parent 31d0303e
...@@ -83,7 +83,7 @@ class ConformerEncoderLayer(nn.Module): ...@@ -83,7 +83,7 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(self.embed_dim, args.cnn_module_kernel, self.activation_fn) self.conv_module = ConvolutionModule(self.embed_dim, args.cnn_module_kernel, self.activation_fn)
self.final_norm = LayerNorm(self.embed_dim) self.final_norm = LayerNorm(self.embed_dim)
else: else:
self.conv_norm = False self.conv_norm = None
self.conv_module = None self.conv_module = None
self.final_norm = None self.final_norm = None
......
...@@ -9,6 +9,14 @@ from typing import Optional, Tuple ...@@ -9,6 +9,14 @@ from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from fairseq.modules.layer_norm import LayerNorm
class Swish(nn.Module):
"""Construct an Swish object."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return Swish activation function."""
return x * torch.sigmoid(x)
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
...@@ -16,7 +24,7 @@ class ConvolutionModule(nn.Module): ...@@ -16,7 +24,7 @@ class ConvolutionModule(nn.Module):
def __init__(self, def __init__(self,
channels: int, channels: int,
kernel_size: int = 15, kernel_size: int = 15,
activation: nn.Module = nn.ReLU(), activation: nn.Module = Swish(),
norm: str = "batch_norm", norm: str = "batch_norm",
causal: bool = False, causal: bool = False,
bias: bool = True): bias: bool = True):
...@@ -48,6 +56,7 @@ class ConvolutionModule(nn.Module): ...@@ -48,6 +56,7 @@ class ConvolutionModule(nn.Module):
assert (kernel_size - 1) % 2 == 0 assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
self.lorder = 0 self.lorder = 0
self.depthwise_conv = nn.Conv1d( self.depthwise_conv = nn.Conv1d(
channels, channels,
channels, channels,
...@@ -64,7 +73,7 @@ class ConvolutionModule(nn.Module): ...@@ -64,7 +73,7 @@ class ConvolutionModule(nn.Module):
self.norm = nn.BatchNorm1d(channels) self.norm = nn.BatchNorm1d(channels)
else: else:
self.use_layer_norm = True self.use_layer_norm = True
self.norm = nn.LayerNorm(channels) self.norm = LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,
...@@ -74,7 +83,7 @@ class ConvolutionModule(nn.Module): ...@@ -74,7 +83,7 @@ class ConvolutionModule(nn.Module):
padding=0, padding=0,
bias=bias, bias=bias,
) )
self.activation = activation self.activation = Swish()
def forward( def forward(
self, self,
...@@ -109,14 +118,14 @@ class ConvolutionModule(nn.Module): ...@@ -109,14 +118,14 @@ class ConvolutionModule(nn.Module):
assert (x.size(2) > self.lorder) assert (x.size(2) > self.lorder)
new_cache = x[:, :, -self.lorder:] new_cache = x[:, :, -self.lorder:]
else: else:
# It's better we just return None if no cache is requried, # It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of # However, for JIT export, here we just fake one tensor instead of
# None. # None.
new_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device) new_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
# GLU mechanism # GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim) x = self.pointwise_conv1(x) # (batch, 2*channel, time)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim) x = nn.functional.glu(x, dim=1) # (batch, channel, time)
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论