Commit aed36ae4 by xuchen

optimize the implementation of lang tag

parent a64cdfcc
......@@ -367,9 +367,11 @@ class CtcCriterion(FairseqCriterion):
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
src_lang_idx = sample["net_input"].get("src_lang_idx", None)
tgt_lang_idx = sample["net_input"].get("tgt_lang_idx", None)
with torch.no_grad():
encoder_out = model.encoder(src_tokens, src_lengths,
src_lang_idx=src_lang_idx,
tgt_lang_idx=tgt_lang_idx)
ctc_logit = None
......
......@@ -82,6 +82,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
prev_output_tokens = sample["net_input"]["prev_output_tokens"]
src_lang_idx = sample["net_input"].get("src_lang_idx", None)
tgt_lang_idx = sample["net_input"].get("tgt_lang_idx", None)
train_enc_only = False
......@@ -105,10 +106,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
ctc_alignment_oracle = self.ctc_criterion.get_ground_truth_alignment(model, sample)
encoder_out = model.encoder(src_tokens, src_lengths,
ctc_alignment_oracle=ctc_alignment_oracle,
src_lang_idx=src_lang_idx,
tgt_lang_idx=tgt_lang_idx)
else:
encoder_out = model.encoder(src_tokens=src_tokens,
src_lengths=src_lengths,
src_lang_idx=src_lang_idx,
tgt_lang_idx=tgt_lang_idx)
net_output = model.decoder(
......
......@@ -1364,6 +1364,8 @@ class S2TTransformerEncoder(FairseqEncoder):
self.compression_stat = False
self.log_flag_dict = dict()
# gather cosine similarity
self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
self.gather_cos_sim_dis = 2
......@@ -1775,27 +1777,15 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.history is not None:
self.history.clean()
src_lang_idx = kwargs.get("src_lang_idx", None)
tgt_lang_idx = kwargs.get("tgt_lang_idx", None)
has_add_lang_tag = False
# (B, T, D) -> (T, B, D)
x = src_tokens.transpose(0, 1)
input_lengths = src_lengths
org_bsz = x.size(1)
if (
self.mixup
and layer_idx == mixup_layer
):
if tgt_lang_idx is not None:
assert self.embed_tokens is not None
tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0)
if mixup is not None:
pass
x = torch.cat((tgt_lang_embed, x), 0)
input_lengths += 1
has_add_lang_tag = True
if (
(self.training or self.mixup_infer)
and self.mixup
and layer_idx == mixup_layer
......@@ -1815,14 +1805,25 @@ class S2TTransformerEncoder(FairseqEncoder):
x, input_lengths = self.subsample(x, input_lengths)
self.show_debug(x, "x after subsampling")
#if tgt_lang_idx is not None and False:
if tgt_lang_idx is not None and not has_add_lang_tag:
if src_lang_idx is not None:
assert self.embed_tokens is not None
src_lang_embed = self.embed_tokens(src_lang_idx).unsqueeze(0)
x = torch.cat((src_lang_embed, x), 0)
input_lengths += 1
if "prepend_src_lang" not in self.log_flag_dict:
self.log_flag_dict["prepend_src_lang"] = True
logger.info("Prepend the source language tag into the encoder input.")
if tgt_lang_idx is not None:
assert self.embed_tokens is not None
tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0)
if mixup is not None:
pass
x = torch.cat((tgt_lang_embed, x), 0)
input_lengths += 1
input_lengths += 1
if "prepend_tgt_lang" not in self.log_flag_dict:
self.log_flag_dict["prepend_tgt_lang"] = True
logger.info("Prepend the target language tag into the encoder input.")
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if encoder_padding_mask.size(1) < x.size(0):
......@@ -2248,12 +2249,12 @@ class S2TTransformerEncoder(FairseqEncoder):
)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, encoder_padding_mask, "Encoder output", is_top=True)
ctc_logit = self.ctc(x, encoder_padding_mask, "Encoder CTC output", is_top=True)
self.show_debug(x, "x after ctc")
if self.use_xctc and xctc_logit is None:
xctc_logit = self.xctc(
x, encoder_padding_mask, "Encoder output", is_top=True
x, encoder_padding_mask, "Encoder XCTC output", is_top=True
)
self.show_debug(x, "x after xctc")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论