Commit ce4936bd by xuchen

optimize the speed perturb

parent b78c7894
......@@ -36,6 +36,7 @@ dataset=mustc
task=speech_to_text
vocab_type=unigram
vocab_size=5000
speed_perturb=1
org_data_dir=/media/data/${dataset}
data_dir=~/st/data/${dataset}/asr
......@@ -80,8 +81,14 @@ if [[ -z ${exp_name} ]]; then
if [[ -n ${extra_tag} ]]; then
exp_name=${exp_name}_${extra_tag}
fi
if [[ ${speed_perturb} -eq 1 ]]; then
exp_name=sp_${exp_name}
fi
fi
if [[ ${speed_perturb} -eq 1 ]]; then
data_dir=${data_dir}_sp
fi
model_dir=$root_dir/../checkpoints/$dataset/asr/${exp_name}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
......@@ -96,6 +103,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir}/${lang} ]]; then
mkdir -p ${data_dir}/${lang}
fi
source ~/tools/audio/bin/activate
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${org_data_dir}
......@@ -103,6 +111,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--task asr
--vocab-type ${vocab_type}
--vocab-size ${vocab_size}"
if [[ ${speed_perturb} -eq 1 ]]; then
cmd="$cmd
--speed-perturb"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval $cmd
fi
......@@ -138,7 +150,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
${data_dir}
--config-yaml ${data_config}
--train-config ${train_config}
--task speech_to_text
--task ${task}
--max-tokens ${max_tokens}
--update-freq ${update_freq}
--log-interval 100
......
......@@ -11,8 +11,9 @@ log-interval: 100
seed: 1
report-accuracy: True
#load-params:
#load-pretrained-encoder-from:
# load-params:
load-pretrained-encoder-from:
load-pretrained-decoder-from:
arch: s2t_transformer_s
share-decoder-input-output-embed: True
......
......@@ -38,10 +38,10 @@ vocab_type=unigram
asr_vocab_size=5000
vocab_size=10000
share_dict=1
speed_perturb=1
org_data_dir=/media/data/${dataset}
data_dir=~/st/data/${dataset}/st
data_dir=~/st/data/${dataset}/st_perturb_2
test_subset=(tst-COMMON)
# exp
......@@ -89,8 +89,14 @@ if [[ -z ${exp_name} ]]; then
if [[ -n ${extra_tag} ]]; then
exp_name=${exp_name}_${extra_tag}
fi
if [[ ${speed_perturb} -eq 1 ]]; then
exp_name=sp_${exp_name}
fi
fi
if [[ ${speed_perturb} -eq 1 ]]; then
data_dir=${data_dir}_sp
fi
model_dir=$root_dir/../checkpoints/$dataset/st/${exp_name}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
......@@ -105,7 +111,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir}/${lang} ]]; then
mkdir -p ${data_dir}/${lang}
fi
source audio/bin/activate
source ~/tools/audio/bin/activate
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${org_data_dir}
......@@ -113,6 +119,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--task asr
--vocab-type ${vocab_type}
--vocab-size ${asr_vocab_size}"
if [[ ${speed_perturb} -eq 1 ]]; then
cmd="$cmd
--speed-perturb"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 && ${share_dict} -ne 1 ]] && eval $cmd
......@@ -120,7 +130,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${org_data_dir}
--output-root ${data_dir}
--speed-perturb
--task st
--add-src
--cmvn-type utterance
......@@ -133,6 +142,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="$cmd
--asr-prefix spm_${vocab_type}${asr_vocab_size}_asr"
fi
if [[ ${speed_perturb} -eq 1 ]]; then
cmd="$cmd
--speed-perturb"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
......
......@@ -46,7 +46,8 @@ class MUSTC(Dataset):
utterance_id
"""
SPLITS = ["dev", "tst-COMMON", "tst-HE", "train"]
# SPLITS = ["dev", "tst-COMMON", "tst-HE", "train"]
SPLITS = ["train_debug", "dev"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False) -> None:
......@@ -74,8 +75,10 @@ class MUSTC(Dataset):
self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = wav_root / wav_filename
# sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
try:
sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
except TypeError:
sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
seg_group = sorted(_seg_group, key=lambda x: x["offset"])
for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate)
......@@ -158,21 +161,28 @@ def process(args):
output_root = Path(args.output_root).absolute() / f"en-{lang}"
# Extract features
feature_root = output_root / "fbank80"
if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
zip_path = output_root / "fbank80.zip"
manifest_dict = {}
train_text = []
if args.speed_perturb:
zip_path = output_root / "fbank80_sp.zip"
else:
zip_path = output_root / "fbank80.zip"
frame_path = output_root / "frame.pkl"
frame_dict = {}
index = 0
gen_feature_flag = False
if not Path.exists(zip_path):
gen_feature_flag = True
for split in MUSTC.SPLITS:
if not Path.exists(output_root / f"{split}_{args.task}.tsv"):
gen_feature_flag = True
break
if args.overwrite or gen_feature_flag:
gen_frame_flag = False
if not Path.exists(frame_path):
gen_frame_flag = True
if args.overwrite or gen_feature_flag or gen_frame_flag:
for split in MUSTC.SPLITS:
print(f"Fetching split {split}...")
dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb)
......@@ -182,59 +192,35 @@ def process(args):
print("And estimating cepstral mean and variance stats...")
gcmvn_feature_list = []
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
manifest["src_text"] = []
for items in tqdm(dataset):
for item in items:
# waveform, sample_rate, _, _, _, utt_id = item
waveform, sr, src_utt, tgt_utt, speaker_id, utt_id = item
index += 1
waveform, sr, _, _, _, utt_id = item
features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path))
# np.save(
# (feature_root / f"{utt_id}.npy").as_posix(),
# features
# )
frame_dict[utt_id] = waveform.size(1)
if gen_feature_flag:
features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path))
manifest["id"].append(utt_id)
duration_ms = int(waveform.size(1) / sr * 1000)
# duration_ms = int(time_dict[utt_id] / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
if args.task == "st" and args.add_src:
manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id)
if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"):
if len(gcmvn_feature_list) < args.gcmvn_max_num:
gcmvn_feature_list.append(features)
if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"):
if len(gcmvn_feature_list) < args.gcmvn_max_num:
gcmvn_feature_list.append(features)
if is_train_split and args.size != -1 and len(manifest["id"]) > args.size:
if is_train_split and args.size != -1 and index > args.size:
break
if is_train_split:
if args.task == "st" and args.add_src and args.share:
train_text.extend(list(set(tuple(manifest["src_text"]))))
train_text.extend(dataset.get_tgt_text())
if is_train_split and args.cmvn_type == "global":
# Estimate and save cmv
stats = cal_gcmvn_stats(gcmvn_feature_list)
with open(output_root / "gcmvn.npz", "wb") as f:
np.savez(f, mean=stats["mean"], std=stats["std"])
manifest_dict[split] = manifest
with open(frame_path, "wb") as f:
pickle.dump(frame_dict, f)
# Pack features into ZIP
print("ZIPing features...")
create_zip(feature_root, zip_path)
# Pack features into ZIP
print("ZIPing features...")
create_zip(feature_root, zip_path)
gen_manifest_flag = False
for split in MUSTC.SPLITS:
......@@ -244,17 +230,44 @@ def process(args):
train_text = []
if args.overwrite or gen_manifest_flag:
if len(frame_dict) == 0:
with open(frame_path, "rb") as f:
frame_dict = pickle.load(f)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest
print("Generating manifest...")
for split, manifest in manifest_dict.items():
for split in MUSTC.SPLITS:
is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
manifest["src_text"] = []
dataset = MUSTC(args.data_root, lang, split)
for idx in range(len(dataset)):
items = dataset.get_fast(idx)
for item in items:
_, sr, src_utt, tgt_utt, speaker_id, utt_id = item
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(frame_dict[utt_id] / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
if args.task == "st" and args.add_src:
manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id)
for utt_id in manifest["id"]:
manifest["audio"].append(zip_manifest[utt_id])
if is_train_split and args.size != -1 and len(manifest["id"]) > args.size:
break
if is_train_split:
if args.task == "st" and args.add_src and args.share:
train_text.extend(manifest["src_text"])
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
......@@ -316,7 +329,7 @@ def process(args):
)
# Clean up
# shutil.rmtree(feature_root)
shutil.rmtree(feature_root)
def process_joint(args):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论