Commit 5160a9f5 by xuchen

optimize the code structure and support the simultaneous speech translation

parent 30aed6f9
......@@ -3,7 +3,7 @@
import logging
import torch.nn as nn
from fairseq import checkpoint_utils
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
register_model,
......@@ -31,8 +31,220 @@ class S2TConformerModel(S2TTransformerModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
S2TTransformerModel.add_args(parser)
# input
parser.add_argument(
"--conv-kernel-sizes",
type=str,
metavar="N",
help="kernel sizes of Conv1d subsampling layers",
)
parser.add_argument(
"--conv-channels",
type=int,
metavar="N",
help="# of channels in Conv1d subsampling layers",
)
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument('--share-all-embeddings',
action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
'--encoder-history-type',
default="learnable_dense",
help='encoder layer history type'
)
parser.add_argument(
'--decoder-history-type',
default="learnable_dense",
help='decoder layer history type'
)
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument(
"--macaron-style",
default=False,
......@@ -44,7 +256,7 @@ class S2TConformerModel(S2TTransformerModel):
"--zero-triu",
default=False,
type=bool,
help="If true, zero the uppper triangular part of attention matrix.",
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
......
......@@ -19,6 +19,7 @@ from fairseq.modules import (
LayerNorm,
PositionalEmbedding,
TransformerEncoderLayer,
ConformerEncoderLayer,
CreateLayerHistory,
)
from torch import Tensor
......@@ -303,6 +304,51 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# Simultaneous speech translation
parser.add_argument(
"--simul",
default=False,
action="store_true",
help="Simultaneous speech translation or not",
)
pass
@classmethod
......@@ -321,7 +367,16 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
@classmethod
def build_decoder(cls, args, task, embed_tokens):
if getattr(args, "simul", False):
from examples.simultaneous_translation.models.transformer_monotonic_attention import (
TransformerMonotonicDecoder,
)
decoder = TransformerMonotonicDecoder(args, task.target_dictionary, embed_tokens)
else:
decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None):
logger.info(
f"loaded pretrained decoder from: "
......@@ -506,8 +561,6 @@ class S2TTransformerEncoder(FairseqEncoder):
"src_lengths": [],
}
# "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() else [], # B x T
def compute_ctc_logit(self, encoder_out):
assert self.use_ctc, "CTC is not available!"
......@@ -602,12 +655,17 @@ class TransformerDecoderScriptable(TransformerDecoder):
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
@register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer")
def base_architecture(args):
# Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024)
# Conformer
args.macaron_style = getattr(args, "macaron_style", True)
args.use_cnn_module = getattr(args, "use_cnn_module", True)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
......@@ -637,6 +695,7 @@ def base_architecture(args):
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
......
......@@ -15,6 +15,7 @@ from fairseq.modules import (
RelPositionMultiheadAttention,
RelativeMultiheadAttention,
LocalMultiheadAttention,
ConvolutionModule
)
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
......@@ -59,6 +60,43 @@ class PyramidTransformerEncoderLayer(nn.Module):
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
if args.macaron_style:
self.macaron_fc1 = self.build_fc1(
self.embed_dim,
args.encoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.macaron_fc2 = self.build_fc2(
args.encoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.macaron_norm = LayerNorm(self.embed_dim)
self.ffn_scale = 0.5
else:
self.macaron_fc1 = None
self.macaron_fc2 = None
self.macaron_norm = None
self.ffn_scale = 1.0
if args.use_cnn_module:
self.conv_norm = LayerNorm(self.embed_dim)
self.conv_module = ConvolutionModule(
self.embed_dim,
args.cnn_module_kernel)
self.final_norm = LayerNorm(self.embed_dim)
else:
self.conv_norm = None
self.conv_module = None
self.final_norm = None
self.normalize_before = args.encoder_normalize_before
self.fc1 = self.build_fc1(
self.embed_dim,
......@@ -191,6 +229,16 @@ class PyramidTransformerEncoderLayer(nn.Module):
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
# whether to use macaron style
if self.macaron_norm is not None:
residual = x
if self.normalize_before:
x = self.macaron_norm(x)
x = self.macaron_fc2(self.activation_dropout_module(self.activation_fn(self.macaron_fc1(x))))
x = residual + self.ffn_scale * self.dropout_module(x)
if not self.normalize_before:
x = self.macaron_norm(x)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
......@@ -219,6 +267,17 @@ class PyramidTransformerEncoderLayer(nn.Module):
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
# convolution module
if self.conv_module is not None:
x = x.transpose(0, 1)
residual = x
if self.normalize_before:
x = self.conv_norm(x)
x = residual + self.dropout_module(self.conv_module(x, encoder_padding_mask))
if not self.normalize_before:
x = self.conv_norm(x)
x = x.transpose(0, 1)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
......@@ -226,9 +285,13 @@ class PyramidTransformerEncoderLayer(nn.Module):
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
x = self.residual_connection(self.ffn_scale * x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.conv_module is not None:
x = self.final_norm(x)
return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论