Commit 306dd6fc by xuchen

fix the bug of writing and reading the source txt

parent ad46064b
......@@ -110,7 +110,8 @@ def process(args):
# Generate config YAML
gen_config_yaml(
out_root, spm_filename_prefix + ".model", specaugment_policy="ld",
asr_spm_filename=spm_filename_prefix + ".model"
asr_spm_filename=spm_filename_prefix + ".model",
share_src_and_tgt=True
)
# Clean up
shutil.rmtree(feature_root)
......
......@@ -234,7 +234,8 @@ def process(args):
cur_root / "gcmvn.npz" if args.cmvn_type == "global"
else None
),
asr_spm_filename=asr_spm_filename
asr_spm_filename=asr_spm_filename,
share_src_and_tgt=True if args.task == "asr" else False
)
# Clean up
shutil.rmtree(feature_root)
......
......@@ -85,7 +85,7 @@ class S2TDataConfig(object):
a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get("src_bpe_tokenizer", {"bpe": None})
return self.config.get("src_bpe_tokenizer", None)
@property
def prepend_tgt_lang_tag(self) -> bool:
......
......@@ -64,13 +64,17 @@ class SpeechToTextTask(LegacyFairseqTask):
)
src_dict = None
if getattr(data_cfg, "asr_vocab_filename", None):
dict_path = op.join(args.data, data_cfg.asr_vocab_filename)
if getattr(data_cfg, "share_src_and_tgt", False):
asr_vocab_filename = data_cfg.vocab_filename
else:
asr_vocab_filename = getattr(data_cfg, "asr_vocab_filename", None)
if asr_vocab_filename is not None:
dict_path = op.join(args.data, asr_vocab_filename)
if not op.isfile(dict_path):
raise FileNotFoundError(f"Dict not found: {dict_path}")
src_dict = Dictionary.load(dict_path)
logger.info(
f"asr dictionary size ({data_cfg.asr_vocab_filename}): " f"{len(src_dict):,}"
f"asr dictionary size ({asr_vocab_filename}): " f"{len(src_dict):,}"
)
if getattr(args, "train_subset", None) is not None:
......@@ -92,10 +96,14 @@ class SpeechToTextTask(LegacyFairseqTask):
is_train_split = split.startswith("train")
pre_tokenizer = self.build_tokenizer(self.args)
bpe_tokenizer = self.build_bpe(self.args)
if self.data_cfg.bpe_tokenizer != self.data_cfg.src_bpe_tokenizer:
if self.data_cfg.src_bpe_tokenizer is not None:
src_bpe_tokenizer = self.build_src_bpe(self.args)
else:
src_bpe_tokenizer = bpe_tokenizer
# if self.data_cfg.share_src_and_tgt:
# src_bpe_tokenizer = bpe_tokenizer
# else:
# src_bpe_tokenizer = None
self.datasets[split] = SpeechToTextDatasetCreator.from_tsv(
self.args.data,
self.data_cfg,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论