Commit 71eae3fd by xuchen

modify the preprocessing of the s2t

parent a8105353
...@@ -10,6 +10,7 @@ from pathlib import Path ...@@ -10,6 +10,7 @@ from pathlib import Path
import shutil import shutil
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Optional, Tuple from typing import Optional, Tuple
import string
import pandas as pd import pandas as pd
import torchaudio import torchaudio
...@@ -54,7 +55,8 @@ class CoVoST(Dataset): ...@@ -54,7 +55,8 @@ class CoVoST(Dataset):
) )
VERSIONS = {2} VERSIONS = {2}
SPLITS = ["train", "dev", "test"] # SPLITS = ["train", "dev", "test"]
SPLITS = ["train"]
XX_EN_LANGUAGES = { XX_EN_LANGUAGES = {
1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"], 1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
...@@ -130,7 +132,12 @@ class CoVoST(Dataset): ...@@ -130,7 +132,12 @@ class CoVoST(Dataset):
cv_tsv_path = self.root / "validated.tsv" cv_tsv_path = self.root / "validated.tsv"
assert cv_tsv_path.is_file() assert cv_tsv_path.is_file()
cv_tsv = load_df_from_tsv(cv_tsv_path)
if self.no_translation:
print("No target translation.")
df = cv_tsv[["path", "sentence", "client_id"]]
else:
covost_url = self.COVOST_URL_TEMPLATE.format( covost_url = self.COVOST_URL_TEMPLATE.format(
src_lang=source_language, tgt_lang=target_language src_lang=source_language, tgt_lang=target_language
) )
...@@ -139,7 +146,6 @@ class CoVoST(Dataset): ...@@ -139,7 +146,6 @@ class CoVoST(Dataset):
download_url(covost_url, self.root.as_posix(), hash_value=None) download_url(covost_url, self.root.as_posix(), hash_value=None)
extract_archive(covost_archive.as_posix()) extract_archive(covost_archive.as_posix())
cv_tsv = load_df_from_tsv(cv_tsv_path)
covost_tsv = load_df_from_tsv( covost_tsv = load_df_from_tsv(
self.root / Path(covost_url).name.replace(".tar.gz", "") self.root / Path(covost_url).name.replace(".tar.gz", "")
) )
...@@ -153,20 +159,21 @@ class CoVoST(Dataset): ...@@ -153,20 +159,21 @@ class CoVoST(Dataset):
df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")] df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
else: else:
df = df[df["split"] == split] df = df[df["split"] == split]
data = df.to_dict(orient="index").items() data = df.to_dict(orient="index").items()
data = [v for k, v in sorted(data, key=lambda x: x[0])] data = [v for k, v in sorted(data, key=lambda x: x[0])]
self.data = [] self.data = []
for e in data: for e in data:
try: try:
path = self.root / "clips" / e["path"] # path = self.root / "clips" / e["path"]
_ = torchaudio.info(path.as_posix()) # _ = torchaudio.info(path.as_posix())
self.data.append(e) self.data.append(e)
except RuntimeError: except RuntimeError:
pass pass
def __getitem__( def __getitem__(
self, n: int self, n: int
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]: ) -> Tuple[Path, int, int, str, str, Optional[str], str, str]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
Args: Args:
...@@ -178,12 +185,14 @@ class CoVoST(Dataset): ...@@ -178,12 +185,14 @@ class CoVoST(Dataset):
""" """
data = self.data[n] data = self.data[n]
path = self.root / "clips" / data["path"] path = self.root / "clips" / data["path"]
waveform, sample_rate = torchaudio.load(path) info = torchaudio.info(path)
sample_rate = info.sample_rate
n_frames = info.num_frames
sentence = data["sentence"] sentence = data["sentence"]
translation = None if self.no_translation else data["translation"] translation = None if self.no_translation else data["translation"]
speaker_id = data["client_id"] speaker_id = data["client_id"]
_id = data["path"].replace(".mp3", "") _id = data["path"].replace(".mp3", "")
return waveform, sample_rate, sentence, translation, speaker_id, _id return path, sample_rate, n_frames, sentence, translation, speaker_id, _id
def __len__(self) -> int: def __len__(self) -> int:
return len(self.data) return len(self.data)
...@@ -191,23 +200,35 @@ class CoVoST(Dataset): ...@@ -191,23 +200,35 @@ class CoVoST(Dataset):
def process(args): def process(args):
root = Path(args.data_root).absolute() / args.src_lang root = Path(args.data_root).absolute() / args.src_lang
output_root = Path(args.output_root).absolute()
if args.tgt_lang is not None:
output_root = output_root / f"{args.src_lang}-{args.tgt_lang}"
else:
output_root = output_root / f"{args.src_lang}"
if not root.is_dir(): if not root.is_dir():
raise NotADirectoryError(f"{root} does not exist") raise NotADirectoryError(f"{root} does not exist")
zip_path = output_root / "fbank80.zip"
if not zip_path.exists():
# Extract features # Extract features
feature_root = root / "fbank80" feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True) feature_root.mkdir(exist_ok=True)
for split in CoVoST.SPLITS: for split in CoVoST.SPLITS:
print(f"Fetching split {split}...") print(f"Fetching split {split}...")
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang) dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
print("Extracting log mel filter bank features...") print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): for wav_path, sample_rate, _, _, _, _, utt_id in tqdm(dataset):
waveform, sample_rate = torchaudio.load(wav_path)
extract_fbank_features( extract_fbank_features(
waveform, sample_rate, feature_root / f"{utt_id}.npy" waveform, sample_rate, feature_root / f"{utt_id}.npy"
) )
# Pack features into ZIP # Pack features into ZIP
zip_path = root / "fbank80.zip"
print("ZIPing features...") print("ZIPing features...")
create_zip(feature_root, zip_path) create_zip(feature_root, zip_path)
# # Clean up
# shutil.rmtree(feature_root)
print("Fetching ZIP manifest...") print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path) zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest # Generate TSV manifest
...@@ -218,41 +239,74 @@ def process(args): ...@@ -218,41 +239,74 @@ def process(args):
task = f"st_{args.src_lang}_{args.tgt_lang}" task = f"st_{args.src_lang}_{args.tgt_lang}"
for split in CoVoST.SPLITS: for split in CoVoST.SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS} manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
manifest["src_text"] = []
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang) dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): for _, sr, n_frames, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest["id"].append(utt_id) manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id]) manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000) duration_ms = int(n_frames / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
src_utt = src_utt.replace(" ", "")
manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt) manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
if args.tgt_lang is not None:
manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id) manifest["speaker"].append(speaker_id)
is_train_split = split.startswith("train") is_train_split = split.startswith("train")
if is_train_split: if is_train_split:
if args.task == "st" and args.add_src and args.share:
train_text.extend(manifest["src_text"])
train_text.extend(manifest["tgt_text"]) train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest) df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split) df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, root / f"{split}_{task}.tsv") save_df_to_tsv(df, output_root / f"{split}_{task}.tsv")
# Generate vocab # Generate vocab
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}" spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{task}"
asr_spm_filename = None
if args.task == "st" and args.add_src:
if args.share:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model"
elif args.task == "asr":
if args.asr_prefix is not None:
spm_filename_prefix = args.asr_prefix
if args.st_spm_prefix is None:
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for t in train_text: for t in train_text:
f.write(t + "\n") f.write(t + "\n")
gen_vocab( gen_vocab(
Path(f.name), Path(f.name),
root / spm_filename_prefix, output_root / spm_filename_prefix,
args.vocab_type, args.vocab_type,
args.vocab_size args.vocab_size
) )
# Generate config YAML # Generate config YAML
gen_config_yaml( gen_config_yaml(
root, output_root,
spm_filename_prefix + ".model", spm_filename_prefix + ".model",
yaml_filename=f"config_{task}.yaml", yaml_filename=f"config_{task}.yaml",
specaugment_policy="lb", specaugment_policy="lb",
cmvn_type=args.cmvn_type,
asr_spm_filename=asr_spm_filename,
share_src_and_tgt=True if args.task == "asr" else False
) )
# Clean up
shutil.rmtree(feature_root)
def main(): def main():
...@@ -262,6 +316,10 @@ def main(): ...@@ -262,6 +316,10 @@ def main():
help="data root with sub-folders for each language <root>/<src_lang>" help="data root with sub-folders for each language <root>/<src_lang>"
) )
parser.add_argument( parser.add_argument(
"--output-root", "-o", required=True, type=str,
help="output root to save the results"
)
parser.add_argument(
"--vocab-type", "--vocab-type",
default="unigram", default="unigram",
required=True, required=True,
...@@ -270,7 +328,18 @@ def main(): ...@@ -270,7 +328,18 @@ def main():
), ),
parser.add_argument("--vocab-size", default=1000, type=int) parser.add_argument("--vocab-size", default=1000, type=int)
parser.add_argument("--src-lang", "-s", required=True, type=str) parser.add_argument("--src-lang", "-s", required=True, type=str)
parser.add_argument("--task", type=str, default="asr", choices=["asr", "st"])
parser.add_argument("--tgt-lang", "-t", type=str) parser.add_argument("--tgt-lang", "-t", type=str)
parser.add_argument("--share", action="store_true",
help="share the tokenizer and dictionary of the transcription and translation")
parser.add_argument("--add-src", action="store_true", help="add the src text for st task")
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
parser.add_argument("--st-spm-prefix", type=str, default=None, help="prefix of the existing st dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
parser.add_argument("--cmvn-type", default="utterance",
choices=["global", "utterance"],
help="The type of cepstral mean and variance normalization")
args = parser.parse_args() args = parser.parse_args()
process(args) process(args)
......
...@@ -50,10 +50,12 @@ class MUSTC(Dataset): ...@@ -50,10 +50,12 @@ class MUSTC(Dataset):
# SPLITS = ["train_debug", "dev"] # SPLITS = ["train_debug", "dev"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"] LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False) -> None: def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False, tokenizer: bool = False) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES assert split in self.SPLITS and lang in self.LANGUAGES
_root = Path(root) / f"en-{lang}" / "data" / split _root = Path(root) / f"en-{lang}" / "data" / split
wav_root, txt_root = _root / "wav", _root / "txt" wav_root, txt_root = _root / "wav", _root / "txt"
if tokenizer:
txt_root = _root / "txt.tok"
assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir(), (_root, wav_root, txt_root) assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir(), (_root, wav_root, txt_root)
# Load audio segments # Load audio segments
try: try:
...@@ -162,26 +164,23 @@ def process(args): ...@@ -162,26 +164,23 @@ def process(args):
else: else:
output_root = Path(args.output_root).absolute() / f"en-{lang}" output_root = Path(args.output_root).absolute() / f"en-{lang}"
# Extract features
if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
if args.speed_perturb: if args.speed_perturb:
zip_path = output_root / "fbank80_sp.zip" zip_path = output_root / "fbank80_sp.zip"
else: else:
zip_path = output_root / "fbank80.zip" zip_path = output_root / "fbank80.zip"
index = 0 index = 0
gen_feature_flag = False # Extract features
if not Path.exists(zip_path): if args.overwrite or not Path.exists(zip_path):
gen_feature_flag = True if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
if args.overwrite or gen_feature_flag:
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
print(f"Fetching split {split}...") print(f"Fetching split {split}...")
dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb) dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb, args.tokenizer)
is_train_split = split.startswith("train") is_train_split = split.startswith("train")
print("Extracting log mel filter bank features...") print("Extracting log mel filter bank features...")
if is_train_split and args.cmvn_type == "global": if is_train_split and args.cmvn_type == "global":
...@@ -193,7 +192,6 @@ def process(args): ...@@ -193,7 +192,6 @@ def process(args):
index += 1 index += 1
waveform, sr, _, _, _, _, utt_id = item waveform, sr, _, _, _, _, utt_id = item
if gen_feature_flag:
features_path = (feature_root / f"{utt_id}.npy").as_posix() features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path)) features = extract_fbank_features(waveform, sr, Path(features_path))
...@@ -214,6 +212,9 @@ def process(args): ...@@ -214,6 +212,9 @@ def process(args):
print("ZIPing features...") print("ZIPing features...")
create_zip(feature_root, zip_path) create_zip(feature_root, zip_path)
# # Clean up
# shutil.rmtree(feature_root)
gen_manifest_flag = False gen_manifest_flag = False
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
if not Path.exists(output_root / f"{split}_{args.task}.tsv"): if not Path.exists(output_root / f"{split}_{args.task}.tsv"):
...@@ -232,7 +233,7 @@ def process(args): ...@@ -232,7 +233,7 @@ def process(args):
manifest = {c: [] for c in MANIFEST_COLUMNS} manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src: if args.task == "st" and args.add_src:
manifest["src_text"] = [] manifest["src_text"] = []
dataset = MUSTC(args.data_root, lang, split, args.speed_perturb) dataset = MUSTC(args.data_root, lang, split, args.speed_perturb, args.tokenizer)
for idx in range(len(dataset)): for idx in range(len(dataset)):
items = dataset.get_fast(idx) items = dataset.get_fast(idx)
for item in items: for item in items:
...@@ -262,23 +263,11 @@ def process(args): ...@@ -262,23 +263,11 @@ def process(args):
df = filter_manifest_df(df, is_train_split=is_train_split) df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv") save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.task == "st" and args.add_src:
if args.share:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
asr_spm_filename = args.asr_prefix + ".model"
else:
asr_spm_filename = None
if len(train_text) == 0: if len(train_text) == 0:
print("Loading the training text to build dictionary...") print("Loading the training text to build dictionary...")
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
if split.startswith("train"): if split.startswith("train"):
dataset = MUSTC(args.data_root, lang, split) dataset = MUSTC(args.data_root, lang, split, args.speed_perturb, args.tokenizer)
src_text = dataset.get_src_text() src_text = dataset.get_src_text()
tgt_text = dataset.get_tgt_text() tgt_text = dataset.get_tgt_text()
for src_utt, tgt_utt in zip(src_text, tgt_text): for src_utt, tgt_utt in zip(src_text, tgt_text):
...@@ -292,6 +281,18 @@ def process(args): ...@@ -292,6 +281,18 @@ def process(args):
train_text.append(src_utt) train_text.append(src_utt)
train_text.append(tgt_utt) train_text.append(tgt_utt)
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.task == "st" and args.add_src:
if args.share:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
asr_spm_filename = args.asr_prefix + ".model"
else:
asr_spm_filename = None
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for t in train_text: for t in train_text:
f.write(t + "\n") f.write(t + "\n")
...@@ -320,9 +321,6 @@ def process(args): ...@@ -320,9 +321,6 @@ def process(args):
share_src_and_tgt=True if args.task == "asr" else False share_src_and_tgt=True if args.task == "asr" else False
) )
# Clean up
shutil.rmtree(feature_root)
def process_joint(args): def process_joint(args):
cur_root = Path(args.data_root) cur_root = Path(args.data_root)
...@@ -392,6 +390,7 @@ def main(): ...@@ -392,6 +390,7 @@ def main():
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict") parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text") parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text") parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
parser.add_argument("--tokenizer", action="store_true", help="use tokenizer txt")
parser.add_argument("--cmvn-type", default="utterance", parser.add_argument("--cmvn-type", default="utterance",
choices=["global", "utterance"], choices=["global", "utterance"],
help="The type of cepstral mean and variance normalization") help="The type of cepstral mean and variance normalization")
......
...@@ -11,9 +11,7 @@ from pathlib import Path ...@@ -11,9 +11,7 @@ from pathlib import Path
import shutil import shutil
from itertools import groupby from itertools import groupby
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Tuple
import string import string
import pickle
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -74,11 +72,11 @@ class ST_Dataset(Dataset): ...@@ -74,11 +72,11 @@ class ST_Dataset(Dataset):
sample_rate = torchaudio.info(wav_path.as_posix())[0].rate sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
except TypeError: except TypeError:
sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
seg_group = sorted(_seg_group, key=lambda x: x["offset"]) seg_group = sorted(_seg_group, key=lambda x: float(x["offset"]))
for i, segment in enumerate(seg_group): for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate) offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate) n_frames = int(float(segment["duration"]) * sample_rate)
_id = f"{wav_path.stem}_{i}" _id = f"{split}_{wav_path.stem}_{i}"
self.data.append( self.data.append(
( (
wav_path.as_posix(), wav_path.as_posix(),
...@@ -87,7 +85,7 @@ class ST_Dataset(Dataset): ...@@ -87,7 +85,7 @@ class ST_Dataset(Dataset):
sample_rate, sample_rate,
segment[src_lang], segment[src_lang],
segment[tgt_lang], segment[tgt_lang],
segment["speaker_id"], segment["speaker_id"] if "speaker_id" in segment else "spk1",
_id, _id,
) )
) )
...@@ -188,7 +186,7 @@ def process(args): ...@@ -188,7 +186,7 @@ def process(args):
for items in tqdm(dataset): for items in tqdm(dataset):
for item in items: for item in items:
index += 1 index += 1
waveform, sr, _, _, _, utt_id = item waveform, sr, _, _, _, _, utt_id = item
if gen_feature_flag: if gen_feature_flag:
features_path = (feature_root / f"{utt_id}.npy").as_posix() features_path = (feature_root / f"{utt_id}.npy").as_posix()
...@@ -259,16 +257,29 @@ def process(args): ...@@ -259,16 +257,29 @@ def process(args):
# Generate vocab # Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}" asr_spm_filename = None
if args.task == "st" and args.add_src: if args.task == "st" and args.add_src:
if args.share: if args.share:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share" spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model" asr_spm_filename = spm_filename_prefix + ".model"
else: else:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model" asr_spm_filename = args.asr_prefix + ".model"
elif args.task == "asr":
if args.asr_prefix is not None:
spm_filename_prefix = args.asr_prefix
else: else:
asr_spm_filename = None spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.st_spm_prefix is None:
if len(train_text) == 0: if len(train_text) == 0:
print("Loading the training text to build dictionary...") print("Loading the training text to build dictionary...")
for split in splits: for split in splits:
...@@ -294,6 +305,7 @@ def process(args): ...@@ -294,6 +305,7 @@ def process(args):
args.vocab_type, args.vocab_type,
args.vocab_size, args.vocab_size,
) )
# Generate config YAML # Generate config YAML
yaml_filename = f"config_{args.task}.yaml" yaml_filename = f"config_{args.task}.yaml"
if args.task == "st" and args.add_src and args.share: if args.task == "st" and args.add_src and args.share:
...@@ -324,7 +336,6 @@ def main(): ...@@ -324,7 +336,6 @@ def main():
parser.add_argument( parser.add_argument(
"--vocab-type", "--vocab-type",
default="unigram", default="unigram",
required=True,
type=str, type=str,
choices=["bpe", "unigram", "char"], choices=["bpe", "unigram", "char"],
), ),
...@@ -339,7 +350,8 @@ def main(): ...@@ -339,7 +350,8 @@ def main():
parser.add_argument("--share", action="store_true", parser.add_argument("--share", action="store_true",
help="share the tokenizer and dictionary of the transcription and translation") help="share the tokenizer and dictionary of the transcription and translation")
parser.add_argument("--add-src", action="store_true", help="add the src text for st task") parser.add_argument("--add-src", action="store_true", help="add the src text for st task")
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict") parser.add_argument("--asr-prefix", type=str, default=None, help="prefix of the asr dict")
parser.add_argument("--st-spm-prefix", type=str, default=None, help="prefix of the existing st dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text") parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text") parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
parser.add_argument("--cmvn-type", default="utterance", parser.add_argument("--cmvn-type", default="utterance",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论