Commit ce4936bd by xuchen

optimize the speed perturb

parent b78c7894
...@@ -36,6 +36,7 @@ dataset=mustc ...@@ -36,6 +36,7 @@ dataset=mustc
task=speech_to_text task=speech_to_text
vocab_type=unigram vocab_type=unigram
vocab_size=5000 vocab_size=5000
speed_perturb=1
org_data_dir=/media/data/${dataset} org_data_dir=/media/data/${dataset}
data_dir=~/st/data/${dataset}/asr data_dir=~/st/data/${dataset}/asr
...@@ -80,8 +81,14 @@ if [[ -z ${exp_name} ]]; then ...@@ -80,8 +81,14 @@ if [[ -z ${exp_name} ]]; then
if [[ -n ${extra_tag} ]]; then if [[ -n ${extra_tag} ]]; then
exp_name=${exp_name}_${extra_tag} exp_name=${exp_name}_${extra_tag}
fi fi
if [[ ${speed_perturb} -eq 1 ]]; then
exp_name=sp_${exp_name}
fi
fi fi
if [[ ${speed_perturb} -eq 1 ]]; then
data_dir=${data_dir}_sp
fi
model_dir=$root_dir/../checkpoints/$dataset/asr/${exp_name} model_dir=$root_dir/../checkpoints/$dataset/asr/${exp_name}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
...@@ -96,6 +103,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -96,6 +103,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir}/${lang} ]]; then if [[ ! -e ${data_dir}/${lang} ]]; then
mkdir -p ${data_dir}/${lang} mkdir -p ${data_dir}/${lang}
fi fi
source ~/tools/audio/bin/activate
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${org_data_dir} --data-root ${org_data_dir}
...@@ -103,6 +111,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -103,6 +111,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--task asr --task asr
--vocab-type ${vocab_type} --vocab-type ${vocab_type}
--vocab-size ${vocab_size}" --vocab-size ${vocab_size}"
if [[ ${speed_perturb} -eq 1 ]]; then
cmd="$cmd
--speed-perturb"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval $cmd [[ $eval -eq 1 ]] && eval $cmd
fi fi
...@@ -138,7 +150,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -138,7 +150,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
${data_dir} ${data_dir}
--config-yaml ${data_config} --config-yaml ${data_config}
--train-config ${train_config} --train-config ${train_config}
--task speech_to_text --task ${task}
--max-tokens ${max_tokens} --max-tokens ${max_tokens}
--update-freq ${update_freq} --update-freq ${update_freq}
--log-interval 100 --log-interval 100
......
...@@ -11,8 +11,9 @@ log-interval: 100 ...@@ -11,8 +11,9 @@ log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
#load-params: # load-params:
#load-pretrained-encoder-from: load-pretrained-encoder-from:
load-pretrained-decoder-from:
arch: s2t_transformer_s arch: s2t_transformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
......
...@@ -38,10 +38,10 @@ vocab_type=unigram ...@@ -38,10 +38,10 @@ vocab_type=unigram
asr_vocab_size=5000 asr_vocab_size=5000
vocab_size=10000 vocab_size=10000
share_dict=1 share_dict=1
speed_perturb=1
org_data_dir=/media/data/${dataset} org_data_dir=/media/data/${dataset}
data_dir=~/st/data/${dataset}/st data_dir=~/st/data/${dataset}/st
data_dir=~/st/data/${dataset}/st_perturb_2
test_subset=(tst-COMMON) test_subset=(tst-COMMON)
# exp # exp
...@@ -89,8 +89,14 @@ if [[ -z ${exp_name} ]]; then ...@@ -89,8 +89,14 @@ if [[ -z ${exp_name} ]]; then
if [[ -n ${extra_tag} ]]; then if [[ -n ${extra_tag} ]]; then
exp_name=${exp_name}_${extra_tag} exp_name=${exp_name}_${extra_tag}
fi fi
if [[ ${speed_perturb} -eq 1 ]]; then
exp_name=sp_${exp_name}
fi
fi fi
if [[ ${speed_perturb} -eq 1 ]]; then
data_dir=${data_dir}_sp
fi
model_dir=$root_dir/../checkpoints/$dataset/st/${exp_name} model_dir=$root_dir/../checkpoints/$dataset/st/${exp_name}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
...@@ -105,7 +111,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -105,7 +111,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir}/${lang} ]]; then if [[ ! -e ${data_dir}/${lang} ]]; then
mkdir -p ${data_dir}/${lang} mkdir -p ${data_dir}/${lang}
fi fi
source audio/bin/activate source ~/tools/audio/bin/activate
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${org_data_dir} --data-root ${org_data_dir}
...@@ -113,6 +119,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -113,6 +119,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--task asr --task asr
--vocab-type ${vocab_type} --vocab-type ${vocab_type}
--vocab-size ${asr_vocab_size}" --vocab-size ${asr_vocab_size}"
if [[ ${speed_perturb} -eq 1 ]]; then
cmd="$cmd
--speed-perturb"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 && ${share_dict} -ne 1 ]] && eval $cmd [[ $eval -eq 1 && ${share_dict} -ne 1 ]] && eval $cmd
...@@ -120,7 +130,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -120,7 +130,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${org_data_dir} --data-root ${org_data_dir}
--output-root ${data_dir} --output-root ${data_dir}
--speed-perturb
--task st --task st
--add-src --add-src
--cmvn-type utterance --cmvn-type utterance
...@@ -133,6 +142,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -133,6 +142,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="$cmd cmd="$cmd
--asr-prefix spm_${vocab_type}${asr_vocab_size}_asr" --asr-prefix spm_${vocab_type}${asr_vocab_size}_asr"
fi fi
if [[ ${speed_perturb} -eq 1 ]]; then
cmd="$cmd
--speed-perturb"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
......
...@@ -46,7 +46,8 @@ class MUSTC(Dataset): ...@@ -46,7 +46,8 @@ class MUSTC(Dataset):
utterance_id utterance_id
""" """
SPLITS = ["dev", "tst-COMMON", "tst-HE", "train"] # SPLITS = ["dev", "tst-COMMON", "tst-HE", "train"]
SPLITS = ["train_debug", "dev"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"] LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False) -> None: def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False) -> None:
...@@ -74,7 +75,9 @@ class MUSTC(Dataset): ...@@ -74,7 +75,9 @@ class MUSTC(Dataset):
self.data = [] self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = wav_root / wav_filename wav_path = wav_root / wav_filename
# sample_rate = torchaudio.info(wav_path.as_posix())[0].rate try:
sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
except TypeError:
sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
seg_group = sorted(_seg_group, key=lambda x: x["offset"]) seg_group = sorted(_seg_group, key=lambda x: x["offset"])
for i, segment in enumerate(seg_group): for i, segment in enumerate(seg_group):
...@@ -158,21 +161,28 @@ def process(args): ...@@ -158,21 +161,28 @@ def process(args):
output_root = Path(args.output_root).absolute() / f"en-{lang}" output_root = Path(args.output_root).absolute() / f"en-{lang}"
# Extract features # Extract features
if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
feature_root = output_root / "fbank80" feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True) feature_root.mkdir(exist_ok=True)
if args.speed_perturb:
zip_path = output_root / "fbank80_sp.zip"
else:
zip_path = output_root / "fbank80.zip" zip_path = output_root / "fbank80.zip"
manifest_dict = {} frame_path = output_root / "frame.pkl"
train_text = [] frame_dict = {}
index = 0
gen_feature_flag = False gen_feature_flag = False
if not Path.exists(zip_path): if not Path.exists(zip_path):
gen_feature_flag = True gen_feature_flag = True
for split in MUSTC.SPLITS:
if not Path.exists(output_root / f"{split}_{args.task}.tsv"):
gen_feature_flag = True
break
if args.overwrite or gen_feature_flag: gen_frame_flag = False
if not Path.exists(frame_path):
gen_frame_flag = True
if args.overwrite or gen_feature_flag or gen_frame_flag:
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
print(f"Fetching split {split}...") print(f"Fetching split {split}...")
dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb) dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb)
...@@ -182,55 +192,31 @@ def process(args): ...@@ -182,55 +192,31 @@ def process(args):
print("And estimating cepstral mean and variance stats...") print("And estimating cepstral mean and variance stats...")
gcmvn_feature_list = [] gcmvn_feature_list = []
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
manifest["src_text"] = []
for items in tqdm(dataset): for items in tqdm(dataset):
for item in items: for item in items:
# waveform, sample_rate, _, _, _, utt_id = item index += 1
waveform, sr, src_utt, tgt_utt, speaker_id, utt_id = item waveform, sr, _, _, _, utt_id = item
frame_dict[utt_id] = waveform.size(1)
if gen_feature_flag:
features_path = (feature_root / f"{utt_id}.npy").as_posix() features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path)) features = extract_fbank_features(waveform, sr, Path(features_path))
# np.save(
# (feature_root / f"{utt_id}.npy").as_posix(),
# features
# )
manifest["id"].append(utt_id)
duration_ms = int(waveform.size(1) / sr * 1000)
# duration_ms = int(time_dict[utt_id] / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
if args.task == "st" and args.add_src:
manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id)
if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"): if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"):
if len(gcmvn_feature_list) < args.gcmvn_max_num: if len(gcmvn_feature_list) < args.gcmvn_max_num:
gcmvn_feature_list.append(features) gcmvn_feature_list.append(features)
if is_train_split and args.size != -1 and len(manifest["id"]) > args.size: if is_train_split and args.size != -1 and index > args.size:
break break
if is_train_split:
if args.task == "st" and args.add_src and args.share:
train_text.extend(list(set(tuple(manifest["src_text"]))))
train_text.extend(dataset.get_tgt_text())
if is_train_split and args.cmvn_type == "global": if is_train_split and args.cmvn_type == "global":
# Estimate and save cmv # Estimate and save cmv
stats = cal_gcmvn_stats(gcmvn_feature_list) stats = cal_gcmvn_stats(gcmvn_feature_list)
with open(output_root / "gcmvn.npz", "wb") as f: with open(output_root / "gcmvn.npz", "wb") as f:
np.savez(f, mean=stats["mean"], std=stats["std"]) np.savez(f, mean=stats["mean"], std=stats["std"])
manifest_dict[split] = manifest with open(frame_path, "wb") as f:
pickle.dump(frame_dict, f)
# Pack features into ZIP # Pack features into ZIP
print("ZIPing features...") print("ZIPing features...")
...@@ -244,17 +230,44 @@ def process(args): ...@@ -244,17 +230,44 @@ def process(args):
train_text = [] train_text = []
if args.overwrite or gen_manifest_flag: if args.overwrite or gen_manifest_flag:
if len(frame_dict) == 0:
with open(frame_path, "rb") as f:
frame_dict = pickle.load(f)
print("Fetching ZIP manifest...") print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path) zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest # Generate TSV manifest
print("Generating manifest...") print("Generating manifest...")
for split in MUSTC.SPLITS:
for split, manifest in manifest_dict.items():
is_train_split = split.startswith("train") is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
for utt_id in manifest["id"]: if args.task == "st" and args.add_src:
manifest["src_text"] = []
dataset = MUSTC(args.data_root, lang, split)
for idx in range(len(dataset)):
items = dataset.get_fast(idx)
for item in items:
_, sr, src_utt, tgt_utt, speaker_id, utt_id = item
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id]) manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(frame_dict[utt_id] / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
if args.task == "st" and args.add_src:
manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id)
if is_train_split and args.size != -1 and len(manifest["id"]) > args.size:
break
if is_train_split:
if args.task == "st" and args.add_src and args.share:
train_text.extend(manifest["src_text"])
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest) df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split) df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv") save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
...@@ -316,7 +329,7 @@ def process(args): ...@@ -316,7 +329,7 @@ def process(args):
) )
# Clean up # Clean up
# shutil.rmtree(feature_root) shutil.rmtree(feature_root)
def process_joint(args): def process_joint(args):
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import os
from pathlib import Path
import shutil
from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
import multiprocessing as mp
import string
import pickle
import numpy as np
import pandas as pd
import torchaudio
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
filter_manifest_df,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
load_df_from_tsv,
save_df_to_tsv,
cal_gcmvn_stats,
)
from timeit import default_timer as timer
from torch.utils.data import Dataset
from tqdm import tqdm
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
class MUSTC(Dataset):
"""
Create a Dataset for MuST-C. Each item is a tuple of the form:
waveform, sample_rate, source utterance, target utterance, speaker_id,
utterance_id
"""
SPLITS = ["dev", "tst-COMMON", "tst-HE", "train"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES
_root = Path(root) / f"en-{lang}" / "data" / split
wav_root, txt_root = _root / "wav", _root / "txt"
assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir(), (_root, wav_root, txt_root)
# Load audio segments
try:
import yaml
except ImportError:
print("Please install PyYAML to load the MuST-C YAML files")
with open(txt_root / f"{split}.yaml") as f:
segments = yaml.load(f, Loader=yaml.BaseLoader)
self.speed_perturb = [0.9, 1.0, 1.1] if speed_perturb and split.startswith("train") else None
# Load source and target utterances
for _lang in ["en", lang]:
with open(txt_root / f"{split}.{_lang}") as f:
utterances = [r.strip() for r in f]
assert len(segments) == len(utterances)
for i, u in enumerate(utterances):
segments[i][_lang] = u
# Gather info
self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = wav_root / wav_filename
# sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
seg_group = sorted(_seg_group, key=lambda x: x["offset"])
for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate)
_id = f"{wav_path.stem}_{i}"
self.data.append(
(
wav_path.as_posix(),
offset,
n_frames,
sample_rate,
segment["en"],
segment[lang],
segment["speaker_id"],
_id,
)
)
def __getitem__(self, n: int):
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
items = []
if self.speed_perturb is None:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, utt_id])
else:
for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id
if speed == 1.0:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
else:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
effects = [
["speed", f"{speed}"],
["rate", f"{sr}"]
]
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, sp_utt_id])
return items
def get_fast(self, n: int):
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
items = []
if self.speed_perturb is None:
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, utt_id])
else:
for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, sp_utt_id])
return items
def get_src_text(self):
src_text = []
for item in self.data:
src_text.append(item[4])
return src_text
def get_tgt_text(self):
tgt_text = []
for item in self.data:
tgt_text.append(item[5])
return tgt_text
def __len__(self) -> int:
return len(self.data)
def get_feature(nargs):
dataset, index, feature_root, frame_dict = nargs
for item in dataset[index]:
waveform, sr, _, _, _, utt_id = item
frame_dict[utt_id] = waveform.size(1)
features_path = (feature_root / f"{utt_id}.npy").as_posix()
extract_fbank_features(waveform, sr, Path(features_path))
def process(args):
root = Path(args.data_root).absolute()
for lang in MUSTC.LANGUAGES:
cur_root = root / f"en-{lang}"
if not cur_root.is_dir():
print(f"{cur_root.as_posix()} does not exist. Skipped.")
continue
if args.output_root is None:
output_root = cur_root
else:
output_root = Path(args.output_root).absolute() / f"en-{lang}"
# Extract features
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
zip_path = output_root / "fbank80.zip"
frame_path = output_root / "frame.pkl"
frame_dict = {}
cores = int(mp.cpu_count() / 2)
print(f"Staring on {cores} cores")
pool = mp.Pool(processes=cores)
if args.overwrite or not Path.exists(zip_path):
for split in MUSTC.SPLITS:
print(f"Fetching split {split}...")
dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb)
is_train_split = split.startswith("train")
print("Extracting log mel filter bank features...")
start = timer()
if is_train_split and args.cmvn_type == "global":
print("And estimating cepstral mean and variance stats...")
gcmvn_feature_list = []
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
manifest["src_text"] = []
nargs = [(dataset, i, feature_root, frame_dict) for i in range(len(dataset))]
pool.map(get_feature, nargs)
end = timer()
print(f'elapsed time: {end - start}')
# if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"):
# if len(gcmvn_feature_list) < args.gcmvn_max_num:
# gcmvn_feature_list.append(features)
#
# if is_train_split and args.size != -1 and len(manifest["id"]) > args.size:
# break
# if is_train_split and args.cmvn_type == "global":
# # Estimate and save cmv
# stats = cal_gcmvn_stats(gcmvn_feature_list)
# with open(output_root / "gcmvn.npz", "wb") as f:
# np.savez(f, mean=stats["mean"], std=stats["std"])
with open(frame_path, "wb") as f:
pickle.dump(frame_dict, f)
# Pack features into ZIP
print("ZIPing features...")
create_zip(feature_root, zip_path)
gen_manifest_flag = False
for split in MUSTC.SPLITS:
if not Path.exists(output_root / f"{split}_{args.task}.tsv"):
gen_manifest_flag = True
break
train_text = []
if args.overwrite or gen_manifest_flag:
if len(frame_dict) == 0:
with open(frame_path, "rb") as f:
frame_dict = pickle.load(f)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest
print("Generating manifest...")
for split in MUSTC.SPLITS:
is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
manifest["src_text"] = []
dataset = MUSTC(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(frame_dict[utt_id] / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
if args.task == "st" and args.add_src:
manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id)
if is_train_split and args.size != -1 and len(manifest["id"]) > args.size:
break
if is_train_split:
if args.task == "st" and args.add_src and args.share:
train_text.extend(manifest["src_text"])
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.task == "st" and args.add_src:
if args.share:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
asr_spm_filename = args.asr_prefix + ".model"
else:
asr_spm_filename = None
if len(train_text) == 0:
print("Loading the training text to build dictionary...")
for split in MUSTC.SPLITS:
if split.startswith("train"):
dataset = MUSTC(args.data_root, lang, split)
src_text = dataset.get_src_text()
tgt_text = dataset.get_tgt_text()
for src_utt, tgt_utt in zip(src_text, tgt_text):
if args.task == "st" and args.add_src and args.share:
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
src_utt = src_utt.translate(None, string.punctuation)
train_text.append(src_utt)
train_text.append(tgt_utt)
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
output_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
)
# Generate config YAML
yaml_filename = f"config_{args.task}.yaml"
if args.task == "st" and args.add_src and args.share:
yaml_filename = f"config_{args.task}_share.yaml"
gen_config_yaml(
output_root,
spm_filename_prefix + ".model",
yaml_filename=yaml_filename,
specaugment_policy="lb",
cmvn_type=args.cmvn_type,
gcmvn_path=(
output_root / "gcmvn.npz" if args.cmvn_type == "global"
else None
),
asr_spm_filename=asr_spm_filename,
share_src_and_tgt=True if args.task == "asr" else False
)
# Clean up
shutil.rmtree(feature_root)
def process_joint(args):
cur_root = Path(args.data_root)
assert all((cur_root / f"en-{lang}").is_dir() for lang in MUSTC.LANGUAGES), \
"do not have downloaded data available for all 8 languages"
if args.output_root is None:
output_root = cur_root
else:
output_root = Path(args.output_root).absolute()
# Generate vocab
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}"
with NamedTemporaryFile(mode="w") as f:
for lang in MUSTC.LANGUAGES:
tsv_path = output_root / f"en-{lang}" / f"train_{args.task}.tsv"
df = load_df_from_tsv(tsv_path)
for t in df["tgt_text"]:
f.write(t + "\n")
special_symbols = None
if args.task == 'st':
special_symbols = [f'<lang:{lang}>' for lang in MUSTC.LANGUAGES]
gen_vocab(
Path(f.name),
output_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
special_symbols=special_symbols
)
# Generate config YAML
gen_config_yaml(
output_root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{args.task}.yaml",
specaugment_policy="ld",
prepend_tgt_lang_tag=(args.task == "st"),
)
# Make symbolic links to manifests
for lang in MUSTC.LANGUAGES:
for split in MUSTC.SPLITS:
src_path = output_root / f"en-{lang}" / f"{split}_{args.task}.tsv"
desc_path = output_root / f"{split}_{lang}_{args.task}.tsv"
if not desc_path.is_symlink():
os.symlink(src_path, desc_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument("--output-root", "-o", default=None, type=str)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=8000, type=int)
parser.add_argument("--task", type=str, choices=["asr", "st"])
parser.add_argument("--size", default=-1, type=int)
parser.add_argument("--speed-perturb", action="store_true", default=False,
help="apply speed perturbation on wave file")
parser.add_argument("--joint", action="store_true", help="")
parser.add_argument("--share", action="store_true",
help="share the tokenizer and dictionary of the transcription and translation")
parser.add_argument("--add-src", action="store_true", help="add the src text for st task")
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
parser.add_argument("--cmvn-type", default="utterance",
choices=["global", "utterance"],
help="The type of cepstral mean and variance normalization")
parser.add_argument("--overwrite", action="store_true", help="overwrite the existing files")
parser.add_argument("--gcmvn-max-num", default=150000, type=int,
help=(
"Maximum number of sentences to use to estimate"
"global mean and variance"
))
args = parser.parse_args()
if args.joint:
process_joint(args)
else:
process(args)
if __name__ == "__main__":
main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论