Commit 5160a9f5 by xuchen

optimize the code structure and support the simultaneous speech translation

parent 30aed6f9
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import logging import logging
import torch.nn as nn import torch.nn as nn
from fairseq import checkpoint_utils from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import ( from fairseq.models import (
register_model, register_model,
...@@ -31,8 +31,220 @@ class S2TConformerModel(S2TTransformerModel): ...@@ -31,8 +31,220 @@ class S2TConformerModel(S2TTransformerModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
S2TTransformerModel.add_args(parser) # input
parser.add_argument(
"--conv-kernel-sizes",
type=str,
metavar="N",
help="kernel sizes of Conv1d subsampling layers",
)
parser.add_argument(
"--conv-channels",
type=int,
metavar="N",
help="# of channels in Conv1d subsampling layers",
)
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument('--share-all-embeddings',
action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
'--encoder-history-type',
default="learnable_dense",
help='encoder layer history type'
)
parser.add_argument(
'--decoder-history-type',
default="learnable_dense",
help='decoder layer history type'
)
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument( parser.add_argument(
"--macaron-style", "--macaron-style",
default=False, default=False,
...@@ -44,7 +256,7 @@ class S2TConformerModel(S2TTransformerModel): ...@@ -44,7 +256,7 @@ class S2TConformerModel(S2TTransformerModel):
"--zero-triu", "--zero-triu",
default=False, default=False,
type=bool, type=bool,
help="If true, zero the uppper triangular part of attention matrix.", help="If true, zero the upper triangular part of attention matrix.",
) )
# Relative positional encoding # Relative positional encoding
parser.add_argument( parser.add_argument(
......
...@@ -19,6 +19,7 @@ from fairseq.modules import ( ...@@ -19,6 +19,7 @@ from fairseq.modules import (
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
TransformerEncoderLayer, TransformerEncoderLayer,
ConformerEncoderLayer,
CreateLayerHistory, CreateLayerHistory,
) )
from torch import Tensor from torch import Tensor
...@@ -303,6 +304,51 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -303,6 +304,51 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help='initialized weight for local mask' help='initialized weight for local mask'
) )
# Conformer setting
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# Simultaneous speech translation
parser.add_argument(
"--simul",
default=False,
action="store_true",
help="Simultaneous speech translation or not",
)
pass pass
@classmethod @classmethod
...@@ -321,7 +367,16 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -321,7 +367,16 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
@classmethod @classmethod
def build_decoder(cls, args, task, embed_tokens): def build_decoder(cls, args, task, embed_tokens):
decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
if getattr(args, "simul", False):
from examples.simultaneous_translation.models.transformer_monotonic_attention import (
TransformerMonotonicDecoder,
)
decoder = TransformerMonotonicDecoder(args, task.target_dictionary, embed_tokens)
else:
decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None): if getattr(args, "load_pretrained_decoder_from", None):
logger.info( logger.info(
f"loaded pretrained decoder from: " f"loaded pretrained decoder from: "
...@@ -506,8 +561,6 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -506,8 +561,6 @@ class S2TTransformerEncoder(FairseqEncoder):
"src_lengths": [], "src_lengths": [],
} }
# "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() else [], # B x T
def compute_ctc_logit(self, encoder_out): def compute_ctc_logit(self, encoder_out):
assert self.use_ctc, "CTC is not available!" assert self.use_ctc, "CTC is not available!"
...@@ -602,12 +655,17 @@ class TransformerDecoderScriptable(TransformerDecoder): ...@@ -602,12 +655,17 @@ class TransformerDecoderScriptable(TransformerDecoder):
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
@register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer") @register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer")
def base_architecture(args): def base_architecture(args):
# Convolutional subsampler # Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024) args.conv_channels = getattr(args, "conv_channels", 1024)
# Conformer
args.macaron_style = getattr(args, "macaron_style", True)
args.use_cnn_module = getattr(args, "use_cnn_module", True)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Transformer # Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
...@@ -637,6 +695,7 @@ def base_architecture(args): ...@@ -637,6 +695,7 @@ def base_architecture(args):
args.share_decoder_input_output_embed = getattr( args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False args, "share_decoder_input_output_embed", False
) )
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr( args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False args, "no_token_positional_embeddings", False
) )
......
...@@ -15,6 +15,7 @@ from fairseq.modules import ( ...@@ -15,6 +15,7 @@ from fairseq.modules import (
RelPositionMultiheadAttention, RelPositionMultiheadAttention,
RelativeMultiheadAttention, RelativeMultiheadAttention,
LocalMultiheadAttention, LocalMultiheadAttention,
ConvolutionModule
) )
from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise from fairseq.modules.quant_noise import quant_noise
...@@ -59,6 +60,43 @@ class PyramidTransformerEncoderLayer(nn.Module): ...@@ -59,6 +60,43 @@ class PyramidTransformerEncoderLayer(nn.Module):
self.activation_dropout_module = FairseqDropout( self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__ float(activation_dropout_p), module_name=self.__class__.__name__
) )
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
if args.macaron_style:
self.macaron_fc1 = self.build_fc1(
self.embed_dim,
args.encoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.macaron_fc2 = self.build_fc2(
args.encoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.macaron_norm = LayerNorm(self.embed_dim)
self.ffn_scale = 0.5
else:
self.macaron_fc1 = None
self.macaron_fc2 = None
self.macaron_norm = None
self.ffn_scale = 1.0
if args.use_cnn_module:
self.conv_norm = LayerNorm(self.embed_dim)
self.conv_module = ConvolutionModule(
self.embed_dim,
args.cnn_module_kernel)
self.final_norm = LayerNorm(self.embed_dim)
else:
self.conv_norm = None
self.conv_module = None
self.final_norm = None
self.normalize_before = args.encoder_normalize_before self.normalize_before = args.encoder_normalize_before
self.fc1 = self.build_fc1( self.fc1 = self.build_fc1(
self.embed_dim, self.embed_dim,
...@@ -191,6 +229,16 @@ class PyramidTransformerEncoderLayer(nn.Module): ...@@ -191,6 +229,16 @@ class PyramidTransformerEncoderLayer(nn.Module):
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
# whether to use macaron style
if self.macaron_norm is not None:
residual = x
if self.normalize_before:
x = self.macaron_norm(x)
x = self.macaron_fc2(self.activation_dropout_module(self.activation_fn(self.macaron_fc1(x))))
x = residual + self.ffn_scale * self.dropout_module(x)
if not self.normalize_before:
x = self.macaron_norm(x)
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
...@@ -219,6 +267,17 @@ class PyramidTransformerEncoderLayer(nn.Module): ...@@ -219,6 +267,17 @@ class PyramidTransformerEncoderLayer(nn.Module):
if not self.normalize_before: if not self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
# convolution module
if self.conv_module is not None:
x = x.transpose(0, 1)
residual = x
if self.normalize_before:
x = self.conv_norm(x)
x = residual + self.dropout_module(self.conv_module(x, encoder_padding_mask))
if not self.normalize_before:
x = self.conv_norm(x)
x = x.transpose(0, 1)
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
...@@ -226,9 +285,13 @@ class PyramidTransformerEncoderLayer(nn.Module): ...@@ -226,9 +285,13 @@ class PyramidTransformerEncoderLayer(nn.Module):
x = self.activation_dropout_module(x) x = self.activation_dropout_module(x)
x = self.fc2(x) x = self.fc2(x)
x = self.dropout_module(x) x = self.dropout_module(x)
x = self.residual_connection(x, residual) x = self.residual_connection(self.ffn_scale * x, residual)
if not self.normalize_before: if not self.normalize_before:
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
if self.conv_module is not None:
x = self.final_norm(x)
return x return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论