Commit 4f452308 by xuchen

prepend target language tag (to enc)

parent ab42136f
......@@ -95,6 +95,13 @@ class S2TDataConfig(object):
return self.config.get("prepend_tgt_lang_tag", False)
@property
def prepend_tgt_lang_tag_to_enc(self) -> bool:
"""Prepend target lang ID token as the target BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return self.config.get("prepend_tgt_lang_tag_to_enc", False)
@property
def input_feat_per_channel(self):
"""The dimension of input features (per audio channel)"""
return self.config.get("input_feat_per_channel", 80)
......@@ -317,6 +324,7 @@ class SpeechToTextDataset(FairseqDataset):
self.__class__.__name__
+ f'(split="{self.split}", n_samples={self.n_samples}, '
f"prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, "
f"prepend_tgt_lang_tag_to_enc={self.data_cfg.prepend_tgt_lang_tag_to_enc}, "
f"shuffle={self.shuffle}, transforms={self.feature_transforms})"
)
......@@ -326,7 +334,7 @@ class SpeechToTextDataset(FairseqDataset):
return re.match(pattern, token)
def check_tgt_lang_tag(self):
if self.data_cfg.prepend_tgt_lang_tag:
if self.data_cfg.prepend_tgt_lang_tag or self.data_cfg.prepend_tgt_lang_tag_to_enc:
assert self.tgt_langs is not None and self.tgt_dict is not None
tgt_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
......@@ -361,7 +369,7 @@ class SpeechToTextDataset(FairseqDataset):
target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=True
).long()
if self.data_cfg.prepend_tgt_lang_tag:
if self.data_cfg.prepend_tgt_lang_tag or self.data_cfg.prepend_tgt_lang_tag_to_enc:
lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index])
lang_tag_idx = self.tgt_dict.index(lang_tag)
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
......@@ -372,7 +380,7 @@ class SpeechToTextDataset(FairseqDataset):
aligned_target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=True
).long()
if self.data_cfg.prepend_tgt_lang_tag:
if self.data_cfg.prepend_tgt_lang_tag or self.data_cfg.prepend_tgt_lang_tag_to_enc:
lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index])
lang_tag_idx = self.tgt_dict.index(lang_tag)
aligned_target = torch.cat((torch.LongTensor([lang_tag_idx]), aligned_target), 0)
......@@ -383,7 +391,7 @@ class SpeechToTextDataset(FairseqDataset):
ctc_target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=True
).long()
if self.data_cfg.prepend_tgt_lang_tag:
if self.data_cfg.prepend_tgt_lang_tag or self.data_cfg.prepend_tgt_lang_tag_to_enc:
lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index])
lang_tag_idx = self.tgt_dict.index(lang_tag)
ctc_target = torch.cat((torch.LongTensor([lang_tag_idx]), ctc_target), 0)
......@@ -415,6 +423,7 @@ class SpeechToTextDataset(FairseqDataset):
target, target_lengths = None, None
prev_output_tokens = None
ntokens = None
tgt_lang_idx = None
if self.tgt_texts is not None:
target = fairseq_data_utils.collate_tokens(
[t for _, _, t, _, _, _ in samples],
......@@ -424,6 +433,11 @@ class SpeechToTextDataset(FairseqDataset):
move_eos_to_beginning=False,
)
target = target.index_select(0, order)
if self.data_cfg.prepend_tgt_lang_tag_to_enc:
tgt_lang_idx = target[:, 0]
if not self.data_cfg.prepend_tgt_lang_tag:
target = target[:, 1:]
target_lengths = torch.tensor(
[t.size(0) for _, _, t, _, _, _ in samples], dtype=torch.long
).index_select(0, order)
......@@ -436,6 +450,9 @@ class SpeechToTextDataset(FairseqDataset):
)
prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(t.size(0) for _, _, t, _, _, _ in samples)
if self.data_cfg.prepend_tgt_lang_tag_to_enc and not self.data_cfg.prepend_tgt_lang_tag:
prev_output_tokens = torch.cat((prev_output_tokens[:, 0:1], prev_output_tokens[:, 2:]), dim=1)
ntokens -= 1
if self.aligned_tgt_texts is not None:
aligned_target = fairseq_data_utils.collate_tokens(
......@@ -493,6 +510,7 @@ class SpeechToTextDataset(FairseqDataset):
"src_tokens": frames,
"src_lengths": n_frames,
"prev_output_tokens": prev_output_tokens,
"tgt_lang_idx": tgt_lang_idx,
},
"transcript": {
"tokens": transcript,
......
......@@ -97,6 +97,13 @@ class S2TDataConfig(object):
return self.config.get("prepend_tgt_lang_tag", False)
@property
def prepend_tgt_lang_tag_to_enc(self) -> bool:
"""Prepend target lang ID token as the target BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return self.config.get("prepend_tgt_lang_tag_to_enc", False)
@property
def input_feat_per_channel(self):
"""The dimension of input features (per audio channel)"""
return self.config.get("input_feat_per_channel", 80)
......@@ -347,6 +354,7 @@ class SpeechToTextDataset(FairseqDataset):
self.__class__.__name__
+ f'(split="{self.split}", n_samples={self.n_samples}, '
f"prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, "
f"prepend_tgt_lang_tag_to_enc={self.data_cfg.prepend_tgt_lang_tag_to_enc}, "
f"shuffle={self.shuffle}, transforms={self.feature_transforms})"
)
......@@ -356,12 +364,12 @@ class SpeechToTextDataset(FairseqDataset):
return re.match(pattern, token)
def check_tgt_lang_tag(self):
if self.data_cfg.prepend_tgt_lang_tag:
if self.data_cfg.prepend_tgt_lang_tag or self.data_cfg.prepend_tgt_lang_tag_to_enc:
assert self.tgt_langs is not None and self.tgt_dict is not None
tgt_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
]
assert all(t in self.tgt_dict for t in tgt_lang_tags)
assert all(t in self.tgt_dict for t in tgt_lang_tags), tgt_lang_tags
def tokenize_text(self, text: str, is_src=False):
if self.pre_tokenizer is not None:
......@@ -391,7 +399,7 @@ class SpeechToTextDataset(FairseqDataset):
target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=True
).long()
if self.data_cfg.prepend_tgt_lang_tag:
if self.data_cfg.prepend_tgt_lang_tag or self.data_cfg.prepend_tgt_lang_tag_to_enc:
lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index])
lang_tag_idx = self.tgt_dict.index(lang_tag)
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
......@@ -424,6 +432,7 @@ class SpeechToTextDataset(FairseqDataset):
target, target_lengths = None, None
prev_output_tokens = None
ntokens = None
tgt_lang_idx = None
if self.tgt_texts is not None:
target = fairseq_data_utils.collate_tokens(
[t for _, _, t, _ in samples],
......@@ -433,6 +442,12 @@ class SpeechToTextDataset(FairseqDataset):
move_eos_to_beginning=False,
)
target = target.index_select(0, order)
if self.data_cfg.prepend_tgt_lang_tag_to_enc:
tgt_lang_idx = target[:, 0]
if not self.data_cfg.prepend_tgt_lang_tag:
target = target[:, 1:]
target_lengths = torch.tensor(
[t.size(0) for _, _, t, _ in samples], dtype=torch.long
).index_select(0, order)
......@@ -445,6 +460,9 @@ class SpeechToTextDataset(FairseqDataset):
)
prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(t.size(0) for _, _, t, _ in samples)
if self.data_cfg.prepend_tgt_lang_tag_to_enc and not self.data_cfg.prepend_tgt_lang_tag:
prev_output_tokens = torch.cat((prev_output_tokens[:, 0], prev_output_tokens[:, 2:]), dim=1)
ntokens -= 1
if self.src_dict is not None and self.src_texts is not None:
transcript = fairseq_data_utils.collate_tokens(
......@@ -470,6 +488,7 @@ class SpeechToTextDataset(FairseqDataset):
"src_tokens": frames,
"src_lengths": n_frames,
"prev_output_tokens": prev_output_tokens,
"tgt_lang_idx": tgt_lang_idx,
},
"transcript": {
"tokens": transcript,
......@@ -611,6 +630,7 @@ class SpeechToTextDatasetCreator(object):
tsv_path = op.join(root, f"{split}.tsv")
if not op.isfile(tsv_path):
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
logger.info("Start loading dataset {}.".format(tsv_path))
with open(tsv_path) as f:
reader = csv.DictReader(
f,
......
......@@ -487,8 +487,8 @@ class SpeechToTextTask(LegacyFairseqTask):
bleu = comp_bleu(
correct=meters["_bleu_counts"].sum,
total=meters["_bleu_totals"].sum,
sys_len=meters["_bleu_sys_len"].sum,
ref_len=meters["_bleu_ref_len"].sum,
sys_len=int(meters["_bleu_sys_len"].sum),
ref_len=int(meters["_bleu_ref_len"].sum),
**smooth
)
return round(bleu.score, 2)
......@@ -515,6 +515,7 @@ class SpeechToTextTask(LegacyFairseqTask):
'Please set "--prefix-size 1" since '
"target language ID token is prepended as BOS."
)
self.prefix_size = getattr(args, "prefix_size", 0)
lang_token_ids = {
i
for s, i in self.tgt_dict.indices.items()
......@@ -567,7 +568,12 @@ class SpeechToTextTask(LegacyFairseqTask):
s = self.tokenizer.decode(s)
return s
gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None)
if self.data_cfg.prepend_tgt_lang_tag and self.prefix_size > 0:
prefix_tokens = sample["target"][:, : self.prefix_size]
else:
prefix_tokens = None
gen_out = self.inference_step(generator, [model], sample, prefix_tokens=prefix_tokens)
hyps, refs = [], []
for i in range(len(gen_out)):
hyps.append(decode(gen_out[i][0]["tokens"]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论