Commit 5160a9f5 by xuchen

optimize the code structure and support the simultaneous speech translation

parent 30aed6f9
......@@ -6,7 +6,7 @@ import torch
from functools import reduce
import torch.nn as nn
from fairseq import checkpoint_utils
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
FairseqEncoder,
......@@ -21,6 +21,7 @@ from fairseq.modules import (
PositionalEmbedding,
PyramidTransformerEncoderLayer,
MultiheadAttention,
DownSampleConvolutionModule
)
logger = logging.getLogger(__name__)
......@@ -33,6 +34,18 @@ def lengths_to_padding_mask_with_maxlen(lens, max_length):
return mask
class Permute_120(nn.Module):
def forward(self, x):
return x.permute(1, 2, 0)
class Permute_201(nn.Module):
def forward(self, x):
return x.permute(2, 0, 1)
class ReducedEmbed(nn.Module):
# Reduced embedding for Pyramid Transformer
def __init__(
......@@ -131,52 +144,32 @@ class ReducedEmbed(nn.Module):
class BlockFuse(nn.Module):
def __init__(self, embed_dim, prev_embed_dim, final_stage, fuse_way="add"):
def __init__(self, embed_dim, prev_embed_dim, num_head, dropout):
super().__init__()
self.conv = nn.Sequential(
nn.Conv1d(prev_embed_dim, embed_dim, kernel_size=1),
nn.BatchNorm1d(embed_dim),
nn.ReLU()
)
self.pre_layer_norm = LayerNorm(prev_embed_dim)
self.post_layer_norm = LayerNorm(embed_dim)
self.final_layer_norm = LayerNorm(embed_dim)
self.fuse_way = fuse_way
self.final_stage = final_stage
if self.fuse_way == "gated":
self.gate_linear = nn.Linear(2 * embed_dim, embed_dim)
# self.gate = nn.GRUCell(embed_dim, embed_dim)
self.out_layer_norm = LayerNorm(embed_dim)
self.attn = MultiheadAttention(
embed_dim,
num_head,
kdim=prev_embed_dim,
vdim=prev_embed_dim,
dropout=dropout,
encoder_decoder_attention=True,
)
def forward(self, x, state, padding):
seq_len, bsz, dim = x.size()
state = self.pre_layer_norm(state)
state = state.permute(1, 2, 0) # bsz, dim, seq_len
if state.size(-1) != seq_len:
state = nn.functional.adaptive_avg_pool1d(state, seq_len)
state = self.conv(state)
state = state.permute(2, 0, 1) # seq_len, bsz, dim
state = self.post_layer_norm(state)
if self.fuse_way == "gated":
coef = (self.gate_linear(torch.cat([x, state], dim=-1))).sigmoid()
x = coef * x + (1 - coef) * state
else:
x = x + state
# if not self.final_stage:
x = self.final_layer_norm(x)
state, attn = self.attn(
query=x,
key=state,
value=state,
key_padding_mask=padding,
static_kv=True,
)
state = self.out_layer_norm(state)
mask_pad = padding.unsqueeze(2)
# mask batch padding
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
return x
return state
@register_model("pys2t_transformer")
......@@ -193,7 +186,258 @@ class PYS2TTransformerModel(S2TTransformerModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
S2TTransformerModel.add_args(parser)
# input
parser.add_argument(
"--conv-kernel-sizes",
type=str,
metavar="N",
help="kernel sizes of Conv1d subsampling layers",
)
parser.add_argument(
"--conv-channels",
type=int,
metavar="N",
help="# of channels in Conv1d subsampling layers",
)
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument('--share-all-embeddings',
action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
'--encoder-history-type',
default="learnable_dense",
help='encoder layer history type'
)
parser.add_argument(
'--decoder-history-type',
default="learnable_dense",
help='decoder layer history type'
)
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# pyramid setting
parser.add_argument(
"--pyramid-stages",
type=int,
......@@ -261,11 +505,15 @@ class PYS2TTransformerModel(S2TTransformerModel):
help="the number of the attention heads",
)
parser.add_argument(
"--pyramid-use-ppm",
"--pyramid-fuse",
action="store_true",
help="use ppm",
help="fuse the features in multiple stages",
)
parser.add_argument(
"--pyramid-dropout",
type=float,
help="dropout of the pyramid transformer",
)
parser.add_argument(
"--ctc-layer",
type=int,
......@@ -300,11 +548,18 @@ class PyS2TTransformerEncoder(FairseqEncoder):
self.dropout = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.pyramid_dropout = FairseqDropout(
p=getattr(args, "pyramid_dropout", args.dropout), module_name=self.__class__.__name__
)
self.pyramid_stages = getattr(args, "pyramid_stages", 4)
self.pyramid_layers = [int(n) for n in args.pyramid_layers.split("_")]
self.pyramid_sr_ratios = [int(n) for n in args.pyramid_sr_ratios.split("_")]
self.pyramid_attn_sample_ratios = [int(n) for n in args.pyramid_attn_sample_ratios.split("_")]
if self.attn_type == "reduced":
self.pyramid_attn_sample_ratios = [int(n) for n in args.pyramid_attn_sample_ratios.split("_")]
else:
self.pyramid_attn_sample_ratios = None
self.pyramid_embed_dims = [int(n) for n in args.pyramid_embed_dims.split("_")]
self.pyramid_position_embed = [int(n) for n in args.pyramid_position_embed.split("_")]
self.pyramid_kernel_sizes = [int(n) for n in args.pyramid_kernel_sizes.split("_")]
......@@ -314,8 +569,27 @@ class PyS2TTransformerEncoder(FairseqEncoder):
self.pyramid_embed_norm = args.pyramid_embed_norm
self.pyramid_block_attn = getattr(args, "pyramid_block_attn", False)
self.pyramid_fuse_way = getattr(args, "pyramid_fuse_way", "add")
self.use_ppm = getattr(args, "pyramid_use_ppm", False)
self.fuse = getattr(args, "pyramid_fuse", False)
self.pyramid_fuse_way = getattr(args, "pyramid_fuse_way", "all_conv")
self.pyramid_fuse_transform = "conv"
if len(self.pyramid_fuse_way.split("_")) == 2:
items = self.pyramid_fuse_way.split("_")
self.pyramid_fuse_way = items[0]
self.pyramid_fuse_transform = items[1]
fuse_stages_num = 0
if self.fuse:
if self.pyramid_fuse_way == "all":
fuse_stages_num = self.pyramid_stages
elif self.pyramid_fuse_way == "same":
for dim in self.pyramid_embed_dims:
if dim == self.embed_dim:
fuse_stages_num += 1
else:
logger.error("Unsupported fusion!")
if fuse_stages_num == 1:
fuse_stages_num = 0
self.fuse_stages_num = fuse_stages_num
for i in range(self.pyramid_stages):
num_layers = self.pyramid_layers[i]
......@@ -327,9 +601,11 @@ class PyS2TTransformerEncoder(FairseqEncoder):
num_head = self.pyramid_heads[i]
use_pos_embed = self.pyramid_position_embed[i]
logger.info("The stage {}: layer {}, sample ratio {}, attention sample ratio {}, embed dim {}, "
"kernel size {}, ffn ratio {}, num head {}, position embed {}".
"kernel size {}, ffn ratio {}, num head {}, position embed {}, "
"fuse {}, fuse way {}, transformer {}.".
format(i, num_layers, sr_ratio, attn_sample_ratio,
embed_dim, kernel_size, ffn_ratio, num_head, use_pos_embed))
embed_dim, kernel_size, ffn_ratio, num_head, use_pos_embed,
self.fuse, self.pyramid_fuse_way, self.pyramid_fuse_transform))
if i == 0:
self.embed_scale = math.sqrt(embed_dim)
......@@ -359,41 +635,70 @@ class PyS2TTransformerEncoder(FairseqEncoder):
block_fuse = None
if self.pyramid_block_attn:
if i != 0:
block_fuse = BlockFuse(embed_dim, self.pyramid_embed_dims[i-1],
final_stage=True if i == self.pyramid_stages - 1 else False,
fuse_way=self.pyramid_fuse_way)
if self.use_ppm:
ppm_pre_layer_norm = LayerNorm(embed_dim)
ppm_post_layer_norm = LayerNorm(self.embed_dim)
ppm = nn.Sequential(
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1),
nn.BatchNorm1d(self.embed_dim),
nn.ReLU(),
)
else:
ppm_pre_layer_norm = None
ppm_post_layer_norm = None
ppm = None
if i != self.pyramid_stages - 1:
block_fuse = BlockFuse(self.embed_dim, embed_dim,
self.pyramid_heads[-1], dropout=args.dropout
)
fuse_pre_layer_norm = None
fuse_post_layer_norm = None
down_sample = None
if fuse_stages_num != 0:
if self.pyramid_fuse_way == "all" or (
self.pyramid_fuse_way == "same" and self.embed_dim == embed_dim
):
if i != self.pyramid_stages - 1:
shrink_size = reduce(lambda a, b: a * b, self.pyramid_sr_ratios[i + 1:])
else:
shrink_size = 1
fuse_pre_layer_norm = LayerNorm(embed_dim)
fuse_post_layer_norm = LayerNorm(self.embed_dim)
if self.pyramid_fuse_transform == "conv":
down_sample = nn.Sequential(
Permute_120(),
nn.Conv1d(embed_dim, self.embed_dim,
kernel_size=shrink_size,
stride=shrink_size),
nn.BatchNorm1d(self.embed_dim),
nn.ReLU(),
Permute_201(),
)
elif self.pyramid_fuse_transform == "pool":
down_sample = nn.Sequential(
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1),
nn.BatchNorm1d(self.embed_dim),
nn.ReLU(),
Permute_201(),
)
elif self.pyramid_fuse_transform == "conv2":
down_sample = DownSampleConvolutionModule(
self.embed_dim,
kernel_size=shrink_size,
stride=shrink_size,
)
else:
logger.error("Unsupported fusion transform!")
setattr(self, f"reduced_embed{i + 1}", reduced_embed)
setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"block{i + 1}", block)
setattr(self, f"block_fuse{i + 1}", block_fuse)
setattr(self, f"ppm{i + 1}", ppm)
setattr(self, f"ppm_pre_layer_norm{i + 1}", ppm_pre_layer_norm)
setattr(self, f"ppm_post_layer_norm{i + 1}", ppm_post_layer_norm)
setattr(self, f"down_sample{i + 1}", down_sample)
setattr(self, f"fuse_pre_layer_norm{i + 1}", fuse_pre_layer_norm)
setattr(self, f"fuse_post_layer_norm{i + 1}", fuse_post_layer_norm)
if self.pyramid_block_attn:
self.block_layer_norm = LayerNorm(self.embed_dim)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim)
else:
self.layer_norm = None
if self.use_ppm:
self.ppm_weight = nn.Parameter(torch.Tensor(self.pyramid_stages).fill_(1.0))
self.ppm_weight.data = self.ppm_weight.data / self.ppm_weight.data.sum(0, keepdim=True)
if self.fuse_stages_num != 0 or self.pyramid_block_attn:
self.fuse_weight = nn.Parameter(torch.Tensor(fuse_stages_num).fill_(1.0))
self.fuse_weight.data = self.fuse_weight.data / self.fuse_weight.data.sum(0, keepdim=True)
self.use_ctc = "sate" in args.arch or \
(("ctc" in getattr(args, "criterion", False)) and
......@@ -452,24 +757,23 @@ class PyS2TTransformerEncoder(FairseqEncoder):
reduced_embed = getattr(self, f"reduced_embed{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}")
block = getattr(self, f"block{i + 1}")
block_fuse = getattr(self, f"block_fuse{i + 1}")
x, input_lengths, encoder_padding_mask = reduced_embed(x, input_lengths)
# add the position encoding and dropout
if block_fuse is not None:
x = block_fuse(x, prev_state[-1], encoder_padding_mask)
# add the position encoding and dropout
if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
x += positions
#if self.attn_type != "rel_selfattn":
# x += positions
x += positions
positions = self.dropout(positions)
else:
positions = None
if i == 0:
x = self.dropout(x)
else:
x = self.pyramid_dropout(x)
for layer in block:
x = layer(x, encoder_padding_mask, pos_emb=positions)
......@@ -481,26 +785,32 @@ class PyS2TTransformerEncoder(FairseqEncoder):
prev_state.append(x)
prev_padding.append(encoder_padding_mask)
if self.use_ppm:
pool_state = []
seq_len, bsz, dim = x.size()
if self.fuse_stages_num != 0:
fuse_state = []
i = -1
seq_len = x.size(0)
for state in prev_state:
i += 1
ppm = getattr(self, f"ppm{i + 1}")
ppm_pre_layer_norm = getattr(self, f"ppm_pre_layer_norm{i + 1}")
ppm_post_layer_norm = getattr(self, f"ppm_post_layer_norm{i + 1}")
state = ppm_pre_layer_norm(state)
state = state.permute(1, 2, 0) # bsz, dim, seq_len
if i != self.pyramid_stages - 1:
state = nn.functional.adaptive_avg_pool1d(state, seq_len)
state = ppm(state)
state = state.permute(2, 0, 1)
state = ppm_post_layer_norm(state)
pool_state.append(state)
ppm_weight = self.ppm_weight
x = (torch.stack(pool_state, dim=0) * ppm_weight.view(-1, 1, 1, 1)).sum(0)
down_sample = getattr(self, f"down_sample{i + 1}")
fuse_pre_layer_norm = getattr(self, f"fuse_pre_layer_norm{i + 1}")
fuse_post_layer_norm = getattr(self, f"fuse_post_layer_norm{i + 1}")
if fuse_pre_layer_norm is not None or fuse_pre_layer_norm is not None:
state = fuse_pre_layer_norm(state)
if self.pyramid_fuse_transform == "pool":
state = state.permute(1, 2, 0) # bsz, dim, seq_len
if i != self.pyramid_stages - 1:
state = nn.functional.adaptive_max_pool1d(state, seq_len)
state = down_sample(state)
elif self.pyramid_fuse_transform == "conv":
state = down_sample(state)
elif self.pyramid_fuse_transform == "conv2":
state = down_sample(state, prev_padding[i])
state = fuse_post_layer_norm(state)
fuse_state.append(state)
fuse_weight = self.fuse_weight
x = (torch.stack(fuse_state, dim=0) * fuse_weight.view(-1, 1, 1, 1)).sum(0)
if self.layer_norm is not None:
x = self.layer_norm(x)
......@@ -570,21 +880,6 @@ def base_architecture(args):
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "")
args.conv_channels = getattr(args, "conv_channels", 1024)
# Pyramid
args.pyramid_stages = getattr(args, "pyramid_stages", None)
args.pyramid_layers = getattr(args, "pyramid_layers", None)
args.pyramid_sr_ratios = getattr(args, "pyramid_sr_ratios", None)
args.pyramid_attn_sample_ratios = getattr(args, "pyramid_attn_sample_ratios", None)
args.pyramid_embed_dims = getattr(args, "pyramid_embed_dims", None)
args.pyramid_kernel_sizes = getattr(args, "pyramid_kernel_sizes", None)
args.pyramid_ffn_ratios = getattr(args, "pyramid_ffn_ratios", None)
args.pyramid_heads = getattr(args, "pyramid_heads", None)
args.pyramid_position_embed = getattr(args, "pyramid_position_embed", None)
args.pyramid_reduced_embed = getattr(args, "pyramid_reduced_embed", "conv")
args.pyramid_embed_norm = getattr(args, "pyramid_embed_norm", False)
args.ctc_layer = getattr(args, "ctc_layer", -1)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
......@@ -609,6 +904,7 @@ def base_architecture(args):
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
......@@ -625,6 +921,27 @@ def base_architecture(args):
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# Conformer
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Pyramid
args.pyramid_stages = getattr(args, "pyramid_stages", None)
args.pyramid_layers = getattr(args, "pyramid_layers", None)
args.pyramid_sr_ratios = getattr(args, "pyramid_sr_ratios", None)
args.pyramid_attn_sample_ratios = getattr(args, "pyramid_attn_sample_ratios", None)
args.pyramid_embed_dims = getattr(args, "pyramid_embed_dims", None)
args.pyramid_kernel_sizes = getattr(args, "pyramid_kernel_sizes", None)
args.pyramid_ffn_ratios = getattr(args, "pyramid_ffn_ratios", None)
args.pyramid_heads = getattr(args, "pyramid_heads", None)
args.pyramid_position_embed = getattr(args, "pyramid_position_embed", None)
args.pyramid_reduced_embed = getattr(args, "pyramid_reduced_embed", "conv")
args.pyramid_embed_norm = getattr(args, "pyramid_embed_norm", False)
args.ctc_layer = getattr(args, "ctc_layer", -1)
args.pyramid_dropout = getattr(args, "pyramid_dropout", args.dropout)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_s")
def pys2t_transformer_s(args):
......
......@@ -3,7 +3,7 @@
import logging
import torch.nn as nn
from fairseq import checkpoint_utils
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
register_model,
......@@ -31,8 +31,220 @@ class S2TConformerModel(S2TTransformerModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
S2TTransformerModel.add_args(parser)
# input
parser.add_argument(
"--conv-kernel-sizes",
type=str,
metavar="N",
help="kernel sizes of Conv1d subsampling layers",
)
parser.add_argument(
"--conv-channels",
type=int,
metavar="N",
help="# of channels in Conv1d subsampling layers",
)
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument('--share-all-embeddings',
action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
'--encoder-history-type',
default="learnable_dense",
help='encoder layer history type'
)
parser.add_argument(
'--decoder-history-type',
default="learnable_dense",
help='decoder layer history type'
)
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument(
"--macaron-style",
default=False,
......@@ -44,7 +256,7 @@ class S2TConformerModel(S2TTransformerModel):
"--zero-triu",
default=False,
type=bool,
help="If true, zero the uppper triangular part of attention matrix.",
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
......
......@@ -47,9 +47,9 @@ class S2TSATEModel(S2TTransformerModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
S2TConformerModel.add_args(parser)
PYS2TTransformerModel.add_args(parser)
# sate setting
parser.add_argument(
"--text-encoder-layers",
default=6,
......@@ -127,14 +127,6 @@ class Adapter(nn.Module):
super().__init__()
attention_dim = args.encoder_embed_dim
self.embed_scale = math.sqrt(attention_dim)
if args.no_scale_embedding:
self.embed_scale = 1.0
self.padding_idx = dictionary.pad_index
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
adapter_type = getattr(args, "adapter", "league")
self.adapter_type = adapter_type
......@@ -143,13 +135,13 @@ class Adapter(nn.Module):
self.linear_adapter = nn.Sequential(
nn.Linear(attention_dim, attention_dim),
LayerNorm(args.encoder_embed_dim),
self.dropout_module,
# self.dropout_module,
nn.ReLU(),
)
elif adapter_type == "linear2":
self.linear_adapter = nn.Sequential(
nn.Linear(attention_dim, attention_dim),
self.dropout_module,
# self.dropout_module,
)
elif adapter_type == "subsample":
self.subsample_adaptor = Conv1dSubsampler(
......@@ -172,22 +164,19 @@ class Adapter(nn.Module):
self.gate_linear1 = nn.Linear(attention_dim, attention_dim)
self.gate_linear2 = nn.Linear(attention_dim, attention_dim)
attn_type = getattr(args, "text_encoder_attention_type", "selfattn")
self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx, pos_emb_type=attn_type
)
def forward(self, x, padding):
representation, distribution = x
batch, seq_len, embed_dim = distribution.size()
batch, seq_len, embed_dim = representation.size()
if distribution is not None:
distribution = distribution.view(-1, distribution.size(-1))
lengths = (~padding).long().sum(-1)
if self.adapter_type == "linear":
out = self.linear_adapter(representation)
elif self.adapter_type == "context":
out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
elif self.adapter_type == "subsample":
representation = representation.transpose(0, 1)
......@@ -196,12 +185,12 @@ class Adapter(nn.Module):
elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
out = linear_out + soft_out
elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out
......@@ -212,22 +201,29 @@ class Adapter(nn.Module):
out = None
logging.error("Unsupported adapter type: {}.".format(self.adapter_type))
out = self.embed_scale * out
return out, padding
positions = self.embed_positions(padding).transpose(0, 1)
out = positions + out
out = self.dropout_module(out)
class TextEncoder(FairseqEncoder):
def __init__(self, args, dictionary):
return out, positions, padding
super().__init__(None)
self.embed_tokens = None
class TextEncoder(FairseqEncoder):
def __init__(self, args, embed_tokens=None):
attention_dim = args.encoder_embed_dim
self.embed_scale = math.sqrt(attention_dim)
if args.no_scale_embedding:
self.embed_scale = 1.0
self.padding_idx = dictionary.pad_index
super().__init__(None)
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.embed_tokens = embed_tokens
self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
)
self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.text_encoder_layers)]
......@@ -239,6 +235,11 @@ class TextEncoder(FairseqEncoder):
def forward(self, x, encoder_padding_mask=None, positions=None, history=None):
x = self.embed_scale * x
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
x = positions + x
x = self.dropout_module(x)
for layer in self.layers:
if history is not None:
x = history.pop()
......@@ -290,7 +291,7 @@ class S2TSATEEncoder(FairseqEncoder):
# logger.info("Force self attention for text encoder.")
# text encoder
self.text_encoder = TextEncoder(args, embed_tokens)
self.text_encoder = TextEncoder(args, task.source_dictionary)
args.encoder_attention_type = acoustic_encoder_attention_type
......@@ -310,11 +311,15 @@ class S2TSATEEncoder(FairseqEncoder):
encoder_out = acoustic_encoder_out["encoder_out"][0]
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
ctc_logit = self.acoustic_encoder.compute_ctc_logit(encoder_out)
ctc_prob = self.acoustic_encoder.compute_ctc_prob(encoder_out, self.temperature)
if self.acoustic_encoder.use_ctc:
ctc_logit = self.acoustic_encoder.compute_ctc_logit(encoder_out)
ctc_prob = self.acoustic_encoder.compute_ctc_prob(encoder_out, self.temperature)
else:
ctc_logit = None
ctc_prob = None
x = (encoder_out, ctc_prob)
x, positions, encoder_padding_mask = self.adapter(x, encoder_padding_mask)
x, encoder_padding_mask = self.adapter(x, encoder_padding_mask)
if self.history is not None:
acoustic_history = self.acoustic_encoder.history
......@@ -332,7 +337,7 @@ class S2TSATEEncoder(FairseqEncoder):
# x, input_lengths = self.length_adapter(x, src_lengths)
# encoder_padding_mask = lengths_to_padding_mask(input_lengths)
x = self.text_encoder(x, encoder_padding_mask, positions, self.history)
x = self.text_encoder(x, encoder_padding_mask, self.history)
return {
"ctc_logit": [ctc_logit], # T x B x C
......@@ -344,7 +349,15 @@ class S2TSATEEncoder(FairseqEncoder):
"src_lengths": [],
}
def compute_ctc_logit(self, encoder_out):
return encoder_out["ctc_logit"][0]
def reorder_encoder_out(self, encoder_out, new_order):
new_ctc_logit = (
[] if len(encoder_out["ctc_logit"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["ctc_logit"]]
)
new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
......@@ -366,6 +379,7 @@ class S2TSATEEncoder(FairseqEncoder):
encoder_states[idx] = state.index_select(1, new_order)
return {
"ctc_logit": new_ctc_logit,
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
......@@ -381,14 +395,6 @@ def base_architecture(args):
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024)
# Conformer
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Pyramid
args.pyramid_layers = getattr(args, "pyramid_layers", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
......@@ -416,6 +422,7 @@ def base_architecture(args):
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
......@@ -432,6 +439,25 @@ def base_architecture(args):
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# Conformer
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Pyramid
args.pyramid_stages = getattr(args, "pyramid_stages", None)
args.pyramid_layers = getattr(args, "pyramid_layers", None)
args.pyramid_sr_ratios = getattr(args, "pyramid_sr_ratios", None)
args.pyramid_attn_sample_ratios = getattr(args, "pyramid_attn_sample_ratios", None)
args.pyramid_embed_dims = getattr(args, "pyramid_embed_dims", None)
args.pyramid_kernel_sizes = getattr(args, "pyramid_kernel_sizes", None)
args.pyramid_ffn_ratios = getattr(args, "pyramid_ffn_ratios", None)
args.pyramid_heads = getattr(args, "pyramid_heads", None)
args.pyramid_position_embed = getattr(args, "pyramid_position_embed", None)
args.pyramid_reduced_embed = getattr(args, "pyramid_reduced_embed", "conv")
args.pyramid_embed_norm = getattr(args, "pyramid_embed_norm", False)
args.ctc_layer = getattr(args, "ctc_layer", -1)
args.pyramid_dropout = getattr(args, "pyramid_dropout", args.dropout)
@register_model_architecture("s2t_sate", "s2t_sate_s")
def s2t_sate_s(args):
......
......@@ -19,6 +19,7 @@ from fairseq.modules import (
LayerNorm,
PositionalEmbedding,
TransformerEncoderLayer,
ConformerEncoderLayer,
CreateLayerHistory,
)
from torch import Tensor
......@@ -303,6 +304,51 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# Simultaneous speech translation
parser.add_argument(
"--simul",
default=False,
action="store_true",
help="Simultaneous speech translation or not",
)
pass
@classmethod
......@@ -321,7 +367,16 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
@classmethod
def build_decoder(cls, args, task, embed_tokens):
decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
if getattr(args, "simul", False):
from examples.simultaneous_translation.models.transformer_monotonic_attention import (
TransformerMonotonicDecoder,
)
decoder = TransformerMonotonicDecoder(args, task.target_dictionary, embed_tokens)
else:
decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None):
logger.info(
f"loaded pretrained decoder from: "
......@@ -506,8 +561,6 @@ class S2TTransformerEncoder(FairseqEncoder):
"src_lengths": [],
}
# "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() else [], # B x T
def compute_ctc_logit(self, encoder_out):
assert self.use_ctc, "CTC is not available!"
......@@ -602,12 +655,17 @@ class TransformerDecoderScriptable(TransformerDecoder):
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
@register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer")
def base_architecture(args):
# Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024)
# Conformer
args.macaron_style = getattr(args, "macaron_style", True)
args.use_cnn_module = getattr(args, "use_cnn_module", True)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
......@@ -637,6 +695,7 @@ def base_architecture(args):
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
......
......@@ -15,6 +15,7 @@ from fairseq.modules import (
RelPositionMultiheadAttention,
RelativeMultiheadAttention,
LocalMultiheadAttention,
ConvolutionModule
)
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
......@@ -59,6 +60,43 @@ class PyramidTransformerEncoderLayer(nn.Module):
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
if args.macaron_style:
self.macaron_fc1 = self.build_fc1(
self.embed_dim,
args.encoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.macaron_fc2 = self.build_fc2(
args.encoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.macaron_norm = LayerNorm(self.embed_dim)
self.ffn_scale = 0.5
else:
self.macaron_fc1 = None
self.macaron_fc2 = None
self.macaron_norm = None
self.ffn_scale = 1.0
if args.use_cnn_module:
self.conv_norm = LayerNorm(self.embed_dim)
self.conv_module = ConvolutionModule(
self.embed_dim,
args.cnn_module_kernel)
self.final_norm = LayerNorm(self.embed_dim)
else:
self.conv_norm = None
self.conv_module = None
self.final_norm = None
self.normalize_before = args.encoder_normalize_before
self.fc1 = self.build_fc1(
self.embed_dim,
......@@ -191,6 +229,16 @@ class PyramidTransformerEncoderLayer(nn.Module):
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
# whether to use macaron style
if self.macaron_norm is not None:
residual = x
if self.normalize_before:
x = self.macaron_norm(x)
x = self.macaron_fc2(self.activation_dropout_module(self.activation_fn(self.macaron_fc1(x))))
x = residual + self.ffn_scale * self.dropout_module(x)
if not self.normalize_before:
x = self.macaron_norm(x)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
......@@ -219,6 +267,17 @@ class PyramidTransformerEncoderLayer(nn.Module):
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
# convolution module
if self.conv_module is not None:
x = x.transpose(0, 1)
residual = x
if self.normalize_before:
x = self.conv_norm(x)
x = residual + self.dropout_module(self.conv_module(x, encoder_padding_mask))
if not self.normalize_before:
x = self.conv_norm(x)
x = x.transpose(0, 1)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
......@@ -226,9 +285,13 @@ class PyramidTransformerEncoderLayer(nn.Module):
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
x = self.residual_connection(self.ffn_scale * x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.conv_module is not None:
x = self.final_norm(x)
return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论