Commit 5160a9f5 by xuchen

optimize the code structure and support the simultaneous speech translation

parent 30aed6f9
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from functools import reduce from functools import reduce
import torch.nn as nn import torch.nn as nn
from fairseq import checkpoint_utils from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
...@@ -21,6 +21,7 @@ from fairseq.modules import ( ...@@ -21,6 +21,7 @@ from fairseq.modules import (
PositionalEmbedding, PositionalEmbedding,
PyramidTransformerEncoderLayer, PyramidTransformerEncoderLayer,
MultiheadAttention, MultiheadAttention,
DownSampleConvolutionModule
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -33,6 +34,18 @@ def lengths_to_padding_mask_with_maxlen(lens, max_length): ...@@ -33,6 +34,18 @@ def lengths_to_padding_mask_with_maxlen(lens, max_length):
return mask return mask
class Permute_120(nn.Module):
def forward(self, x):
return x.permute(1, 2, 0)
class Permute_201(nn.Module):
def forward(self, x):
return x.permute(2, 0, 1)
class ReducedEmbed(nn.Module): class ReducedEmbed(nn.Module):
# Reduced embedding for Pyramid Transformer # Reduced embedding for Pyramid Transformer
def __init__( def __init__(
...@@ -131,52 +144,32 @@ class ReducedEmbed(nn.Module): ...@@ -131,52 +144,32 @@ class ReducedEmbed(nn.Module):
class BlockFuse(nn.Module): class BlockFuse(nn.Module):
def __init__(self, embed_dim, prev_embed_dim, final_stage, fuse_way="add"): def __init__(self, embed_dim, prev_embed_dim, num_head, dropout):
super().__init__() super().__init__()
self.conv = nn.Sequential(
nn.Conv1d(prev_embed_dim, embed_dim, kernel_size=1),
nn.BatchNorm1d(embed_dim),
nn.ReLU()
)
self.pre_layer_norm = LayerNorm(prev_embed_dim) self.pre_layer_norm = LayerNorm(prev_embed_dim)
self.post_layer_norm = LayerNorm(embed_dim) self.out_layer_norm = LayerNorm(embed_dim)
self.final_layer_norm = LayerNorm(embed_dim) self.attn = MultiheadAttention(
embed_dim,
self.fuse_way = fuse_way num_head,
self.final_stage = final_stage kdim=prev_embed_dim,
vdim=prev_embed_dim,
if self.fuse_way == "gated": dropout=dropout,
self.gate_linear = nn.Linear(2 * embed_dim, embed_dim) encoder_decoder_attention=True,
# self.gate = nn.GRUCell(embed_dim, embed_dim) )
def forward(self, x, state, padding): def forward(self, x, state, padding):
seq_len, bsz, dim = x.size()
state = self.pre_layer_norm(state) state = self.pre_layer_norm(state)
state = state.permute(1, 2, 0) # bsz, dim, seq_len state, attn = self.attn(
if state.size(-1) != seq_len: query=x,
state = nn.functional.adaptive_avg_pool1d(state, seq_len) key=state,
state = self.conv(state) value=state,
state = state.permute(2, 0, 1) # seq_len, bsz, dim key_padding_mask=padding,
state = self.post_layer_norm(state) static_kv=True,
)
if self.fuse_way == "gated": state = self.out_layer_norm(state)
coef = (self.gate_linear(torch.cat([x, state], dim=-1))).sigmoid()
x = coef * x + (1 - coef) * state
else:
x = x + state
# if not self.final_stage:
x = self.final_layer_norm(x)
mask_pad = padding.unsqueeze(2) return state
# mask batch padding
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
return x
@register_model("pys2t_transformer") @register_model("pys2t_transformer")
...@@ -193,7 +186,258 @@ class PYS2TTransformerModel(S2TTransformerModel): ...@@ -193,7 +186,258 @@ class PYS2TTransformerModel(S2TTransformerModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
S2TTransformerModel.add_args(parser) # input
parser.add_argument(
"--conv-kernel-sizes",
type=str,
metavar="N",
help="kernel sizes of Conv1d subsampling layers",
)
parser.add_argument(
"--conv-channels",
type=int,
metavar="N",
help="# of channels in Conv1d subsampling layers",
)
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument('--share-all-embeddings',
action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
'--encoder-history-type',
default="learnable_dense",
help='encoder layer history type'
)
parser.add_argument(
'--decoder-history-type',
default="learnable_dense",
help='decoder layer history type'
)
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# pyramid setting
parser.add_argument( parser.add_argument(
"--pyramid-stages", "--pyramid-stages",
type=int, type=int,
...@@ -261,11 +505,15 @@ class PYS2TTransformerModel(S2TTransformerModel): ...@@ -261,11 +505,15 @@ class PYS2TTransformerModel(S2TTransformerModel):
help="the number of the attention heads", help="the number of the attention heads",
) )
parser.add_argument( parser.add_argument(
"--pyramid-use-ppm", "--pyramid-fuse",
action="store_true", action="store_true",
help="use ppm", help="fuse the features in multiple stages",
)
parser.add_argument(
"--pyramid-dropout",
type=float,
help="dropout of the pyramid transformer",
) )
parser.add_argument( parser.add_argument(
"--ctc-layer", "--ctc-layer",
type=int, type=int,
...@@ -300,11 +548,18 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -300,11 +548,18 @@ class PyS2TTransformerEncoder(FairseqEncoder):
self.dropout = FairseqDropout( self.dropout = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__ p=args.dropout, module_name=self.__class__.__name__
) )
self.pyramid_dropout = FairseqDropout(
p=getattr(args, "pyramid_dropout", args.dropout), module_name=self.__class__.__name__
)
self.pyramid_stages = getattr(args, "pyramid_stages", 4) self.pyramid_stages = getattr(args, "pyramid_stages", 4)
self.pyramid_layers = [int(n) for n in args.pyramid_layers.split("_")] self.pyramid_layers = [int(n) for n in args.pyramid_layers.split("_")]
self.pyramid_sr_ratios = [int(n) for n in args.pyramid_sr_ratios.split("_")] self.pyramid_sr_ratios = [int(n) for n in args.pyramid_sr_ratios.split("_")]
self.pyramid_attn_sample_ratios = [int(n) for n in args.pyramid_attn_sample_ratios.split("_")] if self.attn_type == "reduced":
self.pyramid_attn_sample_ratios = [int(n) for n in args.pyramid_attn_sample_ratios.split("_")]
else:
self.pyramid_attn_sample_ratios = None
self.pyramid_embed_dims = [int(n) for n in args.pyramid_embed_dims.split("_")] self.pyramid_embed_dims = [int(n) for n in args.pyramid_embed_dims.split("_")]
self.pyramid_position_embed = [int(n) for n in args.pyramid_position_embed.split("_")] self.pyramid_position_embed = [int(n) for n in args.pyramid_position_embed.split("_")]
self.pyramid_kernel_sizes = [int(n) for n in args.pyramid_kernel_sizes.split("_")] self.pyramid_kernel_sizes = [int(n) for n in args.pyramid_kernel_sizes.split("_")]
...@@ -314,8 +569,27 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -314,8 +569,27 @@ class PyS2TTransformerEncoder(FairseqEncoder):
self.pyramid_embed_norm = args.pyramid_embed_norm self.pyramid_embed_norm = args.pyramid_embed_norm
self.pyramid_block_attn = getattr(args, "pyramid_block_attn", False) self.pyramid_block_attn = getattr(args, "pyramid_block_attn", False)
self.pyramid_fuse_way = getattr(args, "pyramid_fuse_way", "add") self.fuse = getattr(args, "pyramid_fuse", False)
self.use_ppm = getattr(args, "pyramid_use_ppm", False) self.pyramid_fuse_way = getattr(args, "pyramid_fuse_way", "all_conv")
self.pyramid_fuse_transform = "conv"
if len(self.pyramid_fuse_way.split("_")) == 2:
items = self.pyramid_fuse_way.split("_")
self.pyramid_fuse_way = items[0]
self.pyramid_fuse_transform = items[1]
fuse_stages_num = 0
if self.fuse:
if self.pyramid_fuse_way == "all":
fuse_stages_num = self.pyramid_stages
elif self.pyramid_fuse_way == "same":
for dim in self.pyramid_embed_dims:
if dim == self.embed_dim:
fuse_stages_num += 1
else:
logger.error("Unsupported fusion!")
if fuse_stages_num == 1:
fuse_stages_num = 0
self.fuse_stages_num = fuse_stages_num
for i in range(self.pyramid_stages): for i in range(self.pyramid_stages):
num_layers = self.pyramid_layers[i] num_layers = self.pyramid_layers[i]
...@@ -327,9 +601,11 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -327,9 +601,11 @@ class PyS2TTransformerEncoder(FairseqEncoder):
num_head = self.pyramid_heads[i] num_head = self.pyramid_heads[i]
use_pos_embed = self.pyramid_position_embed[i] use_pos_embed = self.pyramid_position_embed[i]
logger.info("The stage {}: layer {}, sample ratio {}, attention sample ratio {}, embed dim {}, " logger.info("The stage {}: layer {}, sample ratio {}, attention sample ratio {}, embed dim {}, "
"kernel size {}, ffn ratio {}, num head {}, position embed {}". "kernel size {}, ffn ratio {}, num head {}, position embed {}, "
"fuse {}, fuse way {}, transformer {}.".
format(i, num_layers, sr_ratio, attn_sample_ratio, format(i, num_layers, sr_ratio, attn_sample_ratio,
embed_dim, kernel_size, ffn_ratio, num_head, use_pos_embed)) embed_dim, kernel_size, ffn_ratio, num_head, use_pos_embed,
self.fuse, self.pyramid_fuse_way, self.pyramid_fuse_transform))
if i == 0: if i == 0:
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim)
...@@ -359,41 +635,70 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -359,41 +635,70 @@ class PyS2TTransformerEncoder(FairseqEncoder):
block_fuse = None block_fuse = None
if self.pyramid_block_attn: if self.pyramid_block_attn:
if i != 0: if i != self.pyramid_stages - 1:
block_fuse = BlockFuse(embed_dim, self.pyramid_embed_dims[i-1], block_fuse = BlockFuse(self.embed_dim, embed_dim,
final_stage=True if i == self.pyramid_stages - 1 else False, self.pyramid_heads[-1], dropout=args.dropout
fuse_way=self.pyramid_fuse_way) )
if self.use_ppm: fuse_pre_layer_norm = None
ppm_pre_layer_norm = LayerNorm(embed_dim) fuse_post_layer_norm = None
ppm_post_layer_norm = LayerNorm(self.embed_dim) down_sample = None
ppm = nn.Sequential( if fuse_stages_num != 0:
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1), if self.pyramid_fuse_way == "all" or (
nn.BatchNorm1d(self.embed_dim), self.pyramid_fuse_way == "same" and self.embed_dim == embed_dim
nn.ReLU(), ):
) if i != self.pyramid_stages - 1:
shrink_size = reduce(lambda a, b: a * b, self.pyramid_sr_ratios[i + 1:])
else: else:
ppm_pre_layer_norm = None shrink_size = 1
ppm_post_layer_norm = None
ppm = None fuse_pre_layer_norm = LayerNorm(embed_dim)
fuse_post_layer_norm = LayerNorm(self.embed_dim)
if self.pyramid_fuse_transform == "conv":
down_sample = nn.Sequential(
Permute_120(),
nn.Conv1d(embed_dim, self.embed_dim,
kernel_size=shrink_size,
stride=shrink_size),
nn.BatchNorm1d(self.embed_dim),
nn.ReLU(),
Permute_201(),
)
elif self.pyramid_fuse_transform == "pool":
down_sample = nn.Sequential(
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1),
nn.BatchNorm1d(self.embed_dim),
nn.ReLU(),
Permute_201(),
)
elif self.pyramid_fuse_transform == "conv2":
down_sample = DownSampleConvolutionModule(
self.embed_dim,
kernel_size=shrink_size,
stride=shrink_size,
)
else:
logger.error("Unsupported fusion transform!")
setattr(self, f"reduced_embed{i + 1}", reduced_embed) setattr(self, f"reduced_embed{i + 1}", reduced_embed)
setattr(self, f"pos_embed{i + 1}", pos_embed) setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"block{i + 1}", block) setattr(self, f"block{i + 1}", block)
setattr(self, f"block_fuse{i + 1}", block_fuse) setattr(self, f"block_fuse{i + 1}", block_fuse)
setattr(self, f"ppm{i + 1}", ppm) setattr(self, f"down_sample{i + 1}", down_sample)
setattr(self, f"ppm_pre_layer_norm{i + 1}", ppm_pre_layer_norm) setattr(self, f"fuse_pre_layer_norm{i + 1}", fuse_pre_layer_norm)
setattr(self, f"ppm_post_layer_norm{i + 1}", ppm_post_layer_norm) setattr(self, f"fuse_post_layer_norm{i + 1}", fuse_post_layer_norm)
if self.pyramid_block_attn:
self.block_layer_norm = LayerNorm(self.embed_dim)
if args.encoder_normalize_before: if args.encoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim) self.layer_norm = LayerNorm(self.embed_dim)
else: else:
self.layer_norm = None self.layer_norm = None
if self.use_ppm: if self.fuse_stages_num != 0 or self.pyramid_block_attn:
self.ppm_weight = nn.Parameter(torch.Tensor(self.pyramid_stages).fill_(1.0)) self.fuse_weight = nn.Parameter(torch.Tensor(fuse_stages_num).fill_(1.0))
self.ppm_weight.data = self.ppm_weight.data / self.ppm_weight.data.sum(0, keepdim=True) self.fuse_weight.data = self.fuse_weight.data / self.fuse_weight.data.sum(0, keepdim=True)
self.use_ctc = "sate" in args.arch or \ self.use_ctc = "sate" in args.arch or \
(("ctc" in getattr(args, "criterion", False)) and (("ctc" in getattr(args, "criterion", False)) and
...@@ -452,24 +757,23 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -452,24 +757,23 @@ class PyS2TTransformerEncoder(FairseqEncoder):
reduced_embed = getattr(self, f"reduced_embed{i + 1}") reduced_embed = getattr(self, f"reduced_embed{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}") pos_embed = getattr(self, f"pos_embed{i + 1}")
block = getattr(self, f"block{i + 1}") block = getattr(self, f"block{i + 1}")
block_fuse = getattr(self, f"block_fuse{i + 1}")
x, input_lengths, encoder_padding_mask = reduced_embed(x, input_lengths) x, input_lengths, encoder_padding_mask = reduced_embed(x, input_lengths)
# add the position encoding and dropout
if block_fuse is not None:
x = block_fuse(x, prev_state[-1], encoder_padding_mask)
# add the position encoding and dropout
if pos_embed: if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1) positions = pos_embed(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn": #if self.attn_type != "rel_selfattn":
x += positions # x += positions
x += positions
positions = self.dropout(positions) positions = self.dropout(positions)
else: else:
positions = None positions = None
if i == 0: if i == 0:
x = self.dropout(x) x = self.dropout(x)
else:
x = self.pyramid_dropout(x)
for layer in block: for layer in block:
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
...@@ -481,26 +785,32 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -481,26 +785,32 @@ class PyS2TTransformerEncoder(FairseqEncoder):
prev_state.append(x) prev_state.append(x)
prev_padding.append(encoder_padding_mask) prev_padding.append(encoder_padding_mask)
if self.use_ppm: if self.fuse_stages_num != 0:
pool_state = [] fuse_state = []
seq_len, bsz, dim = x.size()
i = -1 i = -1
seq_len = x.size(0)
for state in prev_state: for state in prev_state:
i += 1 i += 1
ppm = getattr(self, f"ppm{i + 1}") down_sample = getattr(self, f"down_sample{i + 1}")
ppm_pre_layer_norm = getattr(self, f"ppm_pre_layer_norm{i + 1}") fuse_pre_layer_norm = getattr(self, f"fuse_pre_layer_norm{i + 1}")
ppm_post_layer_norm = getattr(self, f"ppm_post_layer_norm{i + 1}") fuse_post_layer_norm = getattr(self, f"fuse_post_layer_norm{i + 1}")
state = ppm_pre_layer_norm(state) if fuse_pre_layer_norm is not None or fuse_pre_layer_norm is not None:
state = state.permute(1, 2, 0) # bsz, dim, seq_len state = fuse_pre_layer_norm(state)
if i != self.pyramid_stages - 1:
state = nn.functional.adaptive_avg_pool1d(state, seq_len) if self.pyramid_fuse_transform == "pool":
state = ppm(state) state = state.permute(1, 2, 0) # bsz, dim, seq_len
state = state.permute(2, 0, 1) if i != self.pyramid_stages - 1:
state = ppm_post_layer_norm(state) state = nn.functional.adaptive_max_pool1d(state, seq_len)
pool_state.append(state) state = down_sample(state)
ppm_weight = self.ppm_weight elif self.pyramid_fuse_transform == "conv":
x = (torch.stack(pool_state, dim=0) * ppm_weight.view(-1, 1, 1, 1)).sum(0) state = down_sample(state)
elif self.pyramid_fuse_transform == "conv2":
state = down_sample(state, prev_padding[i])
state = fuse_post_layer_norm(state)
fuse_state.append(state)
fuse_weight = self.fuse_weight
x = (torch.stack(fuse_state, dim=0) * fuse_weight.view(-1, 1, 1, 1)).sum(0)
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
...@@ -570,21 +880,6 @@ def base_architecture(args): ...@@ -570,21 +880,6 @@ def base_architecture(args):
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "") args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "")
args.conv_channels = getattr(args, "conv_channels", 1024) args.conv_channels = getattr(args, "conv_channels", 1024)
# Pyramid
args.pyramid_stages = getattr(args, "pyramid_stages", None)
args.pyramid_layers = getattr(args, "pyramid_layers", None)
args.pyramid_sr_ratios = getattr(args, "pyramid_sr_ratios", None)
args.pyramid_attn_sample_ratios = getattr(args, "pyramid_attn_sample_ratios", None)
args.pyramid_embed_dims = getattr(args, "pyramid_embed_dims", None)
args.pyramid_kernel_sizes = getattr(args, "pyramid_kernel_sizes", None)
args.pyramid_ffn_ratios = getattr(args, "pyramid_ffn_ratios", None)
args.pyramid_heads = getattr(args, "pyramid_heads", None)
args.pyramid_position_embed = getattr(args, "pyramid_position_embed", None)
args.pyramid_reduced_embed = getattr(args, "pyramid_reduced_embed", "conv")
args.pyramid_embed_norm = getattr(args, "pyramid_embed_norm", False)
args.ctc_layer = getattr(args, "ctc_layer", -1)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12) args.encoder_layers = getattr(args, "encoder_layers", 12)
...@@ -609,6 +904,7 @@ def base_architecture(args): ...@@ -609,6 +904,7 @@ def base_architecture(args):
args.share_decoder_input_output_embed = getattr( args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False args, "share_decoder_input_output_embed", False
) )
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr( args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False args, "no_token_positional_embeddings", False
) )
...@@ -625,6 +921,27 @@ def base_architecture(args): ...@@ -625,6 +921,27 @@ def base_architecture(args):
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1) args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
# Conformer
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Pyramid
args.pyramid_stages = getattr(args, "pyramid_stages", None)
args.pyramid_layers = getattr(args, "pyramid_layers", None)
args.pyramid_sr_ratios = getattr(args, "pyramid_sr_ratios", None)
args.pyramid_attn_sample_ratios = getattr(args, "pyramid_attn_sample_ratios", None)
args.pyramid_embed_dims = getattr(args, "pyramid_embed_dims", None)
args.pyramid_kernel_sizes = getattr(args, "pyramid_kernel_sizes", None)
args.pyramid_ffn_ratios = getattr(args, "pyramid_ffn_ratios", None)
args.pyramid_heads = getattr(args, "pyramid_heads", None)
args.pyramid_position_embed = getattr(args, "pyramid_position_embed", None)
args.pyramid_reduced_embed = getattr(args, "pyramid_reduced_embed", "conv")
args.pyramid_embed_norm = getattr(args, "pyramid_embed_norm", False)
args.ctc_layer = getattr(args, "ctc_layer", -1)
args.pyramid_dropout = getattr(args, "pyramid_dropout", args.dropout)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_s") @register_model_architecture("pys2t_transformer", "pys2t_transformer_s")
def pys2t_transformer_s(args): def pys2t_transformer_s(args):
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import logging import logging
import torch.nn as nn import torch.nn as nn
from fairseq import checkpoint_utils from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import ( from fairseq.models import (
register_model, register_model,
...@@ -31,8 +31,220 @@ class S2TConformerModel(S2TTransformerModel): ...@@ -31,8 +31,220 @@ class S2TConformerModel(S2TTransformerModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
S2TTransformerModel.add_args(parser) # input
parser.add_argument(
"--conv-kernel-sizes",
type=str,
metavar="N",
help="kernel sizes of Conv1d subsampling layers",
)
parser.add_argument(
"--conv-channels",
type=int,
metavar="N",
help="# of channels in Conv1d subsampling layers",
)
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument('--share-all-embeddings',
action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
'--encoder-history-type',
default="learnable_dense",
help='encoder layer history type'
)
parser.add_argument(
'--decoder-history-type',
default="learnable_dense",
help='decoder layer history type'
)
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument( parser.add_argument(
"--macaron-style", "--macaron-style",
default=False, default=False,
...@@ -44,7 +256,7 @@ class S2TConformerModel(S2TTransformerModel): ...@@ -44,7 +256,7 @@ class S2TConformerModel(S2TTransformerModel):
"--zero-triu", "--zero-triu",
default=False, default=False,
type=bool, type=bool,
help="If true, zero the uppper triangular part of attention matrix.", help="If true, zero the upper triangular part of attention matrix.",
) )
# Relative positional encoding # Relative positional encoding
parser.add_argument( parser.add_argument(
......
...@@ -47,9 +47,9 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -47,9 +47,9 @@ class S2TSATEModel(S2TTransformerModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
S2TConformerModel.add_args(parser)
PYS2TTransformerModel.add_args(parser) PYS2TTransformerModel.add_args(parser)
# sate setting
parser.add_argument( parser.add_argument(
"--text-encoder-layers", "--text-encoder-layers",
default=6, default=6,
...@@ -127,14 +127,6 @@ class Adapter(nn.Module): ...@@ -127,14 +127,6 @@ class Adapter(nn.Module):
super().__init__() super().__init__()
attention_dim = args.encoder_embed_dim attention_dim = args.encoder_embed_dim
self.embed_scale = math.sqrt(attention_dim)
if args.no_scale_embedding:
self.embed_scale = 1.0
self.padding_idx = dictionary.pad_index
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
adapter_type = getattr(args, "adapter", "league") adapter_type = getattr(args, "adapter", "league")
self.adapter_type = adapter_type self.adapter_type = adapter_type
...@@ -143,13 +135,13 @@ class Adapter(nn.Module): ...@@ -143,13 +135,13 @@ class Adapter(nn.Module):
self.linear_adapter = nn.Sequential( self.linear_adapter = nn.Sequential(
nn.Linear(attention_dim, attention_dim), nn.Linear(attention_dim, attention_dim),
LayerNorm(args.encoder_embed_dim), LayerNorm(args.encoder_embed_dim),
self.dropout_module, # self.dropout_module,
nn.ReLU(), nn.ReLU(),
) )
elif adapter_type == "linear2": elif adapter_type == "linear2":
self.linear_adapter = nn.Sequential( self.linear_adapter = nn.Sequential(
nn.Linear(attention_dim, attention_dim), nn.Linear(attention_dim, attention_dim),
self.dropout_module, # self.dropout_module,
) )
elif adapter_type == "subsample": elif adapter_type == "subsample":
self.subsample_adaptor = Conv1dSubsampler( self.subsample_adaptor = Conv1dSubsampler(
...@@ -172,22 +164,19 @@ class Adapter(nn.Module): ...@@ -172,22 +164,19 @@ class Adapter(nn.Module):
self.gate_linear1 = nn.Linear(attention_dim, attention_dim) self.gate_linear1 = nn.Linear(attention_dim, attention_dim)
self.gate_linear2 = nn.Linear(attention_dim, attention_dim) self.gate_linear2 = nn.Linear(attention_dim, attention_dim)
attn_type = getattr(args, "text_encoder_attention_type", "selfattn")
self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx, pos_emb_type=attn_type
)
def forward(self, x, padding): def forward(self, x, padding):
representation, distribution = x representation, distribution = x
batch, seq_len, embed_dim = distribution.size() batch, seq_len, embed_dim = representation.size()
if distribution is not None:
distribution = distribution.view(-1, distribution.size(-1))
lengths = (~padding).long().sum(-1) lengths = (~padding).long().sum(-1)
if self.adapter_type == "linear": if self.adapter_type == "linear":
out = self.linear_adapter(representation) out = self.linear_adapter(representation)
elif self.adapter_type == "context": elif self.adapter_type == "context":
out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1) out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
elif self.adapter_type == "subsample": elif self.adapter_type == "subsample":
representation = representation.transpose(0, 1) representation = representation.transpose(0, 1)
...@@ -196,12 +185,12 @@ class Adapter(nn.Module): ...@@ -196,12 +185,12 @@ class Adapter(nn.Module):
elif self.adapter_type == "league": elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
out = linear_out + soft_out out = linear_out + soft_out
elif self.adapter_type == "gated_league": elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid() coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out out = coef * linear_out + (1 - coef) * soft_out
...@@ -212,22 +201,29 @@ class Adapter(nn.Module): ...@@ -212,22 +201,29 @@ class Adapter(nn.Module):
out = None out = None
logging.error("Unsupported adapter type: {}.".format(self.adapter_type)) logging.error("Unsupported adapter type: {}.".format(self.adapter_type))
out = self.embed_scale * out return out, padding
positions = self.embed_positions(padding).transpose(0, 1)
out = positions + out
out = self.dropout_module(out) class TextEncoder(FairseqEncoder):
def __init__(self, args, dictionary):
return out, positions, padding super().__init__(None)
self.embed_tokens = None
class TextEncoder(FairseqEncoder): attention_dim = args.encoder_embed_dim
def __init__(self, args, embed_tokens=None): self.embed_scale = math.sqrt(attention_dim)
if args.no_scale_embedding:
self.embed_scale = 1.0
self.padding_idx = dictionary.pad_index
super().__init__(None) self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.embed_tokens = embed_tokens self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.text_encoder_layers)] [TransformerEncoderLayer(args) for _ in range(args.text_encoder_layers)]
...@@ -239,6 +235,11 @@ class TextEncoder(FairseqEncoder): ...@@ -239,6 +235,11 @@ class TextEncoder(FairseqEncoder):
def forward(self, x, encoder_padding_mask=None, positions=None, history=None): def forward(self, x, encoder_padding_mask=None, positions=None, history=None):
x = self.embed_scale * x
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
x = positions + x
x = self.dropout_module(x)
for layer in self.layers: for layer in self.layers:
if history is not None: if history is not None:
x = history.pop() x = history.pop()
...@@ -290,7 +291,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -290,7 +291,7 @@ class S2TSATEEncoder(FairseqEncoder):
# logger.info("Force self attention for text encoder.") # logger.info("Force self attention for text encoder.")
# text encoder # text encoder
self.text_encoder = TextEncoder(args, embed_tokens) self.text_encoder = TextEncoder(args, task.source_dictionary)
args.encoder_attention_type = acoustic_encoder_attention_type args.encoder_attention_type = acoustic_encoder_attention_type
...@@ -310,11 +311,15 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -310,11 +311,15 @@ class S2TSATEEncoder(FairseqEncoder):
encoder_out = acoustic_encoder_out["encoder_out"][0] encoder_out = acoustic_encoder_out["encoder_out"][0]
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0] encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
ctc_logit = self.acoustic_encoder.compute_ctc_logit(encoder_out) if self.acoustic_encoder.use_ctc:
ctc_prob = self.acoustic_encoder.compute_ctc_prob(encoder_out, self.temperature) ctc_logit = self.acoustic_encoder.compute_ctc_logit(encoder_out)
ctc_prob = self.acoustic_encoder.compute_ctc_prob(encoder_out, self.temperature)
else:
ctc_logit = None
ctc_prob = None
x = (encoder_out, ctc_prob) x = (encoder_out, ctc_prob)
x, positions, encoder_padding_mask = self.adapter(x, encoder_padding_mask) x, encoder_padding_mask = self.adapter(x, encoder_padding_mask)
if self.history is not None: if self.history is not None:
acoustic_history = self.acoustic_encoder.history acoustic_history = self.acoustic_encoder.history
...@@ -332,7 +337,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -332,7 +337,7 @@ class S2TSATEEncoder(FairseqEncoder):
# x, input_lengths = self.length_adapter(x, src_lengths) # x, input_lengths = self.length_adapter(x, src_lengths)
# encoder_padding_mask = lengths_to_padding_mask(input_lengths) # encoder_padding_mask = lengths_to_padding_mask(input_lengths)
x = self.text_encoder(x, encoder_padding_mask, positions, self.history) x = self.text_encoder(x, encoder_padding_mask, self.history)
return { return {
"ctc_logit": [ctc_logit], # T x B x C "ctc_logit": [ctc_logit], # T x B x C
...@@ -344,7 +349,15 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -344,7 +349,15 @@ class S2TSATEEncoder(FairseqEncoder):
"src_lengths": [], "src_lengths": [],
} }
def compute_ctc_logit(self, encoder_out):
return encoder_out["ctc_logit"][0]
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
new_ctc_logit = (
[] if len(encoder_out["ctc_logit"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["ctc_logit"]]
)
new_encoder_out = ( new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0 [] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
...@@ -366,6 +379,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -366,6 +379,7 @@ class S2TSATEEncoder(FairseqEncoder):
encoder_states[idx] = state.index_select(1, new_order) encoder_states[idx] = state.index_select(1, new_order)
return { return {
"ctc_logit": new_ctc_logit,
"encoder_out": new_encoder_out, # T x B x C "encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T "encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C "encoder_embedding": new_encoder_embedding, # B x T x C
...@@ -381,14 +395,6 @@ def base_architecture(args): ...@@ -381,14 +395,6 @@ def base_architecture(args):
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024) args.conv_channels = getattr(args, "conv_channels", 1024)
# Conformer
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Pyramid
args.pyramid_layers = getattr(args, "pyramid_layers", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12) args.encoder_layers = getattr(args, "encoder_layers", 12)
...@@ -416,6 +422,7 @@ def base_architecture(args): ...@@ -416,6 +422,7 @@ def base_architecture(args):
args.share_decoder_input_output_embed = getattr( args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False args, "share_decoder_input_output_embed", False
) )
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr( args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False args, "no_token_positional_embeddings", False
) )
...@@ -432,6 +439,25 @@ def base_architecture(args): ...@@ -432,6 +439,25 @@ def base_architecture(args):
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1) args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
# Conformer
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Pyramid
args.pyramid_stages = getattr(args, "pyramid_stages", None)
args.pyramid_layers = getattr(args, "pyramid_layers", None)
args.pyramid_sr_ratios = getattr(args, "pyramid_sr_ratios", None)
args.pyramid_attn_sample_ratios = getattr(args, "pyramid_attn_sample_ratios", None)
args.pyramid_embed_dims = getattr(args, "pyramid_embed_dims", None)
args.pyramid_kernel_sizes = getattr(args, "pyramid_kernel_sizes", None)
args.pyramid_ffn_ratios = getattr(args, "pyramid_ffn_ratios", None)
args.pyramid_heads = getattr(args, "pyramid_heads", None)
args.pyramid_position_embed = getattr(args, "pyramid_position_embed", None)
args.pyramid_reduced_embed = getattr(args, "pyramid_reduced_embed", "conv")
args.pyramid_embed_norm = getattr(args, "pyramid_embed_norm", False)
args.ctc_layer = getattr(args, "ctc_layer", -1)
args.pyramid_dropout = getattr(args, "pyramid_dropout", args.dropout)
@register_model_architecture("s2t_sate", "s2t_sate_s") @register_model_architecture("s2t_sate", "s2t_sate_s")
def s2t_sate_s(args): def s2t_sate_s(args):
......
...@@ -19,6 +19,7 @@ from fairseq.modules import ( ...@@ -19,6 +19,7 @@ from fairseq.modules import (
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
TransformerEncoderLayer, TransformerEncoderLayer,
ConformerEncoderLayer,
CreateLayerHistory, CreateLayerHistory,
) )
from torch import Tensor from torch import Tensor
...@@ -303,6 +304,51 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -303,6 +304,51 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help='initialized weight for local mask' help='initialized weight for local mask'
) )
# Conformer setting
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# Simultaneous speech translation
parser.add_argument(
"--simul",
default=False,
action="store_true",
help="Simultaneous speech translation or not",
)
pass pass
@classmethod @classmethod
...@@ -321,7 +367,16 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -321,7 +367,16 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
@classmethod @classmethod
def build_decoder(cls, args, task, embed_tokens): def build_decoder(cls, args, task, embed_tokens):
decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
if getattr(args, "simul", False):
from examples.simultaneous_translation.models.transformer_monotonic_attention import (
TransformerMonotonicDecoder,
)
decoder = TransformerMonotonicDecoder(args, task.target_dictionary, embed_tokens)
else:
decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None): if getattr(args, "load_pretrained_decoder_from", None):
logger.info( logger.info(
f"loaded pretrained decoder from: " f"loaded pretrained decoder from: "
...@@ -506,8 +561,6 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -506,8 +561,6 @@ class S2TTransformerEncoder(FairseqEncoder):
"src_lengths": [], "src_lengths": [],
} }
# "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() else [], # B x T
def compute_ctc_logit(self, encoder_out): def compute_ctc_logit(self, encoder_out):
assert self.use_ctc, "CTC is not available!" assert self.use_ctc, "CTC is not available!"
...@@ -602,12 +655,17 @@ class TransformerDecoderScriptable(TransformerDecoder): ...@@ -602,12 +655,17 @@ class TransformerDecoderScriptable(TransformerDecoder):
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
@register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer") @register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer")
def base_architecture(args): def base_architecture(args):
# Convolutional subsampler # Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024) args.conv_channels = getattr(args, "conv_channels", 1024)
# Conformer
args.macaron_style = getattr(args, "macaron_style", True)
args.use_cnn_module = getattr(args, "use_cnn_module", True)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Transformer # Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
...@@ -637,6 +695,7 @@ def base_architecture(args): ...@@ -637,6 +695,7 @@ def base_architecture(args):
args.share_decoder_input_output_embed = getattr( args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False args, "share_decoder_input_output_embed", False
) )
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr( args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False args, "no_token_positional_embeddings", False
) )
......
...@@ -15,6 +15,7 @@ from fairseq.modules import ( ...@@ -15,6 +15,7 @@ from fairseq.modules import (
RelPositionMultiheadAttention, RelPositionMultiheadAttention,
RelativeMultiheadAttention, RelativeMultiheadAttention,
LocalMultiheadAttention, LocalMultiheadAttention,
ConvolutionModule
) )
from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise from fairseq.modules.quant_noise import quant_noise
...@@ -59,6 +60,43 @@ class PyramidTransformerEncoderLayer(nn.Module): ...@@ -59,6 +60,43 @@ class PyramidTransformerEncoderLayer(nn.Module):
self.activation_dropout_module = FairseqDropout( self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__ float(activation_dropout_p), module_name=self.__class__.__name__
) )
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
if args.macaron_style:
self.macaron_fc1 = self.build_fc1(
self.embed_dim,
args.encoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.macaron_fc2 = self.build_fc2(
args.encoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.macaron_norm = LayerNorm(self.embed_dim)
self.ffn_scale = 0.5
else:
self.macaron_fc1 = None
self.macaron_fc2 = None
self.macaron_norm = None
self.ffn_scale = 1.0
if args.use_cnn_module:
self.conv_norm = LayerNorm(self.embed_dim)
self.conv_module = ConvolutionModule(
self.embed_dim,
args.cnn_module_kernel)
self.final_norm = LayerNorm(self.embed_dim)
else:
self.conv_norm = None
self.conv_module = None
self.final_norm = None
self.normalize_before = args.encoder_normalize_before self.normalize_before = args.encoder_normalize_before
self.fc1 = self.build_fc1( self.fc1 = self.build_fc1(
self.embed_dim, self.embed_dim,
...@@ -191,6 +229,16 @@ class PyramidTransformerEncoderLayer(nn.Module): ...@@ -191,6 +229,16 @@ class PyramidTransformerEncoderLayer(nn.Module):
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
# whether to use macaron style
if self.macaron_norm is not None:
residual = x
if self.normalize_before:
x = self.macaron_norm(x)
x = self.macaron_fc2(self.activation_dropout_module(self.activation_fn(self.macaron_fc1(x))))
x = residual + self.ffn_scale * self.dropout_module(x)
if not self.normalize_before:
x = self.macaron_norm(x)
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
...@@ -219,6 +267,17 @@ class PyramidTransformerEncoderLayer(nn.Module): ...@@ -219,6 +267,17 @@ class PyramidTransformerEncoderLayer(nn.Module):
if not self.normalize_before: if not self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
# convolution module
if self.conv_module is not None:
x = x.transpose(0, 1)
residual = x
if self.normalize_before:
x = self.conv_norm(x)
x = residual + self.dropout_module(self.conv_module(x, encoder_padding_mask))
if not self.normalize_before:
x = self.conv_norm(x)
x = x.transpose(0, 1)
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
...@@ -226,9 +285,13 @@ class PyramidTransformerEncoderLayer(nn.Module): ...@@ -226,9 +285,13 @@ class PyramidTransformerEncoderLayer(nn.Module):
x = self.activation_dropout_module(x) x = self.activation_dropout_module(x)
x = self.fc2(x) x = self.fc2(x)
x = self.dropout_module(x) x = self.dropout_module(x)
x = self.residual_connection(x, residual) x = self.residual_connection(self.ffn_scale * x, residual)
if not self.normalize_before: if not self.normalize_before:
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
if self.conv_module is not None:
x = self.final_norm(x)
return x return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论