Commit b984a889 by xuchen

support the position setting of language tag

parent aed36ae4
......@@ -9,11 +9,10 @@ from fairseq.modules import (
FairseqDropout,
LayerNorm,
)
from fairseq.data.data_utils import post_process
from fairseq.data.data_utils import post_process, lengths_to_padding_mask
logger = logging.getLogger(__name__)
class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout,
......@@ -39,6 +38,7 @@ class CTC(nn.Module):
self.post_process = "sentencepiece"
self.blank_idx = 0
self.log_flag_dict = dict()
self.path = None
self.save_stream = None
......@@ -65,6 +65,32 @@ class CTC(nn.Module):
return x
def forward_with_lang_tag(self, x, encoder_padding_mask, tag, lang, lang_tok, is_top=False):
if lang_tok is None:
logit = self(x, encoder_padding_mask, tag, is_top)
return logit, [logit, encoder_padding_mask]
else:
if False:
lx = torch.cat((lang_tok.unsqueeze(0), x), 0)
new_encoder_padding_mask = torch.cat((torch.zeros(encoder_padding_mask.size(0), 1, dtype=torch.bool).to(x.device), encoder_padding_mask), 1)
logit = self(lx, new_encoder_padding_mask, tag, is_top)
if lang not in self.log_flag_dict:
self.log_flag_dict[lang] = True
logger.info("Prepend the %s language tag into the logit before CTC in %s." % (lang, tag))
logit_remove = logit[1:, :, :]
return logit_remove, [logit, new_encoder_padding_mask]
else:
lx = lang_tok.unsqueeze(0) + x
logit = self(lx, encoder_padding_mask, tag, is_top)
if lang not in self.log_flag_dict:
self.log_flag_dict[lang] = True
logger.info("Prepend the %s language tag into the logit before CTC in %s." % (lang, tag))
return logit, [logit, encoder_padding_mask]
def softmax(self, x, temperature=1.0):
return F.softmax(self.ctc_projection(x) / temperature, dim=-1, dtype=torch.float32)
......
......@@ -552,6 +552,11 @@ class SpeechToTextTask(LegacyFairseqTask):
def _inference(self, generator, sample, model, remove_bpe):
def decode(toks, escape_unk=False):
symbols_to_strip_from_output = None
if hasattr(generator, "symbols_to_strip_from_output"):
symbols_to_strip_from_output = generator.symbols_to_strip_from_output
else:
symbols_to_strip_from_output = generator.eos
s = self.tgt_dict.string(
toks.int().cpu(),
remove_bpe,
......@@ -561,6 +566,7 @@ class SpeechToTextTask(LegacyFairseqTask):
# alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens.
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
extra_symbols_to_ignore=symbols_to_strip_from_output
)
if self.tokenizer:
s = self.tokenizer.decode(s)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论