Commit aed36ae4 by xuchen

optimize the implementation of lang tag

parent a64cdfcc
......@@ -367,9 +367,11 @@ class CtcCriterion(FairseqCriterion):
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
src_lang_idx = sample["net_input"].get("src_lang_idx", None)
tgt_lang_idx = sample["net_input"].get("tgt_lang_idx", None)
with torch.no_grad():
encoder_out = model.encoder(src_tokens, src_lengths,
src_lang_idx=src_lang_idx,
tgt_lang_idx=tgt_lang_idx)
ctc_logit = None
......
......@@ -82,6 +82,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
prev_output_tokens = sample["net_input"]["prev_output_tokens"]
src_lang_idx = sample["net_input"].get("src_lang_idx", None)
tgt_lang_idx = sample["net_input"].get("tgt_lang_idx", None)
train_enc_only = False
......@@ -105,10 +106,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
ctc_alignment_oracle = self.ctc_criterion.get_ground_truth_alignment(model, sample)
encoder_out = model.encoder(src_tokens, src_lengths,
ctc_alignment_oracle=ctc_alignment_oracle,
src_lang_idx=src_lang_idx,
tgt_lang_idx=tgt_lang_idx)
else:
encoder_out = model.encoder(src_tokens=src_tokens,
src_lengths=src_lengths,
src_lang_idx=src_lang_idx,
tgt_lang_idx=tgt_lang_idx)
net_output = model.decoder(
......
......@@ -104,6 +104,13 @@ class S2TDataConfig(object):
return self.config.get("prepend_tgt_lang_tag_to_enc", False)
@property
def prepend_src_lang_tag(self) -> bool:
"""Prepend source lang ID token as the source BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return self.config.get("prepend_src_lang_tag", False)
@property
def input_feat_per_channel(self):
"""The dimension of input features (per audio channel)"""
return self.config.get("input_feat_per_channel", 80)
......@@ -312,7 +319,8 @@ class SpeechToTextDataset(FairseqDataset):
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
src_bpe_tokenizer=None
src_bpe_tokenizer=None,
kwargs=None,
):
self.split, self.is_train_split = split, is_train_split
self.data_cfg = data_cfg
......@@ -321,6 +329,7 @@ class SpeechToTextDataset(FairseqDataset):
self.n_samples = len(audio_paths)
if data_cfg.share_src_and_tgt:
src_texts = tgt_texts
assert len(n_frames) == self.n_samples > 0
assert src_texts is None or len(src_texts) == self.n_samples
assert tgt_texts is None or len(tgt_texts) == self.n_samples
......@@ -347,12 +356,22 @@ class SpeechToTextDataset(FairseqDataset):
self.bpe_tokenizer = bpe_tokenizer
self.src_bpe_tokenizer = src_bpe_tokenizer
if "aligned_tgt_texts" in kwargs:
aligned_tgt_texts = kwargs["aligned_tgt_texts"]
assert aligned_tgt_texts is None or len(aligned_tgt_texts) == self.n_samples
self.aligned_tgt_texts = aligned_tgt_texts
if "ctc_tgt_texts" in kwargs:
ctc_tgt_texts = kwargs["ctc_tgt_texts"]
assert ctc_tgt_texts is None or len(ctc_tgt_texts) == self.n_samples
self.ctc_tgt_texts = ctc_tgt_texts
logger.info(self.__repr__())
def __repr__(self):
return (
self.__class__.__name__
+ f'(split="{self.split}", n_samples={self.n_samples}, '
f"prepend_src_lang_tag={self.data_cfg.prepend_src_lang_tag}, "
f"prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, "
f"prepend_tgt_lang_tag_to_enc={self.data_cfg.prepend_tgt_lang_tag_to_enc}, "
f"shuffle={self.shuffle}, transforms={self.feature_transforms})"
......@@ -369,7 +388,14 @@ class SpeechToTextDataset(FairseqDataset):
tgt_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
]
assert all(t in self.tgt_dict for t in tgt_lang_tags), tgt_lang_tags
assert all(t in self.tgt_dict for t in tgt_lang_tags)
if self.data_cfg.prepend_src_lang_tag:
assert self.src_langs is not None and self.src_dict is not None
src_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.src_langs)
]
assert all(t in self.src_dict for t in src_lang_tags)
def tokenize_text(self, text: str, is_src=False):
if self.pre_tokenizer is not None:
......@@ -380,7 +406,10 @@ class SpeechToTextDataset(FairseqDataset):
def __getitem__(
self, index: int
) -> Tuple[int, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> Dict:
sample = dict()
sample["indice"] = index
source = get_features_or_waveform(
self.audio_paths[index],
need_waveform=self.data_cfg.use_audio_input or (self.is_train_split and self.speed_perturb)
......@@ -392,6 +421,17 @@ class SpeechToTextDataset(FairseqDataset):
assert not self.data_cfg.use_audio_input
source = self.feature_transforms(source)
source = torch.from_numpy(source).float()
sample["source"] = source
if self.data_cfg.prepend_tgt_lang_tag or self.data_cfg.prepend_tgt_lang_tag_to_enc:
tgt_lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index])
tgt_lang_tag_idx = self.tgt_dict.index(tgt_lang_tag)
sample["tgt_lang_tag_idx"] = tgt_lang_tag_idx
if self.data_cfg.prepend_src_lang_tag:
src_lang_tag = self.LANG_TAG_TEMPLATE.format(self.src_langs[index])
src_lang_tag_idx = self.src_dict.index(src_lang_tag)
sample["src_lang_tag_idx"] = src_lang_tag_idx
target = None
if self.tgt_texts is not None:
......@@ -403,28 +443,54 @@ class SpeechToTextDataset(FairseqDataset):
lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index])
lang_tag_idx = self.tgt_dict.index(lang_tag)
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
sample["target"] = target
transcript = None
aligned_target = None
if hasattr(self, "aligned_tgt_texts"):
tokenized = self.tokenize_text(self.aligned_tgt_texts[index])
aligned_target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=True
).long()
if self.data_cfg.prepend_tgt_lang_tag:
aligned_target = torch.cat((torch.LongTensor([tgt_lang_tag_idx]), aligned_target), 0)
sample["aligned_target"] = aligned_target
ctc_target = None
if hasattr(self, "ctc_tgt_texts"):
tokenized = self.tokenize_text(self.ctc_tgt_texts[index])
ctc_target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=True
).long()
if self.data_cfg.prepend_tgt_lang_tag:
ctc_target = torch.cat((torch.LongTensor([tgt_lang_tag_idx]), ctc_target), 0)
sample["ctc_target"] = ctc_target
transcript = None
if self.src_dict is not None and self.src_texts is not None and self.src_bpe_tokenizer is not None:
tokenized = self.tokenize_text(self.src_texts[index], True)
transcript = self.src_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=True
).long()
return index, source, target, transcript
if self.data_cfg.prepend_src_lang_tag:
transcript = torch.cat((torch.LongTensor([src_lang_tag_idx]), transcript), 0)
sample["transcript"] = transcript
return sample
def __len__(self):
return self.n_samples
def collater(self, samples: List[Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]]) -> Dict:
def collater(self, samples) -> Dict:
if len(samples) == 0:
return {}
indices = torch.tensor([i for i, _, _, _ in samples], dtype=torch.long)
indices = torch.tensor([sample["indice"] for sample in samples], dtype=torch.long)
frames = _collate_frames(
[s for _, s, _, _ in samples], self.data_cfg.use_audio_input
[sample["source"] for sample in samples], self.data_cfg.use_audio_input
)
# sort samples by descending number of frames
n_frames = torch.tensor([s.size(0) for _, s, _, _ in samples], dtype=torch.long)
n_frames = torch.tensor([sample["source"].size(0) for sample in samples], dtype=torch.long)
n_frames, order = n_frames.sort(descending=True)
indices = indices.index_select(0, order)
frames = frames.index_select(0, order)
......@@ -432,10 +498,15 @@ class SpeechToTextDataset(FairseqDataset):
target, target_lengths = None, None
prev_output_tokens = None
ntokens = None
transcript = None
transcript_lengths = None
transcript_ntokens = None
src_lang_idx = None
tgt_lang_idx = None
if self.tgt_texts is not None:
target = fairseq_data_utils.collate_tokens(
[t for _, _, t, _ in samples],
[sample["target"] for sample in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
......@@ -443,30 +514,60 @@ class SpeechToTextDataset(FairseqDataset):
)
target = target.index_select(0, order)
if self.data_cfg.prepend_tgt_lang_tag_to_enc:
tgt_lang_idx = target[:, 0]
if not self.data_cfg.prepend_tgt_lang_tag:
target = target[:, 1:]
target_lengths = torch.tensor(
[t.size(0) for _, _, t, _ in samples], dtype=torch.long
[sample["target"].size(0) for sample in samples], dtype=torch.long
).index_select(0, order)
prev_output_tokens = fairseq_data_utils.collate_tokens(
[t for _, _, t, _ in samples],
[sample["target"] for sample in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(t.size(0) for _, _, t, _ in samples)
if self.data_cfg.prepend_tgt_lang_tag_to_enc and not self.data_cfg.prepend_tgt_lang_tag:
prev_output_tokens = torch.cat((prev_output_tokens[:, 0], prev_output_tokens[:, 2:]), dim=1)
ntokens -= 1
ntokens = sum(sample["target"].size(0) for sample in samples)
if "tgt_lang_tag_idx" in samples[0]:
tgt_lang_idx = torch.tensor([sample["tgt_lang_tag_idx"] for sample in samples], dtype=torch.long)
tgt_lang_idx = tgt_lang_idx[order]
if "src_lang_tag_idx" in samples[0]:
src_lang_idx = torch.tensor([sample["src_lang_tag_idx"] for sample in samples], dtype=torch.long)
src_lang_idx = src_lang_idx[order]
aligned_target = None
aligned_target_lengths = None
if hasattr(self, "aligned_tgt_texts"):
aligned_target = fairseq_data_utils.collate_tokens(
[sample["aligned_target"] for sample in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
aligned_target = aligned_target.index_select(0, order)
aligned_target_lengths = torch.tensor(
[sample["aligned_target"].size(0) for sample in samples], dtype=torch.long
).index_select(0, order)
ctc_target = None
ctc_target_lengths = None
if hasattr(self, "ctc_tgt_texts"):
ctc_target = fairseq_data_utils.collate_tokens(
[sample["ctc_target"] for sample in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
ctc_target = ctc_target.index_select(0, order)
ctc_target_lengths = torch.tensor(
[sample["ctc_target"].size(0) for sample in samples], dtype=torch.long
).index_select(0, order)
if self.src_dict is not None and self.src_texts is not None:
transcript_list = [sample["transcript"] for sample in samples]
transcript = fairseq_data_utils.collate_tokens(
[t for _, _, _, t in samples],
transcript_list,
self.src_dict.pad(),
self.src_dict.eos(),
left_pad=False,
......@@ -474,13 +575,9 @@ class SpeechToTextDataset(FairseqDataset):
)
transcript = transcript.index_select(0, order)
transcript_lengths = torch.tensor(
[t.size(0) for _, _, _, t in samples], dtype=torch.long
[item.size(0) for item in transcript_list], dtype=torch.long
).index_select(0, order)
transcript_ntokens = sum(t.size(0) for _, _, _, t in samples)
else:
transcript = None
transcript_lengths = None
transcript_ntokens = None
transcript_ntokens = sum(item.size(0) for item in transcript_list)
out = {
"id": indices,
......@@ -488,6 +585,7 @@ class SpeechToTextDataset(FairseqDataset):
"src_tokens": frames,
"src_lengths": n_frames,
"prev_output_tokens": prev_output_tokens,
"src_lang_idx": src_lang_idx,
"tgt_lang_idx": tgt_lang_idx,
},
"transcript": {
......@@ -500,6 +598,16 @@ class SpeechToTextDataset(FairseqDataset):
"ntokens": ntokens,
"nsentences": len(samples),
}
if aligned_target is not None:
out["aligned_target"] = {
"tokens": aligned_target,
"lengths": aligned_target_lengths,
}
if ctc_target is not None:
out["ctc_target"] = {
"tokens": ctc_target,
"lengths": ctc_target_lengths,
}
return out
......@@ -538,6 +646,8 @@ class SpeechToTextDatasetCreator(object):
# mandatory columns
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
KEY_TGT_TEXT = "tgt_text"
KEY_ALIGNED_TGT_TEXT = "aligned_tgt_text"
KEY_CTC_TGT_TEXT = "ctc_tgt_text"
# optional columns
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
......@@ -558,6 +668,7 @@ class SpeechToTextDatasetCreator(object):
src_bpe_tokenizer=None
) -> SpeechToTextDataset:
audio_paths, n_frames, src_texts, tgt_texts, ids = [], [], [], [], []
aligned_tgt_texts, ctc_tgt_texts = [], []
speakers, src_langs, tgt_langs = [], [], []
for s in samples:
ids.extend([ss[cls.KEY_ID] for ss in s])
......@@ -572,6 +683,18 @@ class SpeechToTextDatasetCreator(object):
speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s])
src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s])
tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s])
kwargs = dict()
if len(s) > 0 and cls.KEY_ALIGNED_TGT_TEXT not in s[0]:
aligned_tgt_texts = None
else:
aligned_tgt_texts.extend([ss[cls.KEY_ALIGNED_TGT_TEXT] for ss in s])
kwargs["aligned_tgt_texts"] = aligned_tgt_texts
if len(s) > 0 and cls.KEY_CTC_TGT_TEXT not in s[0]:
ctc_tgt_texts = None
else:
ctc_tgt_texts.extend([ss[cls.KEY_CTC_TGT_TEXT] for ss in s])
kwargs["ctc_tgt_texts"] = ctc_tgt_texts
return SpeechToTextDataset(
split_name,
is_train_split,
......@@ -588,7 +711,8 @@ class SpeechToTextDatasetCreator(object):
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
src_bpe_tokenizer
src_bpe_tokenizer,
kwargs
)
@classmethod
......
......@@ -1364,6 +1364,8 @@ class S2TTransformerEncoder(FairseqEncoder):
self.compression_stat = False
self.log_flag_dict = dict()
# gather cosine similarity
self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
self.gather_cos_sim_dis = 2
......@@ -1775,27 +1777,15 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.history is not None:
self.history.clean()
src_lang_idx = kwargs.get("src_lang_idx", None)
tgt_lang_idx = kwargs.get("tgt_lang_idx", None)
has_add_lang_tag = False
# (B, T, D) -> (T, B, D)
x = src_tokens.transpose(0, 1)
input_lengths = src_lengths
org_bsz = x.size(1)
if (
self.mixup
and layer_idx == mixup_layer
):
if tgt_lang_idx is not None:
assert self.embed_tokens is not None
tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0)
if mixup is not None:
pass
x = torch.cat((tgt_lang_embed, x), 0)
input_lengths += 1
has_add_lang_tag = True
if (
(self.training or self.mixup_infer)
and self.mixup
and layer_idx == mixup_layer
......@@ -1815,15 +1805,26 @@ class S2TTransformerEncoder(FairseqEncoder):
x, input_lengths = self.subsample(x, input_lengths)
self.show_debug(x, "x after subsampling")
#if tgt_lang_idx is not None and False:
if tgt_lang_idx is not None and not has_add_lang_tag:
if src_lang_idx is not None:
assert self.embed_tokens is not None
src_lang_embed = self.embed_tokens(src_lang_idx).unsqueeze(0)
x = torch.cat((src_lang_embed, x), 0)
input_lengths += 1
if "prepend_src_lang" not in self.log_flag_dict:
self.log_flag_dict["prepend_src_lang"] = True
logger.info("Prepend the source language tag into the encoder input.")
if tgt_lang_idx is not None:
assert self.embed_tokens is not None
tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0)
if mixup is not None:
pass
x = torch.cat((tgt_lang_embed, x), 0)
input_lengths += 1
if "prepend_tgt_lang" not in self.log_flag_dict:
self.log_flag_dict["prepend_tgt_lang"] = True
logger.info("Prepend the target language tag into the encoder input.")
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if encoder_padding_mask.size(1) < x.size(0):
bsz = encoder_padding_mask.size(0)
......@@ -2248,12 +2249,12 @@ class S2TTransformerEncoder(FairseqEncoder):
)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, encoder_padding_mask, "Encoder output", is_top=True)
ctc_logit = self.ctc(x, encoder_padding_mask, "Encoder CTC output", is_top=True)
self.show_debug(x, "x after ctc")
if self.use_xctc and xctc_logit is None:
xctc_logit = self.xctc(
x, encoder_padding_mask, "Encoder output", is_top=True
x, encoder_padding_mask, "Encoder XCTC output", is_top=True
)
self.show_debug(x, "x after xctc")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论