Commit 81caa4ca by xuchen

add the speed perturb for the must-c dataset

parent 6a2f4065
......@@ -41,6 +41,7 @@ share_dict=1
org_data_dir=/media/data/${dataset}
data_dir=~/st/data/${dataset}/st
data_dir=~/st/data/${dataset}/st_perturb_2
test_subset=(tst-COMMON)
# exp
......@@ -104,6 +105,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir}/${lang} ]]; then
mkdir -p ${data_dir}/${lang}
fi
source audio/bin/activate
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${org_data_dir}
......@@ -118,6 +120,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${org_data_dir}
--output-root ${data_dir}
--speed-perturb
--task st
--add-src
--cmvn-type utterance
......@@ -133,6 +136,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
deactivate
fi
data_dir=${data_dir}/${lang}
......
......@@ -13,6 +13,7 @@ from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
import string
import pickle
import numpy as np
import pandas as pd
......@@ -28,7 +29,6 @@ from examples.speech_to_text.data_utils import (
save_df_to_tsv,
cal_gcmvn_stats,
)
from torch import Tensor
from torch.utils.data import Dataset
from tqdm import tqdm
......@@ -46,14 +46,14 @@ class MUSTC(Dataset):
utterance_id
"""
SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
SPLITS = ["dev", "tst-COMMON", "tst-HE", "train"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str) -> None:
def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES
_root = Path(root) / f"en-{lang}" / "data" / split
wav_root, txt_root = _root / "wav", _root / "txt"
assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir()
assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir(), (_root, wav_root, txt_root)
# Load audio segments
try:
import yaml
......@@ -61,6 +61,8 @@ class MUSTC(Dataset):
print("Please install PyYAML to load the MuST-C YAML files")
with open(txt_root / f"{split}.yaml") as f:
segments = yaml.load(f, Loader=yaml.BaseLoader)
self.speed_perturb = [0.9, 1.0, 1.1] if speed_perturb and split.startswith("train") else None
# Load source and target utterances
for _lang in ["en", lang]:
with open(txt_root / f"{split}.{_lang}") as f:
......@@ -72,7 +74,8 @@ class MUSTC(Dataset):
self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = wav_root / wav_filename
sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
# sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
seg_group = sorted(_seg_group, key=lambda x: x["offset"])
for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate)
......@@ -91,10 +94,52 @@ class MUSTC(Dataset):
)
)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
def __getitem__(self, n: int):
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
items = []
if self.speed_perturb is None:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, utt_id])
else:
for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id
if speed == 1.0:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
else:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
effects = [
["speed", f"{speed}"],
["rate", f"{sr}"]
]
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, sp_utt_id])
return items
def get_fast(self, n: int):
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames)
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
items = []
if self.speed_perturb is None:
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, utt_id])
else:
for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, sp_utt_id])
return items
def get_src_text(self):
src_text = []
for item in self.data:
src_text.append(item[4])
return src_text
def get_tgt_text(self):
tgt_text = []
for item in self.data:
tgt_text.append(item[5])
return tgt_text
def __len__(self) -> int:
return len(self.data)
......@@ -116,33 +161,77 @@ def process(args):
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
zip_path = output_root / "fbank80.zip"
if args.overwrite or not Path.exists(zip_path):
manifest_dict = {}
train_text = []
gen_feature_flag = False
if not Path.exists(zip_path):
gen_feature_flag = True
for split in MUSTC.SPLITS:
if not Path.exists(output_root / f"{split}_{args.task}.tsv"):
gen_feature_flag = True
break
if args.overwrite or gen_feature_flag:
for split in MUSTC.SPLITS:
print(f"Fetching split {split}...")
dataset = MUSTC(root.as_posix(), lang, split)
dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb)
is_train_split = split.startswith("train")
print("Extracting log mel filter bank features...")
if split == 'train' and args.cmvn_type == "global":
if is_train_split and args.cmvn_type == "global":
print("And estimating cepstral mean and variance stats...")
gcmvn_feature_list = []
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
features = extract_fbank_features(waveform, sample_rate)
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
manifest["src_text"] = []
np.save(
(feature_root / f"{utt_id}.npy").as_posix(),
features
)
for items in tqdm(dataset):
for item in items:
# waveform, sample_rate, _, _, _, utt_id = item
waveform, sr, src_utt, tgt_utt, speaker_id, utt_id = item
features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path))
# np.save(
# (feature_root / f"{utt_id}.npy").as_posix(),
# features
# )
manifest["id"].append(utt_id)
duration_ms = int(waveform.size(1) / sr * 1000)
# duration_ms = int(time_dict[utt_id] / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
if args.task == "st" and args.add_src:
manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id)
if split == 'train' and args.cmvn_type == "global":
if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"):
if len(gcmvn_feature_list) < args.gcmvn_max_num:
gcmvn_feature_list.append(features)
if split == 'train' and args.cmvn_type == "global":
if is_train_split and args.size != -1 and len(manifest["id"]) > args.size:
break
if is_train_split:
if args.task == "st" and args.add_src and args.share:
train_text.extend(list(set(tuple(manifest["src_text"]))))
train_text.extend(dataset.get_tgt_text())
if is_train_split and args.cmvn_type == "global":
# Estimate and save cmv
stats = cal_gcmvn_stats(gcmvn_feature_list)
with open(output_root / "gcmvn.npz", "wb") as f:
np.savez(f, mean=stats["mean"], std=stats["std"])
manifest_dict[split] = manifest
# Pack features into ZIP
print("ZIPing features...")
create_zip(feature_root, zip_path)
......@@ -159,33 +248,13 @@ def process(args):
zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest
print("Generating manifest...")
for split in MUSTC.SPLITS:
for split, manifest in manifest_dict.items():
is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
manifest["src_text"] = []
dataset = MUSTC(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest["id"].append(utt_id)
for utt_id in manifest["id"]:
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
if args.task == "st" and args.add_src:
manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id)
if is_train_split and args.size != -1 and len(manifest["id"]) > args.size:
break
if is_train_split:
if args.task == "st" and args.add_src and args.share:
train_text.extend(manifest["src_text"])
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
......@@ -207,7 +276,9 @@ def process(args):
for split in MUSTC.SPLITS:
if split.startswith("train"):
dataset = MUSTC(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in dataset:
src_text = dataset.get_src_text()
tgt_text = dataset.get_tgt_text()
for src_utt, tgt_utt in zip(src_text, tgt_text):
if args.task == "st" and args.add_src and args.share:
if args.lowercase_src:
src_utt = src_utt.lower()
......@@ -215,6 +286,7 @@ def process(args):
src_utt = src_utt.translate(None, string.punctuation)
train_text.append(src_utt)
train_text.append(tgt_utt)
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
......@@ -242,8 +314,9 @@ def process(args):
asr_spm_filename=asr_spm_filename,
share_src_and_tgt=True if args.task == "asr" else False
)
# Clean up
shutil.rmtree(feature_root)
# shutil.rmtree(feature_root)
def process_joint(args):
......@@ -305,8 +378,11 @@ def main():
parser.add_argument("--vocab-size", default=8000, type=int)
parser.add_argument("--task", type=str, choices=["asr", "st"])
parser.add_argument("--size", default=-1, type=int)
parser.add_argument("--speed-perturb", action="store_true", default=False,
help="apply speed perturbation on wave file")
parser.add_argument("--joint", action="store_true", help="")
parser.add_argument("--share", action="store_true", help="share the transcription and translation")
parser.add_argument("--share", action="store_true",
help="share the tokenizer and dictionary of the transcription and translation")
parser.add_argument("--add-src", action="store_true", help="add the src text for st task")
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
......
......@@ -81,7 +81,6 @@ def _main(cfg: DictConfig, output_file):
# Load dataset splits
task = tasks.setup_task(cfg.task)
# Set dictionaries
try:
src_dict = getattr(task, "source_dictionary", None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论