Commit b984a889 by xuchen

support the position setting of language tag

parent aed36ae4
......@@ -771,6 +771,15 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true",
help="add the position embedding after compression",
)
# inserted position of language tag
parser.add_argument(
"--enc-lang-tag-pos",
default="input",
choices=["input", "predict"],
type=str,
help="position to insert the language tag, input or before prediction",
)
pass
@classmethod
......@@ -1364,6 +1373,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.compression_stat = False
self.enc_lang_tag_pos = getattr(args, "enc_lang_tag_pos", "input")
self.log_flag_dict = dict()
# gather cosine similarity
......@@ -1805,25 +1815,26 @@ class S2TTransformerEncoder(FairseqEncoder):
x, input_lengths = self.subsample(x, input_lengths)
self.show_debug(x, "x after subsampling")
if src_lang_idx is not None:
assert self.embed_tokens is not None
src_lang_embed = self.embed_tokens(src_lang_idx).unsqueeze(0)
x = torch.cat((src_lang_embed, x), 0)
input_lengths += 1
if self.enc_lang_tag_pos == "input":
if src_lang_idx is not None:
assert self.embed_tokens is not None
src_lang_embed = self.embed_tokens(src_lang_idx).unsqueeze(0)
x = torch.cat((src_lang_embed, x), 0)
input_lengths += 1
if "prepend_src_lang" not in self.log_flag_dict:
self.log_flag_dict["prepend_src_lang"] = True
logger.info("Prepend the source language tag into the encoder input.")
if "prepend_src_lang" not in self.log_flag_dict:
self.log_flag_dict["prepend_src_lang"] = True
logger.info("Prepend the source language tag into the encoder input.")
if tgt_lang_idx is not None:
assert self.embed_tokens is not None
tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0)
x = torch.cat((tgt_lang_embed, x), 0)
input_lengths += 1
if tgt_lang_idx is not None:
assert self.embed_tokens is not None
tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0)
x = torch.cat((tgt_lang_embed, x), 0)
input_lengths += 1
if "prepend_tgt_lang" not in self.log_flag_dict:
self.log_flag_dict["prepend_tgt_lang"] = True
logger.info("Prepend the target language tag into the encoder input.")
if "prepend_tgt_lang" not in self.log_flag_dict:
self.log_flag_dict["prepend_tgt_lang"] = True
logger.info("Prepend the target language tag into the encoder input.")
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if encoder_padding_mask.size(1) < x.size(0):
......@@ -1983,11 +1994,12 @@ class S2TTransformerEncoder(FairseqEncoder):
else:
layer_norm = getattr(self, "ctc_norm%d" % layer_idx)
norm_x = layer_norm(x)
logit = inter_ctc(
norm_x, encoder_padding_mask, "Source Layer %d" % layer_idx
)
inter_logit = [logit, encoder_padding_mask]
logit, logit_list = inter_ctc.forward_with_lang_tag(norm_x, encoder_padding_mask,
"CTC Layer %d" % layer_idx, "source",
self.embed_tokens(src_lang_idx)
if src_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None)
if self.ctc_pae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if (
......@@ -2021,7 +2033,7 @@ class S2TTransformerEncoder(FairseqEncoder):
~ctc_oracle_mask, -1
)
inter_logit = [logit, None, ctc_force_emit]
# logit_list = [logit, None, ctc_force_emit]
pae_input = x if self.pae_unnorm_input else norm_x
if pae.adapter_type != "none":
......@@ -2030,14 +2042,14 @@ class S2TTransformerEncoder(FairseqEncoder):
)
self.show_debug(x, "x after pae")
inter_ctc_logits.append(inter_logit)
inter_ctc_logits.append(logit_list)
if not self.training and self.early_exit_layer == layer_idx:
ctc_logit = inter_logit[0]
ctc_logit = logit_list[0]
break
if not self.training and self.early_exit_count != 0:
predicts = inter_ctc.predict(inter_logit[0], encoder_padding_mask)
predicts = inter_ctc.predict(logit_list[0], encoder_padding_mask)
if len(inter_ctc_logits) < self.early_exit_count:
for i in range(x.size(1)):
inter_ctc_logits_history[i].append(predicts[i])
......@@ -2045,7 +2057,7 @@ class S2TTransformerEncoder(FairseqEncoder):
if org_bsz == 1:
early_exit_flag = self.early_exit_or_not(inter_ctc_logits_history[0], predicts[0], self.early_exit_count)
if early_exit_flag:
ctc_logit = inter_logit[0]
ctc_logit = logit_list[0]
self.early_exit_layer_record.append(layer_idx)
break
else:
......@@ -2056,7 +2068,7 @@ class S2TTransformerEncoder(FairseqEncoder):
real_idx = batch_idx_dict[i]
early_exit_flag = self.early_exit_or_not(inter_ctc_logits_history[real_idx], predicts[i], self.early_exit_count)
if early_exit_flag:
final_ctc_logits[real_idx] = inter_logit[0][:, i, :]
final_ctc_logits[real_idx] = logit_list[0][:, i, :]
final_encoder_padding_mask[real_idx] = encoder_padding_mask[i, :]
early_exit_layer[real_idx] = layer_idx
else:
......@@ -2177,11 +2189,11 @@ class S2TTransformerEncoder(FairseqEncoder):
norm = getattr(self, "xctc_norm%d" % layer_idx)
norm_x = norm(x)
logit = self.xctc(
norm_x, encoder_padding_mask, "Inter XCTC layer %d" % layer_idx
)
logit, logit_list = self.xctc.forward_with_lang_tag(norm_x, encoder_padding_mask,
"XCTC Layer %d" % layer_idx, "target",
self.embed_tokens(tgt_lang_idx)
if tgt_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None)
inter_logit = logit
# CTC alignment
if self.xctc_pae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
......@@ -2214,7 +2226,7 @@ class S2TTransformerEncoder(FairseqEncoder):
xctc_force_emit = best_aligns_pad.masked_fill(
~xctc_oracle_mask, -1
)
inter_logit = [logit, None, xctc_force_emit]
# logit_list = [logit, None, xctc_force_emit]
pae_input = x if self.pae_unnorm_input else norm_x
if self.xctc_pae.adapter_type != "none":
......@@ -2225,7 +2237,7 @@ class S2TTransformerEncoder(FairseqEncoder):
xctc_oracle_mask,
)
inter_xctc_logits.append(inter_logit)
inter_xctc_logits.append(logit_list)
if self.history is not None:
self.history.push(x)
......@@ -2249,13 +2261,19 @@ class S2TTransformerEncoder(FairseqEncoder):
)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, encoder_padding_mask, "Encoder CTC output", is_top=True)
_, ctc_logit = self.ctc.forward_with_lang_tag(x, encoder_padding_mask,
"Encoder CTC output", "source",
self.embed_tokens(src_lang_idx)
if src_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None,
is_top=True)
self.show_debug(x, "x after ctc")
if self.use_xctc and xctc_logit is None:
xctc_logit = self.xctc(
x, encoder_padding_mask, "Encoder XCTC output", is_top=True
)
_, xctc_logit = self.xctc.forward_with_lang_tag(x, encoder_padding_mask,
"Encoder XCTC output", "target",
self.embed_tokens(tgt_lang_idx)
if tgt_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None,
is_top=True)
self.show_debug(x, "x after xctc")
if not self.training and self.early_exit_count != 0 and org_bsz != 1:
......@@ -2277,12 +2295,6 @@ class S2TTransformerEncoder(FairseqEncoder):
encoder_padding_mask = torch.stack(output_encoder_padding_mask, dim=0)
self.early_exit_layer_record.extend(output_layers)
if ctc_force_emit is not None:
ctc_logit = [ctc_logit, None, ctc_force_emit]
if xctc_force_emit is not None:
xctc_logit = [xctc_logit, None, xctc_force_emit]
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
......@@ -2308,7 +2320,7 @@ class S2TTransformerEncoder(FairseqEncoder):
[]
if len(encoder_out["ctc_logit"]) == 0
else [
x.index_select(1, new_order)
[x[0].index_select(1, new_order)].extend(x[1:]) if isinstance(x, list) else x.index_select(1, new_order)
for x in encoder_out["ctc_logit"]
if x is not None
]
......@@ -2316,7 +2328,10 @@ class S2TTransformerEncoder(FairseqEncoder):
new_xctc_logit = (
[]
if len(encoder_out["xctc_logit"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["xctc_logit"]]
else [
[x[0].index_select(1, new_order)].extend(x[1:]) if isinstance(x, list) else x.index_select(1, new_order)
for x in encoder_out["xctc_logit"] if x is not None
]
)
new_inter_ctc_logits = (
[]
......
......@@ -9,11 +9,10 @@ from fairseq.modules import (
FairseqDropout,
LayerNorm,
)
from fairseq.data.data_utils import post_process
from fairseq.data.data_utils import post_process, lengths_to_padding_mask
logger = logging.getLogger(__name__)
class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout,
......@@ -39,6 +38,7 @@ class CTC(nn.Module):
self.post_process = "sentencepiece"
self.blank_idx = 0
self.log_flag_dict = dict()
self.path = None
self.save_stream = None
......@@ -65,6 +65,32 @@ class CTC(nn.Module):
return x
def forward_with_lang_tag(self, x, encoder_padding_mask, tag, lang, lang_tok, is_top=False):
if lang_tok is None:
logit = self(x, encoder_padding_mask, tag, is_top)
return logit, [logit, encoder_padding_mask]
else:
if False:
lx = torch.cat((lang_tok.unsqueeze(0), x), 0)
new_encoder_padding_mask = torch.cat((torch.zeros(encoder_padding_mask.size(0), 1, dtype=torch.bool).to(x.device), encoder_padding_mask), 1)
logit = self(lx, new_encoder_padding_mask, tag, is_top)
if lang not in self.log_flag_dict:
self.log_flag_dict[lang] = True
logger.info("Prepend the %s language tag into the logit before CTC in %s." % (lang, tag))
logit_remove = logit[1:, :, :]
return logit_remove, [logit, new_encoder_padding_mask]
else:
lx = lang_tok.unsqueeze(0) + x
logit = self(lx, encoder_padding_mask, tag, is_top)
if lang not in self.log_flag_dict:
self.log_flag_dict[lang] = True
logger.info("Prepend the %s language tag into the logit before CTC in %s." % (lang, tag))
return logit, [logit, encoder_padding_mask]
def softmax(self, x, temperature=1.0):
return F.softmax(self.ctc_projection(x) / temperature, dim=-1, dtype=torch.float32)
......
......@@ -552,6 +552,11 @@ class SpeechToTextTask(LegacyFairseqTask):
def _inference(self, generator, sample, model, remove_bpe):
def decode(toks, escape_unk=False):
symbols_to_strip_from_output = None
if hasattr(generator, "symbols_to_strip_from_output"):
symbols_to_strip_from_output = generator.symbols_to_strip_from_output
else:
symbols_to_strip_from_output = generator.eos
s = self.tgt_dict.string(
toks.int().cpu(),
remove_bpe,
......@@ -561,6 +566,7 @@ class SpeechToTextTask(LegacyFairseqTask):
# alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens.
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
extra_symbols_to_ignore=symbols_to_strip_from_output
)
if self.tokenizer:
s = self.tokenizer.decode(s)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论