Commit 81caa4ca by xuchen

add the speed perturb for the must-c dataset

parent 6a2f4065
...@@ -41,6 +41,7 @@ share_dict=1 ...@@ -41,6 +41,7 @@ share_dict=1
org_data_dir=/media/data/${dataset} org_data_dir=/media/data/${dataset}
data_dir=~/st/data/${dataset}/st data_dir=~/st/data/${dataset}/st
data_dir=~/st/data/${dataset}/st_perturb_2
test_subset=(tst-COMMON) test_subset=(tst-COMMON)
# exp # exp
...@@ -104,6 +105,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -104,6 +105,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir}/${lang} ]]; then if [[ ! -e ${data_dir}/${lang} ]]; then
mkdir -p ${data_dir}/${lang} mkdir -p ${data_dir}/${lang}
fi fi
source audio/bin/activate
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${org_data_dir} --data-root ${org_data_dir}
...@@ -118,6 +120,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -118,6 +120,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${org_data_dir} --data-root ${org_data_dir}
--output-root ${data_dir} --output-root ${data_dir}
--speed-perturb
--task st --task st
--add-src --add-src
--cmvn-type utterance --cmvn-type utterance
...@@ -133,6 +136,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -133,6 +136,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
deactivate
fi fi
data_dir=${data_dir}/${lang} data_dir=${data_dir}/${lang}
......
...@@ -13,6 +13,7 @@ from itertools import groupby ...@@ -13,6 +13,7 @@ from itertools import groupby
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Tuple from typing import Tuple
import string import string
import pickle
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -28,7 +29,6 @@ from examples.speech_to_text.data_utils import ( ...@@ -28,7 +29,6 @@ from examples.speech_to_text.data_utils import (
save_df_to_tsv, save_df_to_tsv,
cal_gcmvn_stats, cal_gcmvn_stats,
) )
from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
...@@ -46,14 +46,14 @@ class MUSTC(Dataset): ...@@ -46,14 +46,14 @@ class MUSTC(Dataset):
utterance_id utterance_id
""" """
SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"] SPLITS = ["dev", "tst-COMMON", "tst-HE", "train"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"] LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str) -> None: def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES assert split in self.SPLITS and lang in self.LANGUAGES
_root = Path(root) / f"en-{lang}" / "data" / split _root = Path(root) / f"en-{lang}" / "data" / split
wav_root, txt_root = _root / "wav", _root / "txt" wav_root, txt_root = _root / "wav", _root / "txt"
assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir() assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir(), (_root, wav_root, txt_root)
# Load audio segments # Load audio segments
try: try:
import yaml import yaml
...@@ -61,6 +61,8 @@ class MUSTC(Dataset): ...@@ -61,6 +61,8 @@ class MUSTC(Dataset):
print("Please install PyYAML to load the MuST-C YAML files") print("Please install PyYAML to load the MuST-C YAML files")
with open(txt_root / f"{split}.yaml") as f: with open(txt_root / f"{split}.yaml") as f:
segments = yaml.load(f, Loader=yaml.BaseLoader) segments = yaml.load(f, Loader=yaml.BaseLoader)
self.speed_perturb = [0.9, 1.0, 1.1] if speed_perturb and split.startswith("train") else None
# Load source and target utterances # Load source and target utterances
for _lang in ["en", lang]: for _lang in ["en", lang]:
with open(txt_root / f"{split}.{_lang}") as f: with open(txt_root / f"{split}.{_lang}") as f:
...@@ -72,7 +74,8 @@ class MUSTC(Dataset): ...@@ -72,7 +74,8 @@ class MUSTC(Dataset):
self.data = [] self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = wav_root / wav_filename wav_path = wav_root / wav_filename
sample_rate = torchaudio.info(wav_path.as_posix())[0].rate # sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
seg_group = sorted(_seg_group, key=lambda x: x["offset"]) seg_group = sorted(_seg_group, key=lambda x: x["offset"])
for i, segment in enumerate(seg_group): for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate) offset = int(float(segment["offset"]) * sample_rate)
...@@ -91,10 +94,52 @@ class MUSTC(Dataset): ...@@ -91,10 +94,52 @@ class MUSTC(Dataset):
) )
) )
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]: def __getitem__(self, n: int):
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames)
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id items = []
if self.speed_perturb is None:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, utt_id])
else:
for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id
if speed == 1.0:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
else:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
effects = [
["speed", f"{speed}"],
["rate", f"{sr}"]
]
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, sp_utt_id])
return items
def get_fast(self, n: int):
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
items = []
if self.speed_perturb is None:
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, utt_id])
else:
for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, sp_utt_id])
return items
def get_src_text(self):
src_text = []
for item in self.data:
src_text.append(item[4])
return src_text
def get_tgt_text(self):
tgt_text = []
for item in self.data:
tgt_text.append(item[5])
return tgt_text
def __len__(self) -> int: def __len__(self) -> int:
return len(self.data) return len(self.data)
...@@ -116,33 +161,77 @@ def process(args): ...@@ -116,33 +161,77 @@ def process(args):
feature_root = output_root / "fbank80" feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True) feature_root.mkdir(exist_ok=True)
zip_path = output_root / "fbank80.zip" zip_path = output_root / "fbank80.zip"
if args.overwrite or not Path.exists(zip_path): manifest_dict = {}
train_text = []
gen_feature_flag = False
if not Path.exists(zip_path):
gen_feature_flag = True
for split in MUSTC.SPLITS:
if not Path.exists(output_root / f"{split}_{args.task}.tsv"):
gen_feature_flag = True
break
if args.overwrite or gen_feature_flag:
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
print(f"Fetching split {split}...") print(f"Fetching split {split}...")
dataset = MUSTC(root.as_posix(), lang, split) dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb)
is_train_split = split.startswith("train")
print("Extracting log mel filter bank features...") print("Extracting log mel filter bank features...")
if split == 'train' and args.cmvn_type == "global": if is_train_split and args.cmvn_type == "global":
print("And estimating cepstral mean and variance stats...") print("And estimating cepstral mean and variance stats...")
gcmvn_feature_list = [] gcmvn_feature_list = []
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): manifest = {c: [] for c in MANIFEST_COLUMNS}
features = extract_fbank_features(waveform, sample_rate) if args.task == "st" and args.add_src:
manifest["src_text"] = []
np.save( for items in tqdm(dataset):
(feature_root / f"{utt_id}.npy").as_posix(), for item in items:
features # waveform, sample_rate, _, _, _, utt_id = item
) waveform, sr, src_utt, tgt_utt, speaker_id, utt_id = item
features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path))
# np.save(
# (feature_root / f"{utt_id}.npy").as_posix(),
# features
# )
manifest["id"].append(utt_id)
duration_ms = int(waveform.size(1) / sr * 1000)
# duration_ms = int(time_dict[utt_id] / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
if args.task == "st" and args.add_src:
manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id)
if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"):
if len(gcmvn_feature_list) < args.gcmvn_max_num:
gcmvn_feature_list.append(features)
if split == 'train' and args.cmvn_type == "global": if is_train_split and args.size != -1 and len(manifest["id"]) > args.size:
if len(gcmvn_feature_list) < args.gcmvn_max_num: break
gcmvn_feature_list.append(features)
if split == 'train' and args.cmvn_type == "global": if is_train_split:
if args.task == "st" and args.add_src and args.share:
train_text.extend(list(set(tuple(manifest["src_text"]))))
train_text.extend(dataset.get_tgt_text())
if is_train_split and args.cmvn_type == "global":
# Estimate and save cmv # Estimate and save cmv
stats = cal_gcmvn_stats(gcmvn_feature_list) stats = cal_gcmvn_stats(gcmvn_feature_list)
with open(output_root / "gcmvn.npz", "wb") as f: with open(output_root / "gcmvn.npz", "wb") as f:
np.savez(f, mean=stats["mean"], std=stats["std"]) np.savez(f, mean=stats["mean"], std=stats["std"])
manifest_dict[split] = manifest
# Pack features into ZIP # Pack features into ZIP
print("ZIPing features...") print("ZIPing features...")
create_zip(feature_root, zip_path) create_zip(feature_root, zip_path)
...@@ -159,33 +248,13 @@ def process(args): ...@@ -159,33 +248,13 @@ def process(args):
zip_manifest = get_zip_manifest(zip_path) zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest # Generate TSV manifest
print("Generating manifest...") print("Generating manifest...")
for split in MUSTC.SPLITS:
for split, manifest in manifest_dict.items():
is_train_split = split.startswith("train") is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src: for utt_id in manifest["id"]:
manifest["src_text"] = []
dataset = MUSTC(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id]) manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
if args.task == "st" and args.add_src:
manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id)
if is_train_split and args.size != -1 and len(manifest["id"]) > args.size:
break
if is_train_split:
if args.task == "st" and args.add_src and args.share:
train_text.extend(manifest["src_text"])
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest) df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split) df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv") save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
...@@ -207,7 +276,9 @@ def process(args): ...@@ -207,7 +276,9 @@ def process(args):
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
if split.startswith("train"): if split.startswith("train"):
dataset = MUSTC(args.data_root, lang, split) dataset = MUSTC(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in dataset: src_text = dataset.get_src_text()
tgt_text = dataset.get_tgt_text()
for src_utt, tgt_utt in zip(src_text, tgt_text):
if args.task == "st" and args.add_src and args.share: if args.task == "st" and args.add_src and args.share:
if args.lowercase_src: if args.lowercase_src:
src_utt = src_utt.lower() src_utt = src_utt.lower()
...@@ -215,6 +286,7 @@ def process(args): ...@@ -215,6 +286,7 @@ def process(args):
src_utt = src_utt.translate(None, string.punctuation) src_utt = src_utt.translate(None, string.punctuation)
train_text.append(src_utt) train_text.append(src_utt)
train_text.append(tgt_utt) train_text.append(tgt_utt)
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for t in train_text: for t in train_text:
f.write(t + "\n") f.write(t + "\n")
...@@ -242,8 +314,9 @@ def process(args): ...@@ -242,8 +314,9 @@ def process(args):
asr_spm_filename=asr_spm_filename, asr_spm_filename=asr_spm_filename,
share_src_and_tgt=True if args.task == "asr" else False share_src_and_tgt=True if args.task == "asr" else False
) )
# Clean up # Clean up
shutil.rmtree(feature_root) # shutil.rmtree(feature_root)
def process_joint(args): def process_joint(args):
...@@ -305,8 +378,11 @@ def main(): ...@@ -305,8 +378,11 @@ def main():
parser.add_argument("--vocab-size", default=8000, type=int) parser.add_argument("--vocab-size", default=8000, type=int)
parser.add_argument("--task", type=str, choices=["asr", "st"]) parser.add_argument("--task", type=str, choices=["asr", "st"])
parser.add_argument("--size", default=-1, type=int) parser.add_argument("--size", default=-1, type=int)
parser.add_argument("--speed-perturb", action="store_true", default=False,
help="apply speed perturbation on wave file")
parser.add_argument("--joint", action="store_true", help="") parser.add_argument("--joint", action="store_true", help="")
parser.add_argument("--share", action="store_true", help="share the transcription and translation") parser.add_argument("--share", action="store_true",
help="share the tokenizer and dictionary of the transcription and translation")
parser.add_argument("--add-src", action="store_true", help="add the src text for st task") parser.add_argument("--add-src", action="store_true", help="add the src text for st task")
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict") parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text") parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
......
...@@ -81,7 +81,6 @@ def _main(cfg: DictConfig, output_file): ...@@ -81,7 +81,6 @@ def _main(cfg: DictConfig, output_file):
# Load dataset splits # Load dataset splits
task = tasks.setup_task(cfg.task) task = tasks.setup_task(cfg.task)
# Set dictionaries # Set dictionaries
try: try:
src_dict = getattr(task, "source_dictionary", None) src_dict = getattr(task, "source_dictionary", None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论