Commit cb0af126 by xuchen

fix the bug of the conformer

parent 31d0303e
......@@ -83,7 +83,7 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(self.embed_dim, args.cnn_module_kernel, self.activation_fn)
self.final_norm = LayerNorm(self.embed_dim)
else:
self.conv_norm = False
self.conv_norm = None
self.conv_module = None
self.final_norm = None
......
......@@ -9,6 +9,14 @@ from typing import Optional, Tuple
import torch
from torch import nn
from fairseq.modules.layer_norm import LayerNorm
class Swish(nn.Module):
"""Construct an Swish object."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return Swish activation function."""
return x * torch.sigmoid(x)
class ConvolutionModule(nn.Module):
......@@ -16,7 +24,7 @@ class ConvolutionModule(nn.Module):
def __init__(self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = nn.ReLU(),
activation: nn.Module = Swish(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True):
......@@ -48,6 +56,7 @@ class ConvolutionModule(nn.Module):
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = nn.Conv1d(
channels,
channels,
......@@ -64,7 +73,7 @@ class ConvolutionModule(nn.Module):
self.norm = nn.BatchNorm1d(channels)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.norm = LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
......@@ -74,7 +83,7 @@ class ConvolutionModule(nn.Module):
padding=0,
bias=bias,
)
self.activation = activation
self.activation = Swish()
def forward(
self,
......@@ -109,14 +118,14 @@ class ConvolutionModule(nn.Module):
assert (x.size(2) > self.lorder)
new_cache = x[:, :, -self.lorder:]
else:
# It's better we just return None if no cache is requried,
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
x = self.pointwise_conv1(x) # (batch, 2*channel, time)
x = nn.functional.glu(x, dim=1) # (batch, channel, time)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论