Commit ebd1be88 by xuchen

fix the implementation of the relative position encoding in conformer, optimize…

fix the implementation of the relative position encoding in conformer, optimize the code of ctc loss
parent 0c7e71c7
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
encoder-attention-type: rel_pos
\ No newline at end of file
...@@ -13,7 +13,7 @@ label_smoothing: 0.1 ...@@ -13,7 +13,7 @@ label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsmapling-layers: 2
subsampling-filter: 1024 subsampling-filter: 512
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
subsampling-norm: none subsampling-norm: none
...@@ -32,3 +32,9 @@ decoder-ffn-embed-dim: 2048 ...@@ -32,3 +32,9 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1 attention-dropout: 0.1
activation-dropout: 0.1 activation-dropout: 0.1
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
encoder-attention-type: rel_selfattn encoder-attention-type: rel_pos
#encoder-attention-type: relative #encoder-attention-type: relative
#max-encoder-relative-length: 100 #max-encoder-relative-length: 100
...@@ -4,14 +4,20 @@ clip-norm: 10.0 ...@@ -4,14 +4,20 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 0.0015
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: ctc
post-process: sentencepiece post-process: sentencepiece
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 704 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 176 encoder-embed-dim: 176
...@@ -22,4 +28,5 @@ encoder-attention-heads: 4 ...@@ -22,4 +28,5 @@ encoder-attention-heads: 4
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
encoder-attention-type: rel_selfattn encoder-activation-fn: swish
\ No newline at end of file encoder-attention-type: rel_pos
\ No newline at end of file
...@@ -6,13 +6,19 @@ lr-scheduler: inverse_sqrt ...@@ -6,13 +6,19 @@ lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
......
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
\ No newline at end of file
...@@ -6,13 +6,19 @@ lr-scheduler: inverse_sqrt ...@@ -6,13 +6,19 @@ lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
......
ctc-weight: 0.2
intermedia-ctc-layers: 6,9 intermedia-ctc-layers: 6,9
intermedia-adapter: league intermedia-adapter: league
intermedia-ctc-weight: 0.15 intermedia-ctc-weight: 0.1
ctc-self-distill-weight: 1 ctc-self-distill-weight: 0
\ No newline at end of file post-process: sentencepiece
\ No newline at end of file
...@@ -5,21 +5,28 @@ lr-scheduler: inverse_sqrt ...@@ -5,21 +5,28 @@ lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: ctc
zero_infinity: True zero_infinity: True
post-process: sentencepiece post-process: sentencepiece
label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
encoder-attention-heads: 4 encoder-attention-heads: 4
attention-dropout: 0.1 #load-pretrained-encoder-from:
activation-dropout: 0.1 \ No newline at end of file
\ No newline at end of file
...@@ -6,7 +6,6 @@ gpu_num=8 ...@@ -6,7 +6,6 @@ gpu_num=8
update_freq=1 update_freq=1
max_tokens=40000 max_tokens=40000
extra_tag= extra_tag=
extra_parameter= extra_parameter=
#extra_tag="${extra_tag}" #extra_tag="${extra_tag}"
...@@ -14,12 +13,12 @@ extra_parameter= ...@@ -14,12 +13,12 @@ extra_parameter=
exp_tag= exp_tag=
#config_list=(base) config_list=(base ctc)
#config_list=(ctc) config_list=(purectc)
#config_list=(base conformer) #config_list=(base conformer)
#config_list=(pds_base_16) #config_list=(pds_base_16)
config_list=(pds_base_16 conformer rpr) #config_list=(pds_base_16 conformer rpr)
# exp full name # exp full name
exp_name= exp_name=
......
...@@ -6,13 +6,19 @@ lr-scheduler: inverse_sqrt ...@@ -6,13 +6,19 @@ lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
......
ctc-weight: 0.2
intermedia-ctc-layers: 6,9 intermedia-ctc-layers: 6,9
intermedia-adapter: league intermedia-adapter: league
intermedia-ctc-weight: 0.15 intermedia-ctc-weight: 0.1
ctc-self-distill-weight: 1 ctc-self-distill-weight: 0
\ No newline at end of file post-process: sentencepiece
\ No newline at end of file
...@@ -85,34 +85,45 @@ class InterAdapter(nn.Module): ...@@ -85,34 +85,45 @@ class InterAdapter(nn.Module):
if self.adapter_type == "shrink": if self.adapter_type == "shrink":
self.ctc_compress = getattr(CTCCompressStrategy, strategy) self.ctc_compress = getattr(CTCCompressStrategy, strategy)
logger.info("CTC Compress Strategy: %s" % strategy)
elif self.adapter_type == "league":
self.distribution_cutoff = strategy
if self.distribution_cutoff != -1:
logger.info("Distribution cutoff: %d" % int(strategy))
def forward(self, x, padding): def forward(self, x, padding):
representation, distribution = x representation, distribution = x
dim1, dim2, dim = representation.size() dim1, dim2, dim = representation.size()
org_distribution = distribution org_distribution = distribution
if distribution is not None:
distribution = distribution.view(-1, distribution.size(-1))
lengths = (~padding).long().sum(-1) lengths = (~padding).long().sum(-1)
if self.adapter_type == "linear": if self.adapter_type == "linear":
out = self.linear_adapter(representation) out = self.linear_adapter(representation)
elif self.adapter_type == "context": elif self.adapter_type == "context":
distribution = distribution.view(-1, distribution.size(-1))
out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1) out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
elif self.adapter_type == "league": elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
if self.distribution_cutoff != -1:
cutoff = min(int(self.distribution_cutoff), distribution.size(-1) - 1)
threshold = distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
distribution = torch.where(distribution > threshold, distribution, torch.zeros_like(distribution))
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
out = linear_out + soft_out out = linear_out + soft_out
elif self.adapter_type == "gated_league": elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid() coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league": elif self.adapter_type == "inter_league":
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
out = representation + soft_out out = representation + soft_out
......
...@@ -19,6 +19,7 @@ from fairseq.modules import ( ...@@ -19,6 +19,7 @@ from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
RelPositionalEncoding,
S2TTransformerEncoderLayer, S2TTransformerEncoderLayer,
DynamicLinearCombination, DynamicLinearCombination,
) )
...@@ -132,6 +133,9 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -132,6 +133,9 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"reduced", "reduced",
"rel_selfattn", "rel_selfattn",
"relative", "relative",
"rel_pos",
"rope",
"abs"
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
...@@ -299,6 +303,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -299,6 +303,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
# Conformer setting # Conformer setting
parser.add_argument( parser.add_argument(
"--encoder-activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--macaron-style", "--macaron-style",
default=False, default=False,
type=bool, type=bool,
...@@ -369,6 +380,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -369,6 +380,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=str, type=str,
help="type of intermedia adapter", help="type of intermedia adapter",
) )
parser.add_argument(
"--intermedia-distribution-cutoff",
default=-1,
type=int,
help="cutoff of the distribution",
)
pass pass
@classmethod @classmethod
...@@ -477,8 +494,16 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -477,8 +494,16 @@ class S2TTransformerEncoder(FairseqEncoder):
self.subsample = subsampling(args) self.subsample = subsampling(args)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
if self.attn_type == "rel_pos":
self.embed_positions = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
)
elif self.attn_type == "rope":
self.embed_positions = None
else: # Use absolute positional embedding
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
args.max_source_positions, dim, self.padding_idx args.max_source_positions, args.encoder_embed_dim, self.padding_idx
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
...@@ -540,6 +565,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -540,6 +565,8 @@ class S2TTransformerEncoder(FairseqEncoder):
strategy = None strategy = None
if args.intermedia_adapter == "shrink": if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None) strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", -1)
self.adapter = InterAdapter(dim, args.intermedia_adapter, self.adapter = InterAdapter(dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy) task.source_dictionary, strategy=strategy)
...@@ -586,11 +613,20 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -586,11 +613,20 @@ class S2TTransformerEncoder(FairseqEncoder):
# padding and position embedding # padding and position embedding
encoder_padding_mask = lengths_to_padding_mask(input_lengths) encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if self.attn_type == "rel_pos":
positions = self.embed_positions(x)
elif self.attn_type == "rope":
positions = None
else:
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
x += positions x += positions
positions = None
x = self.dropout_module(x) x = self.dropout_module(x)
positions = self.dropout_module(positions) # positions = self.dropout_module(positions)
# add emb into history # add emb into history
if self.history is not None: if self.history is not None:
...@@ -742,8 +778,13 @@ class TransformerDecoderScriptable(TransformerDecoder): ...@@ -742,8 +778,13 @@ class TransformerDecoderScriptable(TransformerDecoder):
@register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer") @register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer")
def base_architecture(args): def base_architecture(args):
# Convolutional subsampler # Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.subsampling_type = getattr(args, "subsampling_type", "conv1d")
args.conv_channels = getattr(args, "conv_channels", 1024) args.subsampling_layers = getattr(args, "subsampling_layers", 2)
args.subsampling_filter = getattr(args, "subsampling_filter", 1024)
args.subsampling_kernel = getattr(args, "subsampling_kernel", 5)
args.subsampling_stride = getattr(args, "subsampling_stride", 2)
args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
# Transformer # Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
...@@ -791,6 +832,7 @@ def base_architecture(args): ...@@ -791,6 +832,7 @@ def base_architecture(args):
args.ctc_layer = getattr(args, "ctc_layer", 0) args.ctc_layer = getattr(args, "ctc_layer", 0)
# Conformer # Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False) args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False) args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31) args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
......
...@@ -10,7 +10,6 @@ from .adaptive_input import AdaptiveInput ...@@ -10,7 +10,6 @@ from .adaptive_input import AdaptiveInput
from .adaptive_softmax import AdaptiveSoftmax from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .character_token_embedder import CharacterTokenEmbedder from .character_token_embedder import CharacterTokenEmbedder
from .convolution import ConvolutionModule
from .downsample_convolution import DownSampleConvolutionModule from .downsample_convolution import DownSampleConvolutionModule
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .cross_entropy import cross_entropy from .cross_entropy import cross_entropy
...@@ -30,6 +29,7 @@ from .learned_positional_embedding import LearnedPositionalEmbedding ...@@ -30,6 +29,7 @@ from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
from .linearized_convolution import LinearizedConvolution from .linearized_convolution import LinearizedConvolution
from .local_multihead_attention import LocalMultiheadAttention from .local_multihead_attention import LocalMultiheadAttention
from .location_attention import LocationAttention
from .multihead_attention import MultiheadAttention from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding from .positional_embedding import PositionalEmbedding
from .reduced_multihead_attention import ReducedMultiheadAttention from .reduced_multihead_attention import ReducedMultiheadAttention
...@@ -44,6 +44,16 @@ from .transpose_last import TransposeLast ...@@ -44,6 +44,16 @@ from .transpose_last import TransposeLast
from .unfold import unfold1d from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .vggblock import VGGBlock from .vggblock import VGGBlock
from .espnet_multihead_attention import (
ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention,
)
from .rotary_positional_embedding import RotaryPositionalEmbedding
from .positional_encoding import (
RelPositionalEncoding,
)
from .convolution import ConvolutionModule
from .s2t_transformer_layer import S2TTransformerEncoderLayer from .s2t_transformer_layer import S2TTransformerEncoderLayer
from .pds_layer import PDSTransformerEncoderLayer from .pds_layer import PDSTransformerEncoderLayer
...@@ -77,6 +87,7 @@ __all__ = [ ...@@ -77,6 +87,7 @@ __all__ = [
"LightweightConv", "LightweightConv",
"LinearizedConvolution", "LinearizedConvolution",
"LocalMultiheadAttention", "LocalMultiheadAttention",
"MultiheadAttention", "MultiheadAttention",
"PositionalEmbedding", "PositionalEmbedding",
"PDSTransformerEncoderLayer", "PDSTransformerEncoderLayer",
...@@ -96,4 +107,10 @@ __all__ = [ ...@@ -96,4 +107,10 @@ __all__ = [
"TransposeLast", "TransposeLast",
"VGGBlock", "VGGBlock",
"unfold1d", "unfold1d",
"ESPNETMultiheadedAttention",
"PositionalEmbedding",
"RelPositionMultiHeadedAttention",
"RelPositionalEncoding",
"RotaryPositionalEmbedding",
"RotaryPositionMultiHeadedAttention",
] ]
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from typing import Callable
def get_activation_fn(activation: str) -> Callable:
""" Returns the activation function corresponding to `activation` """
from fairseq.modules import gelu, gelu_accurate
if activation == "relu":
return F.relu
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "swish":
return torch.nn.SiLU
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
def get_activation_class(activation: str, dim=None): def get_activation_class(activation: str, dim=None):
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Author: di.wu@mobvoi.com (DI WU)
"""ConvolutionModule definition."""
from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from fairseq.modules.layer_norm import LayerNorm
from fairseq.modules.activations import get_activation_class from fairseq.modules.activations import get_activation_class
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.""" """Convolution block used in the conformer block"""
def __init__(self,
channels: int,
kernel_size: int = 15,
norm: str = "batch_norm",
bias: bool = True):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
super().__init__()
self.pointwise_conv1 = nn.Conv1d( def __init__(
self,
embed_dim,
channels, channels,
depthwise_kernel_size,
dropout,
activation_fn="swish",
bias=False,
export=False,
):
"""
Args:
embed_dim: Embedding dimension
channels: Number of channels in depthwise conv layers
depthwise_kernel_size: Depthwise conv layer kernel size
dropout: dropout value
activation_fn: Activation function to use after depthwise convolution kernel
bias: If bias should be added to conv layers
export: If layernorm should be exported to jit
"""
super(ConvolutionModule, self).__init__()
assert (
depthwise_kernel_size - 1
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
self.pointwise_conv1 = torch.nn.Conv1d(
embed_dim,
2 * channels, 2 * channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias=bias, bias=bias,
) )
self.glu = torch.nn.GLU(dim=1)
# kernel_size should be an odd number for none causal convolution self.depthwise_conv = torch.nn.Conv1d(
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
channels, channels,
channels, channels,
kernel_size, depthwise_kernel_size,
stride=1, stride=1,
padding=padding, padding=(depthwise_kernel_size - 1) // 2,
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.batch_norm = nn.BatchNorm1d(channels)
assert norm in ['batch_norm', 'layer_norm'] self.activation = get_activation_class(activation_fn)
if norm == "batch_norm": self.pointwise_conv2 = torch.nn.Conv1d(
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
else:
self.use_layer_norm = True
self.norm = LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels, channels,
embed_dim,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias=bias, bias=bias,
) )
self.activation = get_activation_class("swish") self.dropout = torch.nn.Dropout(dropout)
def forward( def forward(self, x):
self, """
x: torch.Tensor,
mask_pad: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args: Args:
x (torch.Tensor): Input tensor (#batch, time, channels). x: Input of shape B X T X C
mask_pad (torch.Tensor): used for batch padding
Returns: Returns:
torch.Tensor: Output tensor (#batch, time, channels). Tensor of shape B X T X C
""" """
# exchange the temporal dimension and the feature dimension # exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2) x = x.transpose(1, 2)
zero_mask_pad = mask_pad.unsqueeze(1).repeat(1, x.size(1), 1)
# mask batch padding
if mask_pad is not None:
x.masked_fill_(zero_mask_pad, 0.0)
# GLU mechanism # GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, time) x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, time) x = self.glu(x) # (batch, channel, dim)
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
if self.use_layer_norm: x = self.batch_norm(x)
x = x.transpose(1, 2) x = self.activation(x)
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# mask batch padding
if zero_mask_pad is not None:
x.masked_fill_(zero_mask_pad, 0.0)
x = self.pointwise_conv2(x)
x = self.dropout(x)
return x.transpose(1, 2) return x.transpose(1, 2)
#
# class ConvolutionModule(nn.Module):
# """ConvolutionModule in Conformer model."""
# def __init__(self,
# channels: int,
# kernel_size: int = 15,
# norm: str = "batch_norm",
# bias: bool = True):
# """Construct an ConvolutionModule object.
# Args:
# channels (int): The number of channels of conv layers.
# kernel_size (int): Kernel size of conv layers.
# causal (int): Whether use causal convolution or not
# """
# super().__init__()
#
# self.pointwise_conv1 = nn.Conv1d(
# channels,
# 2 * channels,
# kernel_size=1,
# stride=1,
# padding=0,
# bias=bias,
# )
#
# # kernel_size should be an odd number for none causal convolution
# assert (kernel_size - 1) % 2 == 0
# padding = (kernel_size - 1) // 2
#
# self.depthwise_conv = nn.Conv1d(
# channels,
# channels,
# kernel_size,
# stride=1,
# padding=padding,
# groups=channels,
# bias=bias,
# )
#
# assert norm in ['batch_norm', 'layer_norm']
# if norm == "batch_norm":
# self.use_layer_norm = False
# self.norm = nn.BatchNorm1d(channels)
# else:
# self.use_layer_norm = True
# self.norm = LayerNorm(channels)
#
# self.pointwise_conv2 = nn.Conv1d(
# channels,
# channels,
# kernel_size=1,
# stride=1,
# padding=0,
# bias=bias,
# )
# self.activation = get_activation_class("swish")
#
# def forward(
# self,
# x: torch.Tensor,
# mask_pad: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# """Compute convolution module.
# Args:
# x (torch.Tensor): Input tensor (#batch, time, channels).
# mask_pad (torch.Tensor): used for batch padding
# Returns:
# torch.Tensor: Output tensor (#batch, time, channels).
# """
# # exchange the temporal dimension and the feature dimension
# x = x.transpose(1, 2)
#
# # zero_mask_pad = mask_pad.unsqueeze(1).repeat(1, x.size(1), 1)
# # # mask batch padding
# # if mask_pad is not None:
# # x.masked_fill_(zero_mask_pad, 0.0)
#
# # GLU mechanism
# x = self.pointwise_conv1(x) # (batch, 2*channel, time)
# x = nn.functional.glu(x, dim=1) # (batch, channel, time)
#
# # 1D Depthwise Conv
# x = self.depthwise_conv(x)
# if self.use_layer_norm:
# x = x.transpose(1, 2)
# x = self.activation(self.norm(x))
# if self.use_layer_norm:
# x = x.transpose(1, 2)
# x = self.pointwise_conv2(x)
#
# # # mask batch padding
# # if zero_mask_pad is not None:
# # x.masked_fill_(zero_mask_pad, 0.0)
#
# return x.transpose(1, 2)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Multi-Head Attention layer definition."""
import math
import torch
from torch import nn
from fairseq.modules.rotary_positional_embedding import (
RotaryPositionalEmbedding,
apply_rotary_pos_emb,
)
class ESPNETMultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head: The number of heads.
n_feat: The number of features.
dropout: Dropout rate.
"""
def __init__(self, n_feat, n_head, dropout):
"""Construct an MultiHeadedAttention object."""
super(ESPNETMultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward_qkv(self, query, key, value, **kwargs):
"""Transform query, key and value.
Args:
query: Query tensor B X T1 X C
key: Key tensor B X T2 X C
value: Value tensor B X T2 X C
Returns:
torch.Tensor: Transformed query tensor B X n_head X T1 X d_k
torch.Tensor: Transformed key tensor B X n_head X T2 X d_k
torch.Tensor: Transformed value tensor B X n_head X T2 X d_k
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(self, value, scores, mask):
"""Compute attention context vector.
Args:
value: Transformed value B X n_head X T2 X d_k.
scores: Attention score B X n_head X T1 X T2
mask: Mask T2 X B
Returns:
torch.Tensor: Transformed value B X T1 X d_model
weighted by the attention score B X T1 X T2
"""
n_batch = value.size(0)
if mask is not None:
scores = scores.masked_fill(
mask.unsqueeze(1).unsqueeze(2).to(bool),
float("-inf"), # (batch, head, time1, time2)
)
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, query, key, value, key_padding_mask=None, **kwargs):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor T X B X C
key (torch.Tensor): Key tensor T X B X C
value (torch.Tensor): Value tensor T X B X C
mask (torch.Tensor): Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X D.
"""
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = self.forward_attention(v, scores, key_padding_mask)
scores = scores.transpose(0, 1)
return scores, None
class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head: The number of heads.
n_feat: The number of features.
dropout: Dropout rate.
zero_triu: Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_feat, n_head, dropout, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_feat, n_head, dropout)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x: Input tensor B X n_head X T X 2T-1
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)[
:, :, :, : x.size(-1) // 2 + 1
] # only keep the positions from 0 to time2
if self.zero_triu:
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs):
"""Compute scaled dot product attention.
Args:
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
pos_emb: Positional embedding tensor B X 2T-1 X C
key_padding_mask: Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X C.
"""
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
pos_emb = pos_emb.transpose(0, 1)
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, 2*time1-1)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k
) # (batch, head, time1, time2)
scores = self.forward_attention(v, scores, key_padding_mask)
scores = scores.transpose(0, 1)
return scores, None
class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
def __init__(
self,
n_feat,
n_head,
dropout,
precision,
rotary_emd_base=10000,
):
"""Construct an RotaryPositionMultiHeadedAttention object."""
super().__init__(n_feat, n_head, dropout)
precision = torch.float
self.rotary_ndims = self.d_k # also try self.d_k//2
if precision == "fp16":
precision = torch.half
self.rotary_emb = RotaryPositionalEmbedding(
self.rotary_ndims, base=rotary_emd_base, precision=precision
)
def forward(self, query, key, value, key_padding_mask=None, **kwargs):
"""Compute rotary position attention.
Args:
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
key_padding_mask: Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X D.
Notes:
Assumes self attn
"""
T, B, C = value.size()
query = query.view(T, B, self.h, self.d_k)
key = key.view(T, B, self.h, self.d_k)
value = value.view(T, B, self.h, self.d_k)
cos, sin = self.rotary_emb(value, seq_len=T)
query, key = apply_rotary_pos_emb(
query, key, cos, sin, offset=0
) # offset is based on layer_past
query = query.view(T, B, self.h * self.d_k)
key = key.view(T, B, self.h * self.d_k)
value = value.view(T, B, self.h * self.d_k)
# TBD to BTD
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = self.forward_attention(v, scores, key_padding_mask)
scores = scores.transpose(0, 1)
return scores, None
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
import torch
import torch.nn.functional as F
class LocationAttention(nn.Module):
"""
Attention-Based Models for Speech Recognition
https://arxiv.org/pdf/1506.07503.pdf
:param int encoder_dim: # projection-units of encoder
:param int decoder_dim: # units of decoder
:param int attn_dim: attention dimension
:param int conv_dim: # channels of attention convolution
:param int conv_kernel_size: filter size of attention convolution
"""
def __init__(
self,
attn_dim,
encoder_dim,
decoder_dim,
attn_state_kernel_size,
conv_dim,
conv_kernel_size,
scaling=2.0,
):
super(LocationAttention, self).__init__()
self.attn_dim = attn_dim
self.decoder_dim = decoder_dim
self.scaling = scaling
self.proj_enc = nn.Linear(encoder_dim, attn_dim)
self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False)
self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False)
self.conv = nn.Conv1d(
attn_state_kernel_size,
conv_dim,
2 * conv_kernel_size + 1,
padding=conv_kernel_size,
bias=False,
)
self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1))
self.proj_enc_out = None # cache
def clear_cache(self):
self.proj_enc_out = None
def forward(self, encoder_out, encoder_padding_mask, decoder_h, attn_state):
"""
:param torch.Tensor encoder_out: padded encoder hidden state B x T x D
:param torch.Tensor encoder_padding_mask: encoder padding mask
:param torch.Tensor decoder_h: decoder hidden state B x D
:param torch.Tensor attn_prev: previous attention weight B x K x T
:return: attention weighted encoder state (B, D)
:rtype: torch.Tensor
:return: previous attention weights (B x T)
:rtype: torch.Tensor
"""
bsz, seq_len, _ = encoder_out.size()
if self.proj_enc_out is None:
self.proj_enc_out = self.proj_enc(encoder_out)
# B x K x T -> B x C x T
attn = self.conv(attn_state)
# B x C x T -> B x T x C -> B x T x D
attn = self.proj_attn(attn.transpose(1, 2))
if decoder_h is None:
decoder_h = encoder_out.new_zeros(bsz, self.decoder_dim)
dec_h = self.proj_dec(decoder_h).view(bsz, 1, self.attn_dim)
out = self.proj_out(attn + self.proj_enc_out + dec_h).squeeze(2)
out.masked_fill_(encoder_padding_mask, -float("inf"))
w = F.softmax(self.scaling * out, dim=1)
c = torch.sum(encoder_out * w.view(bsz, seq_len, 1), dim=1)
return c, w
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
import math
import torch
class PositionalEncoding(nn.Module):
"""Positional encoding.
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
reverse: Whether to reverse the input position.
"""
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
"""Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__()
self.d_model = d_model
self.reverse = reverse
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor B X T X C
Returns:
torch.Tensor: Encoded tensor B X T X C
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1)]
return self.dropout(x)
class RelPositionalEncoding(nn.Module):
"""Relative positional encoding module (new implementation).
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(self, max_len, d_model):
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x : Input tensor T X B X C.
Returns:
torch.Tensor: Encoded tensor T X B X C.
"""
x = x.transpose(0, 1) # Change TBC to BTC
self.extend_pe(x)
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
]
pos_emb = pos_emb.transpose(0, 1) # change to TBC
return pos_emb
import torch
class RotaryPositionalEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, precision=torch.half):
"""Rotary positional embedding
Reference : https://blog.eleuther.ai/rotary-embeddings/
Paper: https://arxiv.org/pdf/2104.09864.pdf
Args:
dim: Dimension of embedding
base: Base value for exponential
precision: precision to use for numerical values
"""
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision
def forward(self, x, seq_len=None):
"""
Args:
x: Input x with T X B X C
seq_len: Sequence length of input x
"""
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[:, None, None, :]
self.sin_cached = emb.sin()[:, None, None, :]
return self.cos_cached, self.sin_cached
# rotary pos emb helpers:
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat(
(-x2, x1), dim=x1.ndim - 1
) # dim=-1 triggers a bug in earlier torch versions
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
cos, sin = (
cos[offset : q.shape[0] + offset, ...],
sin[offset : q.shape[0] + offset, ...],
)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import List
from fairseq.modules.activations import Swish from fairseq.modules.activations import Swish
from fairseq.modules.layer_norm import LayerNorm from fairseq.modules.layer_norm import LayerNorm
...@@ -46,6 +48,61 @@ def get_norm(norm_type, size, transpose=False): ...@@ -46,6 +48,61 @@ def get_norm(norm_type, size, transpose=False):
raise RuntimeError("normalization type {} not supported".format(norm_type)) raise RuntimeError("normalization type {} not supported".format(norm_type))
class Conv1dSubsampler(nn.Module):
"""Convolutional subsampler: a stack of 1D convolution (along temporal
dimension) followed by non-linear activation via gated linear units
(https://arxiv.org/abs/1911.08460)
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(
self,
in_channels: int,
mid_channels: int,
out_channels: int,
kernel_sizes: List[int] = (3, 3),
):
super(Conv1dSubsampler, self).__init__()
self.n_layers = len(kernel_sizes)
self.conv_layers = nn.ModuleList(
nn.Conv1d(
in_channels if i == 0 else mid_channels // 2,
mid_channels if i < self.n_layers - 1 else out_channels * 2,
k,
stride=2,
padding=k // 2,
)
for i, k in enumerate(kernel_sizes)
)
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone()
for _ in range(self.n_layers):
out = ((out.float() - 1) / 2 + 1).floor().long()
return out
def forward(self, src_tokens, src_lengths):
bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D)
x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T
inner_x = []
for conv in self.conv_layers:
x = conv(x)
x = nn.functional.glu(x, dim=1)
inner_x.append(x)
_, _, out_seq_len = x.size()
# x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D)
out_inner_x = []
for x in inner_x:
out_inner_x.append(x.transpose(1, 2).transpose(0, 1).contiguous())
return out_inner_x, self.get_out_seq_lens_tensor(src_lengths)
# fairseq style
class Conv1dSubsampling(nn.Module): class Conv1dSubsampling(nn.Module):
"""Conv1d Subsampling Block """Conv1d Subsampling Block
...@@ -74,12 +131,14 @@ class Conv1dSubsampling(nn.Module): ...@@ -74,12 +131,14 @@ class Conv1dSubsampling(nn.Module):
# Layers # Layers
self.layers = nn.ModuleList([nn.Sequential( self.layers = nn.ModuleList([nn.Sequential(
nn.Conv1d(in_dim if layer_id == 0 else filters[layer_id - 1], nn.Conv1d(in_dim if layer_id == 0 else filters[layer_id - 1] // 2 if act == "glu" else filters[layer_id - 1],
filters[layer_id] * 2 if act == "glu" else filters[layer_id], filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
kernel_size, kernel_size,
stride=stride, stride=stride,
padding=(kernel_size - 1) // 2), padding=(kernel_size - 1) // 2),
get_norm(norm, filters[layer_id], transpose=True if norm == "layer" else False), get_norm(norm,
filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
transpose=True if norm == "layer" else False),
get_activation_class(act, dim=1) get_activation_class(act, dim=1)
) for layer_id in range(num_layers)]) ) for layer_id in range(num_layers)])
...@@ -126,12 +185,14 @@ class Conv2dSubsampling(nn.Module): ...@@ -126,12 +185,14 @@ class Conv2dSubsampling(nn.Module):
# Conv 2D Subsampling Layers # Conv 2D Subsampling Layers
self.layers = nn.ModuleList([nn.Sequential( self.layers = nn.ModuleList([nn.Sequential(
nn.Conv2d(1 if layer_id == 0 else filters[layer_id - 1], nn.Conv2d(1 if layer_id == 0 else filters[layer_id - 1] // 2 if act == "glu" else filters[layer_id - 1],
filters[layer_id] * 2 if act =="glu" else filters[layer_id], filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
kernel_size, kernel_size,
stride=stride, stride=stride,
padding=(kernel_size - 1) // 2), padding=(kernel_size - 1) // 2),
get_norm(norm, filters[layer_id], transpose=True if norm == "layer" else False), get_norm(norm,
filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
transpose=True if norm == "layer" else False),
get_activation_class(act, dim=1) get_activation_class(act, dim=1)
) for layer_id in range(num_layers)]) ) for layer_id in range(num_layers)])
self.linear = nn.Linear(filters[-1] * in_dim // 2 ** num_layers, filters[-1]) self.linear = nn.Linear(filters[-1] * in_dim // 2 ** num_layers, filters[-1])
...@@ -139,7 +200,7 @@ class Conv2dSubsampling(nn.Module): ...@@ -139,7 +200,7 @@ class Conv2dSubsampling(nn.Module):
def forward(self, x, x_len): def forward(self, x, x_len):
# (B, T, D) -> (B, D, T) -> (B, 1, D, T) # (B, T, D) -> (B, D, T) -> (B, 1, D, T)
x = x.tranpose(1, 2).unsqueeze(dim=1) x = x.transpose(1, 2).unsqueeze(dim=1)
# Layers # Layers
for layer in self.layers: for layer in self.layers:
......
...@@ -17,7 +17,9 @@ from typing import Callable, Dict, List, Optional ...@@ -17,7 +17,9 @@ from typing import Callable, Dict, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.modules.multihead_attention import MultiheadAttention from fairseq.modules.multihead_attention import MultiheadAttention
from torch import Tensor from torch import Tensor
...@@ -514,6 +516,8 @@ def get_activation_fn(activation: str) -> Callable: ...@@ -514,6 +516,8 @@ def get_activation_fn(activation: str) -> Callable:
return torch.tanh return torch.tanh
elif activation == "linear": elif activation == "linear":
return lambda x: x return lambda x: x
elif activation == "swish":
return torch.nn.SiLU
else: else:
raise RuntimeError("--activation-fn {} not supported".format(activation)) raise RuntimeError("--activation-fn {} not supported".format(activation))
...@@ -526,6 +530,7 @@ def get_available_activation_fns() -> List: ...@@ -526,6 +530,7 @@ def get_available_activation_fns() -> List:
"gelu_accurate", "gelu_accurate",
"tanh", "tanh",
"linear", "linear",
"swish",
] ]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论