Commit 5d84c743 by xuchen

enable the additional label for CTC learning

parent e40eac14
...@@ -190,6 +190,7 @@ def get_features_or_waveform_from_audio( ...@@ -190,6 +190,7 @@ def get_features_or_waveform_from_audio(
else get_fbank(path, offset=offset, size=size) else get_fbank(path, offset=offset, size=size)
return features_or_waveform return features_or_waveform
def get_features_or_waveform(path: str, need_waveform=False): def get_features_or_waveform(path: str, need_waveform=False):
"""Get speech features from .npy file or waveform from .wav/.flac file. """Get speech features from .npy file or waveform from .wav/.flac file.
The file may be inside an uncompressed ZIP file and is accessed via byte The file may be inside an uncompressed ZIP file and is accessed via byte
......
...@@ -50,12 +50,19 @@ class SpeechToTextTask(LegacyFairseqTask): ...@@ -50,12 +50,19 @@ class SpeechToTextTask(LegacyFairseqTask):
metavar="N", metavar="N",
help="max number of tokens in the target sequence", help="max number of tokens in the target sequence",
) )
parser.add_argument(
"--use-aligned-text",
default=False,
action="store_true",
help="use aligned text for loss",
)
def __init__(self, args, tgt_dict, src_dict=None): def __init__(self, args, tgt_dict, src_dict=None):
super().__init__(args) super().__init__(args)
self.src_dict = src_dict self.src_dict = src_dict
self.tgt_dict = tgt_dict self.tgt_dict = tgt_dict
self.speed_perturb = args.speed_perturb self.speed_perturb = args.speed_perturb
self.use_aligned_text = getattr(args, "use_aligned_text", False)
self.data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml)) self.data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml))
self.data_cfg.config["speed_perturb"] = self.speed_perturb self.data_cfg.config["speed_perturb"] = self.speed_perturb
...@@ -111,7 +118,11 @@ class SpeechToTextTask(LegacyFairseqTask): ...@@ -111,7 +118,11 @@ class SpeechToTextTask(LegacyFairseqTask):
# src_bpe_tokenizer = bpe_tokenizer # src_bpe_tokenizer = bpe_tokenizer
# else: # else:
# src_bpe_tokenizer = None # src_bpe_tokenizer = None
self.datasets[split] = SpeechToTextDatasetCreator.from_tsv( if self.use_aligned_text:
from fairseq.data.audio.aligned_speech_to_text_dataset import SpeechToTextDatasetCreator as Creator
else:
from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDatasetCreator as Creator
self.datasets[split] = Creator.from_tsv(
self.args.data, self.args.data,
self.data_cfg, self.data_cfg,
split, split,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论