Commit aed36ae4 by xuchen

optimize the implementation of lang tag

parent a64cdfcc
...@@ -367,9 +367,11 @@ class CtcCriterion(FairseqCriterion): ...@@ -367,9 +367,11 @@ class CtcCriterion(FairseqCriterion):
src_tokens = sample["net_input"]["src_tokens"] src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"] src_lengths = sample["net_input"]["src_lengths"]
src_lang_idx = sample["net_input"].get("src_lang_idx", None)
tgt_lang_idx = sample["net_input"].get("tgt_lang_idx", None) tgt_lang_idx = sample["net_input"].get("tgt_lang_idx", None)
with torch.no_grad(): with torch.no_grad():
encoder_out = model.encoder(src_tokens, src_lengths, encoder_out = model.encoder(src_tokens, src_lengths,
src_lang_idx=src_lang_idx,
tgt_lang_idx=tgt_lang_idx) tgt_lang_idx=tgt_lang_idx)
ctc_logit = None ctc_logit = None
......
...@@ -82,6 +82,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -82,6 +82,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
src_tokens = sample["net_input"]["src_tokens"] src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"] src_lengths = sample["net_input"]["src_lengths"]
prev_output_tokens = sample["net_input"]["prev_output_tokens"] prev_output_tokens = sample["net_input"]["prev_output_tokens"]
src_lang_idx = sample["net_input"].get("src_lang_idx", None)
tgt_lang_idx = sample["net_input"].get("tgt_lang_idx", None) tgt_lang_idx = sample["net_input"].get("tgt_lang_idx", None)
train_enc_only = False train_enc_only = False
...@@ -105,10 +106,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -105,10 +106,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
ctc_alignment_oracle = self.ctc_criterion.get_ground_truth_alignment(model, sample) ctc_alignment_oracle = self.ctc_criterion.get_ground_truth_alignment(model, sample)
encoder_out = model.encoder(src_tokens, src_lengths, encoder_out = model.encoder(src_tokens, src_lengths,
ctc_alignment_oracle=ctc_alignment_oracle, ctc_alignment_oracle=ctc_alignment_oracle,
src_lang_idx=src_lang_idx,
tgt_lang_idx=tgt_lang_idx) tgt_lang_idx=tgt_lang_idx)
else: else:
encoder_out = model.encoder(src_tokens=src_tokens, encoder_out = model.encoder(src_tokens=src_tokens,
src_lengths=src_lengths, src_lengths=src_lengths,
src_lang_idx=src_lang_idx,
tgt_lang_idx=tgt_lang_idx) tgt_lang_idx=tgt_lang_idx)
net_output = model.decoder( net_output = model.decoder(
......
...@@ -104,6 +104,13 @@ class S2TDataConfig(object): ...@@ -104,6 +104,13 @@ class S2TDataConfig(object):
return self.config.get("prepend_tgt_lang_tag_to_enc", False) return self.config.get("prepend_tgt_lang_tag_to_enc", False)
@property @property
def prepend_src_lang_tag(self) -> bool:
"""Prepend source lang ID token as the source BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return self.config.get("prepend_src_lang_tag", False)
@property
def input_feat_per_channel(self): def input_feat_per_channel(self):
"""The dimension of input features (per audio channel)""" """The dimension of input features (per audio channel)"""
return self.config.get("input_feat_per_channel", 80) return self.config.get("input_feat_per_channel", 80)
...@@ -312,7 +319,8 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -312,7 +319,8 @@ class SpeechToTextDataset(FairseqDataset):
tgt_dict: Optional[Dictionary] = None, tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None, pre_tokenizer=None,
bpe_tokenizer=None, bpe_tokenizer=None,
src_bpe_tokenizer=None src_bpe_tokenizer=None,
kwargs=None,
): ):
self.split, self.is_train_split = split, is_train_split self.split, self.is_train_split = split, is_train_split
self.data_cfg = data_cfg self.data_cfg = data_cfg
...@@ -321,6 +329,7 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -321,6 +329,7 @@ class SpeechToTextDataset(FairseqDataset):
self.n_samples = len(audio_paths) self.n_samples = len(audio_paths)
if data_cfg.share_src_and_tgt: if data_cfg.share_src_and_tgt:
src_texts = tgt_texts src_texts = tgt_texts
assert len(n_frames) == self.n_samples > 0 assert len(n_frames) == self.n_samples > 0
assert src_texts is None or len(src_texts) == self.n_samples assert src_texts is None or len(src_texts) == self.n_samples
assert tgt_texts is None or len(tgt_texts) == self.n_samples assert tgt_texts is None or len(tgt_texts) == self.n_samples
...@@ -347,12 +356,22 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -347,12 +356,22 @@ class SpeechToTextDataset(FairseqDataset):
self.bpe_tokenizer = bpe_tokenizer self.bpe_tokenizer = bpe_tokenizer
self.src_bpe_tokenizer = src_bpe_tokenizer self.src_bpe_tokenizer = src_bpe_tokenizer
if "aligned_tgt_texts" in kwargs:
aligned_tgt_texts = kwargs["aligned_tgt_texts"]
assert aligned_tgt_texts is None or len(aligned_tgt_texts) == self.n_samples
self.aligned_tgt_texts = aligned_tgt_texts
if "ctc_tgt_texts" in kwargs:
ctc_tgt_texts = kwargs["ctc_tgt_texts"]
assert ctc_tgt_texts is None or len(ctc_tgt_texts) == self.n_samples
self.ctc_tgt_texts = ctc_tgt_texts
logger.info(self.__repr__()) logger.info(self.__repr__())
def __repr__(self): def __repr__(self):
return ( return (
self.__class__.__name__ self.__class__.__name__
+ f'(split="{self.split}", n_samples={self.n_samples}, ' + f'(split="{self.split}", n_samples={self.n_samples}, '
f"prepend_src_lang_tag={self.data_cfg.prepend_src_lang_tag}, "
f"prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, " f"prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, "
f"prepend_tgt_lang_tag_to_enc={self.data_cfg.prepend_tgt_lang_tag_to_enc}, " f"prepend_tgt_lang_tag_to_enc={self.data_cfg.prepend_tgt_lang_tag_to_enc}, "
f"shuffle={self.shuffle}, transforms={self.feature_transforms})" f"shuffle={self.shuffle}, transforms={self.feature_transforms})"
...@@ -369,7 +388,14 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -369,7 +388,14 @@ class SpeechToTextDataset(FairseqDataset):
tgt_lang_tags = [ tgt_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs) self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
] ]
assert all(t in self.tgt_dict for t in tgt_lang_tags), tgt_lang_tags assert all(t in self.tgt_dict for t in tgt_lang_tags)
if self.data_cfg.prepend_src_lang_tag:
assert self.src_langs is not None and self.src_dict is not None
src_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.src_langs)
]
assert all(t in self.src_dict for t in src_lang_tags)
def tokenize_text(self, text: str, is_src=False): def tokenize_text(self, text: str, is_src=False):
if self.pre_tokenizer is not None: if self.pre_tokenizer is not None:
...@@ -380,7 +406,10 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -380,7 +406,10 @@ class SpeechToTextDataset(FairseqDataset):
def __getitem__( def __getitem__(
self, index: int self, index: int
) -> Tuple[int, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Dict:
sample = dict()
sample["indice"] = index
source = get_features_or_waveform( source = get_features_or_waveform(
self.audio_paths[index], self.audio_paths[index],
need_waveform=self.data_cfg.use_audio_input or (self.is_train_split and self.speed_perturb) need_waveform=self.data_cfg.use_audio_input or (self.is_train_split and self.speed_perturb)
...@@ -392,6 +421,17 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -392,6 +421,17 @@ class SpeechToTextDataset(FairseqDataset):
assert not self.data_cfg.use_audio_input assert not self.data_cfg.use_audio_input
source = self.feature_transforms(source) source = self.feature_transforms(source)
source = torch.from_numpy(source).float() source = torch.from_numpy(source).float()
sample["source"] = source
if self.data_cfg.prepend_tgt_lang_tag or self.data_cfg.prepend_tgt_lang_tag_to_enc:
tgt_lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index])
tgt_lang_tag_idx = self.tgt_dict.index(tgt_lang_tag)
sample["tgt_lang_tag_idx"] = tgt_lang_tag_idx
if self.data_cfg.prepend_src_lang_tag:
src_lang_tag = self.LANG_TAG_TEMPLATE.format(self.src_langs[index])
src_lang_tag_idx = self.src_dict.index(src_lang_tag)
sample["src_lang_tag_idx"] = src_lang_tag_idx
target = None target = None
if self.tgt_texts is not None: if self.tgt_texts is not None:
...@@ -403,28 +443,54 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -403,28 +443,54 @@ class SpeechToTextDataset(FairseqDataset):
lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index]) lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index])
lang_tag_idx = self.tgt_dict.index(lang_tag) lang_tag_idx = self.tgt_dict.index(lang_tag)
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
sample["target"] = target
transcript = None aligned_target = None
if hasattr(self, "aligned_tgt_texts"):
tokenized = self.tokenize_text(self.aligned_tgt_texts[index])
aligned_target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=True
).long()
if self.data_cfg.prepend_tgt_lang_tag:
aligned_target = torch.cat((torch.LongTensor([tgt_lang_tag_idx]), aligned_target), 0)
sample["aligned_target"] = aligned_target
ctc_target = None
if hasattr(self, "ctc_tgt_texts"):
tokenized = self.tokenize_text(self.ctc_tgt_texts[index])
ctc_target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=True
).long()
if self.data_cfg.prepend_tgt_lang_tag:
ctc_target = torch.cat((torch.LongTensor([tgt_lang_tag_idx]), ctc_target), 0)
sample["ctc_target"] = ctc_target
transcript = None
if self.src_dict is not None and self.src_texts is not None and self.src_bpe_tokenizer is not None: if self.src_dict is not None and self.src_texts is not None and self.src_bpe_tokenizer is not None:
tokenized = self.tokenize_text(self.src_texts[index], True) tokenized = self.tokenize_text(self.src_texts[index], True)
transcript = self.src_dict.encode_line( transcript = self.src_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=True tokenized, add_if_not_exist=False, append_eos=True
).long() ).long()
return index, source, target, transcript
if self.data_cfg.prepend_src_lang_tag:
transcript = torch.cat((torch.LongTensor([src_lang_tag_idx]), transcript), 0)
sample["transcript"] = transcript
return sample
def __len__(self): def __len__(self):
return self.n_samples return self.n_samples
def collater(self, samples: List[Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]]) -> Dict: def collater(self, samples) -> Dict:
if len(samples) == 0: if len(samples) == 0:
return {} return {}
indices = torch.tensor([i for i, _, _, _ in samples], dtype=torch.long)
indices = torch.tensor([sample["indice"] for sample in samples], dtype=torch.long)
frames = _collate_frames( frames = _collate_frames(
[s for _, s, _, _ in samples], self.data_cfg.use_audio_input [sample["source"] for sample in samples], self.data_cfg.use_audio_input
) )
# sort samples by descending number of frames # sort samples by descending number of frames
n_frames = torch.tensor([s.size(0) for _, s, _, _ in samples], dtype=torch.long) n_frames = torch.tensor([sample["source"].size(0) for sample in samples], dtype=torch.long)
n_frames, order = n_frames.sort(descending=True) n_frames, order = n_frames.sort(descending=True)
indices = indices.index_select(0, order) indices = indices.index_select(0, order)
frames = frames.index_select(0, order) frames = frames.index_select(0, order)
...@@ -432,10 +498,15 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -432,10 +498,15 @@ class SpeechToTextDataset(FairseqDataset):
target, target_lengths = None, None target, target_lengths = None, None
prev_output_tokens = None prev_output_tokens = None
ntokens = None ntokens = None
transcript = None
transcript_lengths = None
transcript_ntokens = None
src_lang_idx = None
tgt_lang_idx = None tgt_lang_idx = None
if self.tgt_texts is not None: if self.tgt_texts is not None:
target = fairseq_data_utils.collate_tokens( target = fairseq_data_utils.collate_tokens(
[t for _, _, t, _ in samples], [sample["target"] for sample in samples],
self.tgt_dict.pad(), self.tgt_dict.pad(),
self.tgt_dict.eos(), self.tgt_dict.eos(),
left_pad=False, left_pad=False,
...@@ -443,30 +514,60 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -443,30 +514,60 @@ class SpeechToTextDataset(FairseqDataset):
) )
target = target.index_select(0, order) target = target.index_select(0, order)
if self.data_cfg.prepend_tgt_lang_tag_to_enc:
tgt_lang_idx = target[:, 0]
if not self.data_cfg.prepend_tgt_lang_tag:
target = target[:, 1:]
target_lengths = torch.tensor( target_lengths = torch.tensor(
[t.size(0) for _, _, t, _ in samples], dtype=torch.long [sample["target"].size(0) for sample in samples], dtype=torch.long
).index_select(0, order) ).index_select(0, order)
prev_output_tokens = fairseq_data_utils.collate_tokens( prev_output_tokens = fairseq_data_utils.collate_tokens(
[t for _, _, t, _ in samples], [sample["target"] for sample in samples],
self.tgt_dict.pad(), self.tgt_dict.pad(),
self.tgt_dict.eos(), self.tgt_dict.eos(),
left_pad=False, left_pad=False,
move_eos_to_beginning=True, move_eos_to_beginning=True,
) )
prev_output_tokens = prev_output_tokens.index_select(0, order) prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(t.size(0) for _, _, t, _ in samples) ntokens = sum(sample["target"].size(0) for sample in samples)
if self.data_cfg.prepend_tgt_lang_tag_to_enc and not self.data_cfg.prepend_tgt_lang_tag:
prev_output_tokens = torch.cat((prev_output_tokens[:, 0], prev_output_tokens[:, 2:]), dim=1) if "tgt_lang_tag_idx" in samples[0]:
ntokens -= 1 tgt_lang_idx = torch.tensor([sample["tgt_lang_tag_idx"] for sample in samples], dtype=torch.long)
tgt_lang_idx = tgt_lang_idx[order]
if "src_lang_tag_idx" in samples[0]:
src_lang_idx = torch.tensor([sample["src_lang_tag_idx"] for sample in samples], dtype=torch.long)
src_lang_idx = src_lang_idx[order]
aligned_target = None
aligned_target_lengths = None
if hasattr(self, "aligned_tgt_texts"):
aligned_target = fairseq_data_utils.collate_tokens(
[sample["aligned_target"] for sample in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
aligned_target = aligned_target.index_select(0, order)
aligned_target_lengths = torch.tensor(
[sample["aligned_target"].size(0) for sample in samples], dtype=torch.long
).index_select(0, order)
ctc_target = None
ctc_target_lengths = None
if hasattr(self, "ctc_tgt_texts"):
ctc_target = fairseq_data_utils.collate_tokens(
[sample["ctc_target"] for sample in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
ctc_target = ctc_target.index_select(0, order)
ctc_target_lengths = torch.tensor(
[sample["ctc_target"].size(0) for sample in samples], dtype=torch.long
).index_select(0, order)
if self.src_dict is not None and self.src_texts is not None: if self.src_dict is not None and self.src_texts is not None:
transcript_list = [sample["transcript"] for sample in samples]
transcript = fairseq_data_utils.collate_tokens( transcript = fairseq_data_utils.collate_tokens(
[t for _, _, _, t in samples], transcript_list,
self.src_dict.pad(), self.src_dict.pad(),
self.src_dict.eos(), self.src_dict.eos(),
left_pad=False, left_pad=False,
...@@ -474,13 +575,9 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -474,13 +575,9 @@ class SpeechToTextDataset(FairseqDataset):
) )
transcript = transcript.index_select(0, order) transcript = transcript.index_select(0, order)
transcript_lengths = torch.tensor( transcript_lengths = torch.tensor(
[t.size(0) for _, _, _, t in samples], dtype=torch.long [item.size(0) for item in transcript_list], dtype=torch.long
).index_select(0, order) ).index_select(0, order)
transcript_ntokens = sum(t.size(0) for _, _, _, t in samples) transcript_ntokens = sum(item.size(0) for item in transcript_list)
else:
transcript = None
transcript_lengths = None
transcript_ntokens = None
out = { out = {
"id": indices, "id": indices,
...@@ -488,6 +585,7 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -488,6 +585,7 @@ class SpeechToTextDataset(FairseqDataset):
"src_tokens": frames, "src_tokens": frames,
"src_lengths": n_frames, "src_lengths": n_frames,
"prev_output_tokens": prev_output_tokens, "prev_output_tokens": prev_output_tokens,
"src_lang_idx": src_lang_idx,
"tgt_lang_idx": tgt_lang_idx, "tgt_lang_idx": tgt_lang_idx,
}, },
"transcript": { "transcript": {
...@@ -500,6 +598,16 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -500,6 +598,16 @@ class SpeechToTextDataset(FairseqDataset):
"ntokens": ntokens, "ntokens": ntokens,
"nsentences": len(samples), "nsentences": len(samples),
} }
if aligned_target is not None:
out["aligned_target"] = {
"tokens": aligned_target,
"lengths": aligned_target_lengths,
}
if ctc_target is not None:
out["ctc_target"] = {
"tokens": ctc_target,
"lengths": ctc_target_lengths,
}
return out return out
...@@ -538,6 +646,8 @@ class SpeechToTextDatasetCreator(object): ...@@ -538,6 +646,8 @@ class SpeechToTextDatasetCreator(object):
# mandatory columns # mandatory columns
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames" KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
KEY_TGT_TEXT = "tgt_text" KEY_TGT_TEXT = "tgt_text"
KEY_ALIGNED_TGT_TEXT = "aligned_tgt_text"
KEY_CTC_TGT_TEXT = "ctc_tgt_text"
# optional columns # optional columns
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text" KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang" KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
...@@ -558,6 +668,7 @@ class SpeechToTextDatasetCreator(object): ...@@ -558,6 +668,7 @@ class SpeechToTextDatasetCreator(object):
src_bpe_tokenizer=None src_bpe_tokenizer=None
) -> SpeechToTextDataset: ) -> SpeechToTextDataset:
audio_paths, n_frames, src_texts, tgt_texts, ids = [], [], [], [], [] audio_paths, n_frames, src_texts, tgt_texts, ids = [], [], [], [], []
aligned_tgt_texts, ctc_tgt_texts = [], []
speakers, src_langs, tgt_langs = [], [], [] speakers, src_langs, tgt_langs = [], [], []
for s in samples: for s in samples:
ids.extend([ss[cls.KEY_ID] for ss in s]) ids.extend([ss[cls.KEY_ID] for ss in s])
...@@ -572,6 +683,18 @@ class SpeechToTextDatasetCreator(object): ...@@ -572,6 +683,18 @@ class SpeechToTextDatasetCreator(object):
speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]) speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s])
src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]) src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s])
tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]) tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s])
kwargs = dict()
if len(s) > 0 and cls.KEY_ALIGNED_TGT_TEXT not in s[0]:
aligned_tgt_texts = None
else:
aligned_tgt_texts.extend([ss[cls.KEY_ALIGNED_TGT_TEXT] for ss in s])
kwargs["aligned_tgt_texts"] = aligned_tgt_texts
if len(s) > 0 and cls.KEY_CTC_TGT_TEXT not in s[0]:
ctc_tgt_texts = None
else:
ctc_tgt_texts.extend([ss[cls.KEY_CTC_TGT_TEXT] for ss in s])
kwargs["ctc_tgt_texts"] = ctc_tgt_texts
return SpeechToTextDataset( return SpeechToTextDataset(
split_name, split_name,
is_train_split, is_train_split,
...@@ -588,7 +711,8 @@ class SpeechToTextDatasetCreator(object): ...@@ -588,7 +711,8 @@ class SpeechToTextDatasetCreator(object):
tgt_dict, tgt_dict,
pre_tokenizer, pre_tokenizer,
bpe_tokenizer, bpe_tokenizer,
src_bpe_tokenizer src_bpe_tokenizer,
kwargs
) )
@classmethod @classmethod
......
...@@ -1364,6 +1364,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1364,6 +1364,8 @@ class S2TTransformerEncoder(FairseqEncoder):
self.compression_stat = False self.compression_stat = False
self.log_flag_dict = dict()
# gather cosine similarity # gather cosine similarity
self.gather_cos_sim = getattr(args, "gather_cos_sim", False) self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
self.gather_cos_sim_dis = 2 self.gather_cos_sim_dis = 2
...@@ -1775,27 +1777,15 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1775,27 +1777,15 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.history is not None: if self.history is not None:
self.history.clean() self.history.clean()
src_lang_idx = kwargs.get("src_lang_idx", None)
tgt_lang_idx = kwargs.get("tgt_lang_idx", None) tgt_lang_idx = kwargs.get("tgt_lang_idx", None)
has_add_lang_tag = False
# (B, T, D) -> (T, B, D) # (B, T, D) -> (T, B, D)
x = src_tokens.transpose(0, 1) x = src_tokens.transpose(0, 1)
input_lengths = src_lengths input_lengths = src_lengths
org_bsz = x.size(1) org_bsz = x.size(1)
if ( if (
self.mixup
and layer_idx == mixup_layer
):
if tgt_lang_idx is not None:
assert self.embed_tokens is not None
tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0)
if mixup is not None:
pass
x = torch.cat((tgt_lang_embed, x), 0)
input_lengths += 1
has_add_lang_tag = True
if (
(self.training or self.mixup_infer) (self.training or self.mixup_infer)
and self.mixup and self.mixup
and layer_idx == mixup_layer and layer_idx == mixup_layer
...@@ -1815,15 +1805,26 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1815,15 +1805,26 @@ class S2TTransformerEncoder(FairseqEncoder):
x, input_lengths = self.subsample(x, input_lengths) x, input_lengths = self.subsample(x, input_lengths)
self.show_debug(x, "x after subsampling") self.show_debug(x, "x after subsampling")
#if tgt_lang_idx is not None and False: if src_lang_idx is not None:
if tgt_lang_idx is not None and not has_add_lang_tag: assert self.embed_tokens is not None
src_lang_embed = self.embed_tokens(src_lang_idx).unsqueeze(0)
x = torch.cat((src_lang_embed, x), 0)
input_lengths += 1
if "prepend_src_lang" not in self.log_flag_dict:
self.log_flag_dict["prepend_src_lang"] = True
logger.info("Prepend the source language tag into the encoder input.")
if tgt_lang_idx is not None:
assert self.embed_tokens is not None assert self.embed_tokens is not None
tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0) tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0)
if mixup is not None:
pass
x = torch.cat((tgt_lang_embed, x), 0) x = torch.cat((tgt_lang_embed, x), 0)
input_lengths += 1 input_lengths += 1
if "prepend_tgt_lang" not in self.log_flag_dict:
self.log_flag_dict["prepend_tgt_lang"] = True
logger.info("Prepend the target language tag into the encoder input.")
encoder_padding_mask = lengths_to_padding_mask(input_lengths) encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if encoder_padding_mask.size(1) < x.size(0): if encoder_padding_mask.size(1) < x.size(0):
bsz = encoder_padding_mask.size(0) bsz = encoder_padding_mask.size(0)
...@@ -2248,12 +2249,12 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2248,12 +2249,12 @@ class S2TTransformerEncoder(FairseqEncoder):
) )
if self.use_ctc and ctc_logit is None: if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, encoder_padding_mask, "Encoder output", is_top=True) ctc_logit = self.ctc(x, encoder_padding_mask, "Encoder CTC output", is_top=True)
self.show_debug(x, "x after ctc") self.show_debug(x, "x after ctc")
if self.use_xctc and xctc_logit is None: if self.use_xctc and xctc_logit is None:
xctc_logit = self.xctc( xctc_logit = self.xctc(
x, encoder_padding_mask, "Encoder output", is_top=True x, encoder_padding_mask, "Encoder XCTC output", is_top=True
) )
self.show_debug(x, "x after xctc") self.show_debug(x, "x after xctc")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论