Commit 30aed6f9 by xuchen

optimize the modules (conformer, relative position encoding)

parent 876daed6
......@@ -9,6 +9,7 @@ from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM
from .character_token_embedder import CharacterTokenEmbedder
from .convolution import ConvolutionModule
from .downsample_convolution import DownSampleConvolutionModule
from .conv_tbc import ConvTBC
from .cross_entropy import cross_entropy
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
......@@ -54,6 +55,7 @@ __all__ = [
"ConvTBC",
"CreateLayerHistory",
"cross_entropy",
"DownSampleConvolutionModule",
"DownsampledMultiHeadAttention",
"DynamicConv1dTBC",
"DynamicConv",
......
......@@ -80,7 +80,9 @@ class ConformerEncoderLayer(nn.Module):
if args.use_cnn_module:
self.conv_norm = LayerNorm(self.embed_dim)
self.conv_module = ConvolutionModule(self.embed_dim, args.cnn_module_kernel, self.activation_fn)
self.conv_module = ConvolutionModule(
self.embed_dim,
args.cnn_module_kernel)
self.final_norm = LayerNorm(self.embed_dim)
else:
self.conv_norm = None
......
......@@ -24,9 +24,7 @@ class ConvolutionModule(nn.Module):
def __init__(self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = Swish(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True):
"""Construct an ConvolutionModule object.
Args:
......@@ -44,18 +42,10 @@ class ConvolutionModule(nn.Module):
padding=0,
bias=bias,
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if causal:
padding = 0
self.lorder = kernel_size - 1
else:
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
channels,
......@@ -89,39 +79,21 @@ class ConvolutionModule(nn.Module):
self,
x: torch.Tensor,
mask_pad: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding
cache (torch.Tensor): left context cache, it is only
used in causal convolution
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2)
new_pad = mask_pad.unsqueeze(1).repeat(1, x.size(1), 1)
zero_mask_pad = mask_pad.unsqueeze(1).repeat(1, x.size(1), 1)
# mask batch padding
if mask_pad is not None:
x.masked_fill_(~new_pad, 0.0)
if self.lorder > 0:
if cache is None:
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
else:
assert cache.size(0) == x.size(0)
assert cache.size(1) == x.size(1)
x = torch.cat((cache, x), dim=2)
assert (x.size(2) > self.lorder)
new_cache = x[:, :, -self.lorder:]
else:
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
x.masked_fill_(zero_mask_pad, 0.0)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, time)
......@@ -136,7 +108,7 @@ class ConvolutionModule(nn.Module):
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# mask batch padding
if new_pad is not None:
x.masked_fill_(~new_pad, 0.0)
if zero_mask_pad is not None:
x.masked_fill_(zero_mask_pad, 0.0)
return x.transpose(1, 2)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Author: di.wu@mobvoi.com (DI WU)
"""ConvolutionModule definition."""
from typing import Optional, Tuple
import torch
from torch import nn
from fairseq.modules.layer_norm import LayerNorm
class Swish(nn.Module):
"""Construct an Swish object."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return Swish activation function."""
return x * torch.sigmoid(x)
class DownSampleConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = Swish(),
norm: str = "batch_norm",
stride: int = 1,
causal: bool = False,
bias: bool = True):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
super().__init__()
self.pointwise_conv1 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
# padding = kernel_size // 2
padding = 0
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=stride,
padding=padding,
groups=channels,
bias=bias,
)
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
else:
self.use_layer_norm = True
self.norm = LayerNorm(channels)
self.stride = stride
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = Swish()
def forward(
self,
x: torch.Tensor,
mask_pad: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding
cache (torch.Tensor): left context cache, it is only
used in causal convolution
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0)
zero_mask_pad = mask_pad.unsqueeze(1).repeat(1, x.size(1), 1)
# mask batch padding
if mask_pad is not None:
x.masked_fill_(zero_mask_pad, 0.0)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, time)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# mask batch padding
bsz, dim, seq_len = x.size()
lengths = (~mask_pad).sum(-1)
lengths = (lengths / self.stride).long()
max_length = x.size(-1)
assert max_length >= max(lengths), (max_length, max(lengths))
mask = torch.arange(max_length).to(lengths.device).view(1, max_length)
mask_pad = mask.expand(bsz, -1) >= lengths.view(bsz, 1).expand(-1, max_length)
zero_mask_pad = mask_pad.unsqueeze(1).repeat(1, x.size(1), 1)
if zero_mask_pad is not None:
x.masked_fill_(zero_mask_pad, 0.0)
return x.permute(2, 0, 1)
......@@ -86,7 +86,11 @@ class ReducedMultiheadAttention(nn.Module):
self.add_zero_attn = add_zero_attn
self.sample_ratio = sample_ratio
if self.sample_ratio > 1:
self.sr = nn.Conv1d(embed_dim, embed_dim, kernel_size=sample_ratio, stride=sample_ratio)
self.sr = nn.Conv1d(embed_dim, embed_dim,
kernel_size=sample_ratio,
stride=sample_ratio,
# padding=(sample_ratio - 1) // 2
)
self.norm = nn.LayerNorm(embed_dim)
self.reset_parameters()
......@@ -307,6 +311,16 @@ class ReducedMultiheadAttention(nn.Module):
key_padding_mask = None
if key_padding_mask is not None:
if self.sample_ratio > 1:
lengths = (~key_padding_mask).sum(-1)
lengths = (lengths / self.sample_ratio).long()
# lengths = ((lengths.float() - 1) / self.sample_ratio + 1).floor().long()
max_length = src_len
assert max_length >= max(lengths), (max_length, max(lengths))
mask = torch.arange(max_length).to(lengths.device).view(1, max_length)
key_padding_mask = mask.expand(bsz, -1) >= lengths.view(bsz, 1).expand(-1, max_length)
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
......
......@@ -300,7 +300,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
x = x * torch.tril(ones, x.size(2) - x.size(1))[None, :, :]
return x
matrix_bd = rel_shift(matrix_bd)
# matrix_bd = rel_shift(matrix_bd)
attn_weights = (matrix_ac + matrix_bd) * self.scaling
......@@ -455,7 +455,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
dim: 2 * dim
dim : 2 * dim
]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论