Commit d946bc3b by xuchen

I valid the results of embedding norm and no scale embedding for speech-to-text encoder.

Yeah, it is better.
parent 2de89089
......@@ -19,6 +19,9 @@ subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
encoder-embed-norm: True
encoder-no-scale-embedding: True
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
......
......@@ -22,6 +22,9 @@ subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
encoder-embed-norm: True
encoder-no-scale-embedding: True
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
......
......@@ -340,7 +340,7 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0 and \
torch.rand() < self.ctc_self_distill_prob:
torch.rand(1).uniform_() < self.ctc_self_distill_prob:
for i in range(interleaved_ctc_num):
out = net_output["interleaved_ctc_logits"][i]
if type(out) == list:
......
......@@ -258,12 +258,12 @@ class TextualEncoder(FairseqEncoder):
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
if args.no_scale_embedding:
if args.encoder_no_scale_embedding:
self.embed_scale = 1.0
self.padding_idx = dictionary.pad_index
self.embed_norm = getattr(args, "embed_norm", False)
if self.embed_norm:
self.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
if self.encoder_embed_norm:
self.embed_ln = LayerNorm(embed_dim)
self.dropout_module = FairseqDropout(
......@@ -339,7 +339,7 @@ class TextualEncoder(FairseqEncoder):
def forward(self, x, encoder_padding_mask=None, history=None):
if self.embed_norm:
if self.encoder_embed_norm:
x = self.embed_ln(x)
x = self.embed_scale * x
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
......@@ -599,9 +599,11 @@ def base_architecture(args):
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.encoder_no_scale_embedding = getattr(args, "encoder_no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.embed_linear = getattr(args, "embed_linear", False)
args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
......
......@@ -236,6 +236,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument(
"--encoder-no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings in encoder",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
......@@ -392,12 +397,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="Kernel size of convolution module.",
)
parser.add_argument(
"--embed-linear",
"--encoder-embed-linear",
action="store_true",
help="use linear transform after down-sampling",
)
parser.add_argument(
"--embed-norm",
"--encoder-embed-norm",
action="store_true",
help="use layer norm after down-sampling",
)
......@@ -590,16 +595,16 @@ class S2TTransformerEncoder(FairseqEncoder):
p=args.dropout, module_name=self.__class__.__name__
)
self.embed_scale = math.sqrt(dim)
if args.no_scale_embedding:
if args.encoder_no_scale_embedding:
self.embed_scale = 1.0
self.padding_idx = 1
self.subsample = subsampling(args)
self.embed_linear = getattr(args, "embed_linear", False)
self.embed_norm = getattr(args, "embed_norm", False)
if self.embed_linear:
self.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
self.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
if self.encoder_embed_linear:
self.linear = nn.Linear(dim, dim)
if self.embed_norm:
if self.encoder_embed_norm:
self.embed_ln = LayerNorm(dim)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
......@@ -814,7 +819,7 @@ class S2TTransformerEncoder(FairseqEncoder):
if encoder_padding_mask is not None:
x = x * (1 - encoder_padding_mask.transpose(0, 1).unsqueeze(-1).type_as(x))
if self.embed_norm:
if self.encoder_embed_norm:
x = self.embed_ln(x)
self.show_debug(x, "x after embed norm")
......@@ -835,7 +840,7 @@ class S2TTransformerEncoder(FairseqEncoder):
positions = None
self.show_debug(x, "x after position embedding")
if self.embed_linear:
if self.encoder_embed_linear:
x = self.linear(x)
self.show_debug(x, "x after embed linear")
......@@ -1061,10 +1066,11 @@ def base_architecture(args):
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.encoder_no_scale_embedding = getattr(args, "encoder_no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.embed_linear = getattr(args, "embed_linear", False)
args.embed_norm = getattr(args, "embed_norm", False)
args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
......
......@@ -713,18 +713,28 @@ class TransformerCTCEncoder(FairseqEncoder):
if ratio <= 1:
return x
if len(x.size()) == 3:
bsz, seq_len, dim = x.size()
up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
else:
bsz, seq_len = x.size()
up_x = x.unsqueeze(2).expand(-1, -1, ratio).reshape(bsz, -1)
up_padding = padding.unsqueeze(-1).expand(-1, -1, ratio).reshape(bsz, -1)
output_length = int(seq_len * ratio * 2/3)
select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
select_matrix[:, 1::ratio] = 1
threshold = select_matrix.sort(dim=-1, descending=True)[0][:, output_length:output_length + 1]
select_matrix = (select_matrix > threshold)
assert all(select_matrix.sum(dim=-1).eq(output_length))
out_x = up_x[select_matrix, :].reshape(bsz, -1, dim).contiguous()
out_padding = up_padding[select_matrix].reshape(bsz, -1).contiguous()
# output_length = int(seq_len * ratio * 2/3)
# select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
# select_matrix[:, 1::ratio] = 1
# mask = select_matrix.sort(dim=-1, descending=True)[1][:, :output_length]
# mask = mask.sort(dim=-1)[0]
#
# if len(x.size()) == 3:
# out_x = torch.gather(up_x, dim=1, index=mask.unsqueeze(-1).expand(-1, -1, dim)).contiguous()
# else:
# out_x = torch.gather(up_x, dim=1, index=mask).contiguous()
# out_padding = torch.gather(up_padding, dim=1, index=mask).contiguous()
out_x = up_x
out_padding = up_padding
return out_x, out_padding
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
......@@ -773,17 +783,17 @@ class TransformerCTCEncoder(FairseqEncoder):
if self.history is not None:
self.history.clean()
ctc_padding_mask = encoder_padding_mask
if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
src_tokens, encoder_padding_mask = self.upsampling(src_tokens, encoder_padding_mask)
ctc_padding_mask = encoder_padding_mask
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
# account for padding while computing the representation
if encoder_padding_mask is not None:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
ctc_padding_mask = encoder_padding_mask
if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
x, encoder_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_padding_mask = encoder_padding_mask
# B x T x C -> T x B x C
x = x.transpose(0, 1)
......
......@@ -67,6 +67,7 @@ def main(cfg: FairseqConfig) -> None:
# Print args
logger.info(cfg)
if distributed_utils.is_master(cfg.distributed_training):
with open(os.path.join(cfg.checkpoint.save_dir, "config.yaml"), 'w') as f:
f.write("%s" % OmegaConf.to_yaml(cfg))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论