Commit 71eae3fd by xuchen

modify the preprocessing of the s2t

parent a8105353
......@@ -10,6 +10,7 @@ from pathlib import Path
import shutil
from tempfile import NamedTemporaryFile
from typing import Optional, Tuple
import string
import pandas as pd
import torchaudio
......@@ -54,7 +55,8 @@ class CoVoST(Dataset):
)
VERSIONS = {2}
SPLITS = ["train", "dev", "test"]
# SPLITS = ["train", "dev", "test"]
SPLITS = ["train"]
XX_EN_LANGUAGES = {
1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
......@@ -130,7 +132,12 @@ class CoVoST(Dataset):
cv_tsv_path = self.root / "validated.tsv"
assert cv_tsv_path.is_file()
cv_tsv = load_df_from_tsv(cv_tsv_path)
if self.no_translation:
print("No target translation.")
df = cv_tsv[["path", "sentence", "client_id"]]
else:
covost_url = self.COVOST_URL_TEMPLATE.format(
src_lang=source_language, tgt_lang=target_language
)
......@@ -139,7 +146,6 @@ class CoVoST(Dataset):
download_url(covost_url, self.root.as_posix(), hash_value=None)
extract_archive(covost_archive.as_posix())
cv_tsv = load_df_from_tsv(cv_tsv_path)
covost_tsv = load_df_from_tsv(
self.root / Path(covost_url).name.replace(".tar.gz", "")
)
......@@ -153,20 +159,21 @@ class CoVoST(Dataset):
df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
else:
df = df[df["split"] == split]
data = df.to_dict(orient="index").items()
data = [v for k, v in sorted(data, key=lambda x: x[0])]
self.data = []
for e in data:
try:
path = self.root / "clips" / e["path"]
_ = torchaudio.info(path.as_posix())
# path = self.root / "clips" / e["path"]
# _ = torchaudio.info(path.as_posix())
self.data.append(e)
except RuntimeError:
pass
def __getitem__(
self, n: int
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
) -> Tuple[Path, int, int, str, str, Optional[str], str, str]:
"""Load the n-th sample from the dataset.
Args:
......@@ -178,12 +185,14 @@ class CoVoST(Dataset):
"""
data = self.data[n]
path = self.root / "clips" / data["path"]
waveform, sample_rate = torchaudio.load(path)
info = torchaudio.info(path)
sample_rate = info.sample_rate
n_frames = info.num_frames
sentence = data["sentence"]
translation = None if self.no_translation else data["translation"]
speaker_id = data["client_id"]
_id = data["path"].replace(".mp3", "")
return waveform, sample_rate, sentence, translation, speaker_id, _id
return path, sample_rate, n_frames, sentence, translation, speaker_id, _id
def __len__(self) -> int:
return len(self.data)
......@@ -191,23 +200,35 @@ class CoVoST(Dataset):
def process(args):
root = Path(args.data_root).absolute() / args.src_lang
output_root = Path(args.output_root).absolute()
if args.tgt_lang is not None:
output_root = output_root / f"{args.src_lang}-{args.tgt_lang}"
else:
output_root = output_root / f"{args.src_lang}"
if not root.is_dir():
raise NotADirectoryError(f"{root} does not exist")
zip_path = output_root / "fbank80.zip"
if not zip_path.exists():
# Extract features
feature_root = root / "fbank80"
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
for split in CoVoST.SPLITS:
print(f"Fetching split {split}...")
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
for wav_path, sample_rate, _, _, _, _, utt_id in tqdm(dataset):
waveform, sample_rate = torchaudio.load(wav_path)
extract_fbank_features(
waveform, sample_rate, feature_root / f"{utt_id}.npy"
)
# Pack features into ZIP
zip_path = root / "fbank80.zip"
print("ZIPing features...")
create_zip(feature_root, zip_path)
# # Clean up
# shutil.rmtree(feature_root)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest
......@@ -218,41 +239,74 @@ def process(args):
task = f"st_{args.src_lang}_{args.tgt_lang}"
for split in CoVoST.SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
manifest["src_text"] = []
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
for _, sr, n_frames, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
duration_ms = int(n_frames / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
src_utt = src_utt.replace(" ", "")
manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
if args.tgt_lang is not None:
manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id)
is_train_split = split.startswith("train")
if is_train_split:
if args.task == "st" and args.add_src and args.share:
train_text.extend(manifest["src_text"])
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, root / f"{split}_{task}.tsv")
save_df_to_tsv(df, output_root / f"{split}_{task}.tsv")
# Generate vocab
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{task}"
asr_spm_filename = None
if args.task == "st" and args.add_src:
if args.share:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model"
elif args.task == "asr":
if args.asr_prefix is not None:
spm_filename_prefix = args.asr_prefix
if args.st_spm_prefix is None:
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
root / spm_filename_prefix,
output_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size
)
# Generate config YAML
gen_config_yaml(
root,
output_root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{task}.yaml",
specaugment_policy="lb",
cmvn_type=args.cmvn_type,
asr_spm_filename=asr_spm_filename,
share_src_and_tgt=True if args.task == "asr" else False
)
# Clean up
shutil.rmtree(feature_root)
def main():
......@@ -262,6 +316,10 @@ def main():
help="data root with sub-folders for each language <root>/<src_lang>"
)
parser.add_argument(
"--output-root", "-o", required=True, type=str,
help="output root to save the results"
)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
......@@ -270,7 +328,18 @@ def main():
),
parser.add_argument("--vocab-size", default=1000, type=int)
parser.add_argument("--src-lang", "-s", required=True, type=str)
parser.add_argument("--task", type=str, default="asr", choices=["asr", "st"])
parser.add_argument("--tgt-lang", "-t", type=str)
parser.add_argument("--share", action="store_true",
help="share the tokenizer and dictionary of the transcription and translation")
parser.add_argument("--add-src", action="store_true", help="add the src text for st task")
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
parser.add_argument("--st-spm-prefix", type=str, default=None, help="prefix of the existing st dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
parser.add_argument("--cmvn-type", default="utterance",
choices=["global", "utterance"],
help="The type of cepstral mean and variance normalization")
args = parser.parse_args()
process(args)
......
......@@ -50,10 +50,12 @@ class MUSTC(Dataset):
# SPLITS = ["train_debug", "dev"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False) -> None:
def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False, tokenizer: bool = False) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES
_root = Path(root) / f"en-{lang}" / "data" / split
wav_root, txt_root = _root / "wav", _root / "txt"
if tokenizer:
txt_root = _root / "txt.tok"
assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir(), (_root, wav_root, txt_root)
# Load audio segments
try:
......@@ -162,26 +164,23 @@ def process(args):
else:
output_root = Path(args.output_root).absolute() / f"en-{lang}"
# Extract features
if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
if args.speed_perturb:
zip_path = output_root / "fbank80_sp.zip"
else:
zip_path = output_root / "fbank80.zip"
index = 0
gen_feature_flag = False
if not Path.exists(zip_path):
gen_feature_flag = True
# Extract features
if args.overwrite or not Path.exists(zip_path):
if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
if args.overwrite or gen_feature_flag:
for split in MUSTC.SPLITS:
print(f"Fetching split {split}...")
dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb)
dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb, args.tokenizer)
is_train_split = split.startswith("train")
print("Extracting log mel filter bank features...")
if is_train_split and args.cmvn_type == "global":
......@@ -193,7 +192,6 @@ def process(args):
index += 1
waveform, sr, _, _, _, _, utt_id = item
if gen_feature_flag:
features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path))
......@@ -214,6 +212,9 @@ def process(args):
print("ZIPing features...")
create_zip(feature_root, zip_path)
# # Clean up
# shutil.rmtree(feature_root)
gen_manifest_flag = False
for split in MUSTC.SPLITS:
if not Path.exists(output_root / f"{split}_{args.task}.tsv"):
......@@ -232,7 +233,7 @@ def process(args):
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
manifest["src_text"] = []
dataset = MUSTC(args.data_root, lang, split, args.speed_perturb)
dataset = MUSTC(args.data_root, lang, split, args.speed_perturb, args.tokenizer)
for idx in range(len(dataset)):
items = dataset.get_fast(idx)
for item in items:
......@@ -262,23 +263,11 @@ def process(args):
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.task == "st" and args.add_src:
if args.share:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
asr_spm_filename = args.asr_prefix + ".model"
else:
asr_spm_filename = None
if len(train_text) == 0:
print("Loading the training text to build dictionary...")
for split in MUSTC.SPLITS:
if split.startswith("train"):
dataset = MUSTC(args.data_root, lang, split)
dataset = MUSTC(args.data_root, lang, split, args.speed_perturb, args.tokenizer)
src_text = dataset.get_src_text()
tgt_text = dataset.get_tgt_text()
for src_utt, tgt_utt in zip(src_text, tgt_text):
......@@ -292,6 +281,18 @@ def process(args):
train_text.append(src_utt)
train_text.append(tgt_utt)
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.task == "st" and args.add_src:
if args.share:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
asr_spm_filename = args.asr_prefix + ".model"
else:
asr_spm_filename = None
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
......@@ -320,9 +321,6 @@ def process(args):
share_src_and_tgt=True if args.task == "asr" else False
)
# Clean up
shutil.rmtree(feature_root)
def process_joint(args):
cur_root = Path(args.data_root)
......@@ -392,6 +390,7 @@ def main():
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
parser.add_argument("--tokenizer", action="store_true", help="use tokenizer txt")
parser.add_argument("--cmvn-type", default="utterance",
choices=["global", "utterance"],
help="The type of cepstral mean and variance normalization")
......
......@@ -11,9 +11,7 @@ from pathlib import Path
import shutil
from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
import string
import pickle
import numpy as np
import pandas as pd
......@@ -74,11 +72,11 @@ class ST_Dataset(Dataset):
sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
except TypeError:
sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
seg_group = sorted(_seg_group, key=lambda x: x["offset"])
seg_group = sorted(_seg_group, key=lambda x: float(x["offset"]))
for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate)
_id = f"{wav_path.stem}_{i}"
_id = f"{split}_{wav_path.stem}_{i}"
self.data.append(
(
wav_path.as_posix(),
......@@ -87,7 +85,7 @@ class ST_Dataset(Dataset):
sample_rate,
segment[src_lang],
segment[tgt_lang],
segment["speaker_id"],
segment["speaker_id"] if "speaker_id" in segment else "spk1",
_id,
)
)
......@@ -188,7 +186,7 @@ def process(args):
for items in tqdm(dataset):
for item in items:
index += 1
waveform, sr, _, _, _, utt_id = item
waveform, sr, _, _, _, _, utt_id = item
if gen_feature_flag:
features_path = (feature_root / f"{utt_id}.npy").as_posix()
......@@ -259,16 +257,29 @@ def process(args):
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
asr_spm_filename = None
if args.task == "st" and args.add_src:
if args.share:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model"
elif args.task == "asr":
if args.asr_prefix is not None:
spm_filename_prefix = args.asr_prefix
else:
asr_spm_filename = None
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.st_spm_prefix is None:
if len(train_text) == 0:
print("Loading the training text to build dictionary...")
for split in splits:
......@@ -294,6 +305,7 @@ def process(args):
args.vocab_type,
args.vocab_size,
)
# Generate config YAML
yaml_filename = f"config_{args.task}.yaml"
if args.task == "st" and args.add_src and args.share:
......@@ -324,7 +336,6 @@ def main():
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
......@@ -339,7 +350,8 @@ def main():
parser.add_argument("--share", action="store_true",
help="share the tokenizer and dictionary of the transcription and translation")
parser.add_argument("--add-src", action="store_true", help="add the src text for st task")
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
parser.add_argument("--asr-prefix", type=str, default=None, help="prefix of the asr dict")
parser.add_argument("--st-spm-prefix", type=str, default=None, help="prefix of the existing st dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
parser.add_argument("--cmvn-type", default="utterance",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论