Commit b984a889 by xuchen

support the position setting of language tag

parent aed36ae4
...@@ -771,6 +771,15 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -771,6 +771,15 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true", action="store_true",
help="add the position embedding after compression", help="add the position embedding after compression",
) )
# inserted position of language tag
parser.add_argument(
"--enc-lang-tag-pos",
default="input",
choices=["input", "predict"],
type=str,
help="position to insert the language tag, input or before prediction",
)
pass pass
@classmethod @classmethod
...@@ -1364,6 +1373,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1364,6 +1373,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.compression_stat = False self.compression_stat = False
self.enc_lang_tag_pos = getattr(args, "enc_lang_tag_pos", "input")
self.log_flag_dict = dict() self.log_flag_dict = dict()
# gather cosine similarity # gather cosine similarity
...@@ -1805,6 +1815,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1805,6 +1815,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x, input_lengths = self.subsample(x, input_lengths) x, input_lengths = self.subsample(x, input_lengths)
self.show_debug(x, "x after subsampling") self.show_debug(x, "x after subsampling")
if self.enc_lang_tag_pos == "input":
if src_lang_idx is not None: if src_lang_idx is not None:
assert self.embed_tokens is not None assert self.embed_tokens is not None
src_lang_embed = self.embed_tokens(src_lang_idx).unsqueeze(0) src_lang_embed = self.embed_tokens(src_lang_idx).unsqueeze(0)
...@@ -1983,11 +1994,12 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1983,11 +1994,12 @@ class S2TTransformerEncoder(FairseqEncoder):
else: else:
layer_norm = getattr(self, "ctc_norm%d" % layer_idx) layer_norm = getattr(self, "ctc_norm%d" % layer_idx)
norm_x = layer_norm(x) norm_x = layer_norm(x)
logit = inter_ctc(
norm_x, encoder_padding_mask, "Source Layer %d" % layer_idx
)
inter_logit = [logit, encoder_padding_mask] logit, logit_list = inter_ctc.forward_with_lang_tag(norm_x, encoder_padding_mask,
"CTC Layer %d" % layer_idx, "source",
self.embed_tokens(src_lang_idx)
if src_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None)
if self.ctc_pae_ground_truth_ratio > 0: if self.ctc_pae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None) ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ( if (
...@@ -2021,7 +2033,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2021,7 +2033,7 @@ class S2TTransformerEncoder(FairseqEncoder):
~ctc_oracle_mask, -1 ~ctc_oracle_mask, -1
) )
inter_logit = [logit, None, ctc_force_emit] # logit_list = [logit, None, ctc_force_emit]
pae_input = x if self.pae_unnorm_input else norm_x pae_input = x if self.pae_unnorm_input else norm_x
if pae.adapter_type != "none": if pae.adapter_type != "none":
...@@ -2030,14 +2042,14 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2030,14 +2042,14 @@ class S2TTransformerEncoder(FairseqEncoder):
) )
self.show_debug(x, "x after pae") self.show_debug(x, "x after pae")
inter_ctc_logits.append(inter_logit) inter_ctc_logits.append(logit_list)
if not self.training and self.early_exit_layer == layer_idx: if not self.training and self.early_exit_layer == layer_idx:
ctc_logit = inter_logit[0] ctc_logit = logit_list[0]
break break
if not self.training and self.early_exit_count != 0: if not self.training and self.early_exit_count != 0:
predicts = inter_ctc.predict(inter_logit[0], encoder_padding_mask) predicts = inter_ctc.predict(logit_list[0], encoder_padding_mask)
if len(inter_ctc_logits) < self.early_exit_count: if len(inter_ctc_logits) < self.early_exit_count:
for i in range(x.size(1)): for i in range(x.size(1)):
inter_ctc_logits_history[i].append(predicts[i]) inter_ctc_logits_history[i].append(predicts[i])
...@@ -2045,7 +2057,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2045,7 +2057,7 @@ class S2TTransformerEncoder(FairseqEncoder):
if org_bsz == 1: if org_bsz == 1:
early_exit_flag = self.early_exit_or_not(inter_ctc_logits_history[0], predicts[0], self.early_exit_count) early_exit_flag = self.early_exit_or_not(inter_ctc_logits_history[0], predicts[0], self.early_exit_count)
if early_exit_flag: if early_exit_flag:
ctc_logit = inter_logit[0] ctc_logit = logit_list[0]
self.early_exit_layer_record.append(layer_idx) self.early_exit_layer_record.append(layer_idx)
break break
else: else:
...@@ -2056,7 +2068,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2056,7 +2068,7 @@ class S2TTransformerEncoder(FairseqEncoder):
real_idx = batch_idx_dict[i] real_idx = batch_idx_dict[i]
early_exit_flag = self.early_exit_or_not(inter_ctc_logits_history[real_idx], predicts[i], self.early_exit_count) early_exit_flag = self.early_exit_or_not(inter_ctc_logits_history[real_idx], predicts[i], self.early_exit_count)
if early_exit_flag: if early_exit_flag:
final_ctc_logits[real_idx] = inter_logit[0][:, i, :] final_ctc_logits[real_idx] = logit_list[0][:, i, :]
final_encoder_padding_mask[real_idx] = encoder_padding_mask[i, :] final_encoder_padding_mask[real_idx] = encoder_padding_mask[i, :]
early_exit_layer[real_idx] = layer_idx early_exit_layer[real_idx] = layer_idx
else: else:
...@@ -2177,11 +2189,11 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2177,11 +2189,11 @@ class S2TTransformerEncoder(FairseqEncoder):
norm = getattr(self, "xctc_norm%d" % layer_idx) norm = getattr(self, "xctc_norm%d" % layer_idx)
norm_x = norm(x) norm_x = norm(x)
logit = self.xctc( logit, logit_list = self.xctc.forward_with_lang_tag(norm_x, encoder_padding_mask,
norm_x, encoder_padding_mask, "Inter XCTC layer %d" % layer_idx "XCTC Layer %d" % layer_idx, "target",
) self.embed_tokens(tgt_lang_idx)
if tgt_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None)
inter_logit = logit
# CTC alignment # CTC alignment
if self.xctc_pae_ground_truth_ratio > 0: if self.xctc_pae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None) ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
...@@ -2214,7 +2226,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2214,7 +2226,7 @@ class S2TTransformerEncoder(FairseqEncoder):
xctc_force_emit = best_aligns_pad.masked_fill( xctc_force_emit = best_aligns_pad.masked_fill(
~xctc_oracle_mask, -1 ~xctc_oracle_mask, -1
) )
inter_logit = [logit, None, xctc_force_emit] # logit_list = [logit, None, xctc_force_emit]
pae_input = x if self.pae_unnorm_input else norm_x pae_input = x if self.pae_unnorm_input else norm_x
if self.xctc_pae.adapter_type != "none": if self.xctc_pae.adapter_type != "none":
...@@ -2225,7 +2237,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2225,7 +2237,7 @@ class S2TTransformerEncoder(FairseqEncoder):
xctc_oracle_mask, xctc_oracle_mask,
) )
inter_xctc_logits.append(inter_logit) inter_xctc_logits.append(logit_list)
if self.history is not None: if self.history is not None:
self.history.push(x) self.history.push(x)
...@@ -2249,13 +2261,19 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2249,13 +2261,19 @@ class S2TTransformerEncoder(FairseqEncoder):
) )
if self.use_ctc and ctc_logit is None: if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, encoder_padding_mask, "Encoder CTC output", is_top=True) _, ctc_logit = self.ctc.forward_with_lang_tag(x, encoder_padding_mask,
"Encoder CTC output", "source",
self.embed_tokens(src_lang_idx)
if src_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None,
is_top=True)
self.show_debug(x, "x after ctc") self.show_debug(x, "x after ctc")
if self.use_xctc and xctc_logit is None: if self.use_xctc and xctc_logit is None:
xctc_logit = self.xctc( _, xctc_logit = self.xctc.forward_with_lang_tag(x, encoder_padding_mask,
x, encoder_padding_mask, "Encoder XCTC output", is_top=True "Encoder XCTC output", "target",
) self.embed_tokens(tgt_lang_idx)
if tgt_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None,
is_top=True)
self.show_debug(x, "x after xctc") self.show_debug(x, "x after xctc")
if not self.training and self.early_exit_count != 0 and org_bsz != 1: if not self.training and self.early_exit_count != 0 and org_bsz != 1:
...@@ -2277,12 +2295,6 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2277,12 +2295,6 @@ class S2TTransformerEncoder(FairseqEncoder):
encoder_padding_mask = torch.stack(output_encoder_padding_mask, dim=0) encoder_padding_mask = torch.stack(output_encoder_padding_mask, dim=0)
self.early_exit_layer_record.extend(output_layers) self.early_exit_layer_record.extend(output_layers)
if ctc_force_emit is not None:
ctc_logit = [ctc_logit, None, ctc_force_emit]
if xctc_force_emit is not None:
xctc_logit = [xctc_logit, None, xctc_force_emit]
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C "ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
...@@ -2308,7 +2320,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2308,7 +2320,7 @@ class S2TTransformerEncoder(FairseqEncoder):
[] []
if len(encoder_out["ctc_logit"]) == 0 if len(encoder_out["ctc_logit"]) == 0
else [ else [
x.index_select(1, new_order) [x[0].index_select(1, new_order)].extend(x[1:]) if isinstance(x, list) else x.index_select(1, new_order)
for x in encoder_out["ctc_logit"] for x in encoder_out["ctc_logit"]
if x is not None if x is not None
] ]
...@@ -2316,7 +2328,10 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2316,7 +2328,10 @@ class S2TTransformerEncoder(FairseqEncoder):
new_xctc_logit = ( new_xctc_logit = (
[] []
if len(encoder_out["xctc_logit"]) == 0 if len(encoder_out["xctc_logit"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["xctc_logit"]] else [
[x[0].index_select(1, new_order)].extend(x[1:]) if isinstance(x, list) else x.index_select(1, new_order)
for x in encoder_out["xctc_logit"] if x is not None
]
) )
new_inter_ctc_logits = ( new_inter_ctc_logits = (
[] []
......
...@@ -9,11 +9,10 @@ from fairseq.modules import ( ...@@ -9,11 +9,10 @@ from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
) )
from fairseq.data.data_utils import post_process from fairseq.data.data_utils import post_process, lengths_to_padding_mask
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CTC(nn.Module): class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout, def __init__(self, embed_dim, dictionary_size, dropout,
...@@ -39,6 +38,7 @@ class CTC(nn.Module): ...@@ -39,6 +38,7 @@ class CTC(nn.Module):
self.post_process = "sentencepiece" self.post_process = "sentencepiece"
self.blank_idx = 0 self.blank_idx = 0
self.log_flag_dict = dict()
self.path = None self.path = None
self.save_stream = None self.save_stream = None
...@@ -65,6 +65,32 @@ class CTC(nn.Module): ...@@ -65,6 +65,32 @@ class CTC(nn.Module):
return x return x
def forward_with_lang_tag(self, x, encoder_padding_mask, tag, lang, lang_tok, is_top=False):
if lang_tok is None:
logit = self(x, encoder_padding_mask, tag, is_top)
return logit, [logit, encoder_padding_mask]
else:
if False:
lx = torch.cat((lang_tok.unsqueeze(0), x), 0)
new_encoder_padding_mask = torch.cat((torch.zeros(encoder_padding_mask.size(0), 1, dtype=torch.bool).to(x.device), encoder_padding_mask), 1)
logit = self(lx, new_encoder_padding_mask, tag, is_top)
if lang not in self.log_flag_dict:
self.log_flag_dict[lang] = True
logger.info("Prepend the %s language tag into the logit before CTC in %s." % (lang, tag))
logit_remove = logit[1:, :, :]
return logit_remove, [logit, new_encoder_padding_mask]
else:
lx = lang_tok.unsqueeze(0) + x
logit = self(lx, encoder_padding_mask, tag, is_top)
if lang not in self.log_flag_dict:
self.log_flag_dict[lang] = True
logger.info("Prepend the %s language tag into the logit before CTC in %s." % (lang, tag))
return logit, [logit, encoder_padding_mask]
def softmax(self, x, temperature=1.0): def softmax(self, x, temperature=1.0):
return F.softmax(self.ctc_projection(x) / temperature, dim=-1, dtype=torch.float32) return F.softmax(self.ctc_projection(x) / temperature, dim=-1, dtype=torch.float32)
......
...@@ -552,6 +552,11 @@ class SpeechToTextTask(LegacyFairseqTask): ...@@ -552,6 +552,11 @@ class SpeechToTextTask(LegacyFairseqTask):
def _inference(self, generator, sample, model, remove_bpe): def _inference(self, generator, sample, model, remove_bpe):
def decode(toks, escape_unk=False): def decode(toks, escape_unk=False):
symbols_to_strip_from_output = None
if hasattr(generator, "symbols_to_strip_from_output"):
symbols_to_strip_from_output = generator.symbols_to_strip_from_output
else:
symbols_to_strip_from_output = generator.eos
s = self.tgt_dict.string( s = self.tgt_dict.string(
toks.int().cpu(), toks.int().cpu(),
remove_bpe, remove_bpe,
...@@ -561,6 +566,7 @@ class SpeechToTextTask(LegacyFairseqTask): ...@@ -561,6 +566,7 @@ class SpeechToTextTask(LegacyFairseqTask):
# alternative that is unlikely to appear in the real # alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens. # reference, but doesn't get split into multiple tokens.
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
extra_symbols_to_ignore=symbols_to_strip_from_output
) )
if self.tokenizer: if self.tokenizer:
s = self.tokenizer.decode(s) s = self.tokenizer.decode(s)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论