Commit 42321f76 by xuchen

modify the egs scripts to set the output dir

parent 0ab28954
...@@ -37,7 +37,8 @@ task=speech_to_text ...@@ -37,7 +37,8 @@ task=speech_to_text
vocab_type=unigram vocab_type=unigram
vocab_size=5000 vocab_size=5000
data_dir=~/st/data/${dataset} org_data_dir=/media/data/${dataset}
data_dir=~/st/data/${dataset}/asr
test_subset=(tst-COMMON) test_subset=(tst-COMMON)
# exp # exp
...@@ -47,7 +48,7 @@ exp_tag=baseline ...@@ -47,7 +48,7 @@ exp_tag=baseline
exp_name= exp_name=
# config # config
train_config=asr_train_ctc.yaml train_config=train_ctc.yaml
data_config=config_asr.yaml data_config=config_asr.yaml
# training setting # training setting
...@@ -92,8 +93,13 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -92,8 +93,13 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
### Task dependent. You have to make data the following preparation part by yourself. ### Task dependent. You have to make data the following preparation part by yourself.
### But you can utilize Kaldi recipes in most cases ### But you can utilize Kaldi recipes in most cases
echo "stage 0: ASR Data Preparation" echo "stage 0: ASR Data Preparation"
if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir}
fi
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${data_dir} --data-root ${org_data_dir}
--output-root ${data_dir}
--task asr --task asr
--vocab-type ${vocab_type} --vocab-type ${vocab_type}
--vocab-size ${vocab_size}" --vocab-size ${vocab_size}"
...@@ -101,6 +107,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -101,6 +107,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
[[ $eval -eq 1 ]] && eval $cmd [[ $eval -eq 1 ]] && eval $cmd
fi fi
data_dir=${data_dir}/${lang}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: ASR Network Training" echo "stage 1: ASR Network Training"
[[ ! -d $data_dir ]] && echo "The data dir $data_dir is not existing!" && exit 1; [[ ! -d $data_dir ]] && echo "The data dir $data_dir is not existing!" && exit 1;
...@@ -114,7 +122,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -114,7 +122,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
fi fi
fi fi
echo -e "dev=${device} data=$data_dir model=${model_dir}" echo -e "dev=${device} data=${data_dir} model=${model_dir}"
if [[ ! -d ${model_dir} ]]; then if [[ ! -d ${model_dir} ]]; then
mkdir -p ${model_dir} mkdir -p ${model_dir}
...@@ -127,7 +135,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -127,7 +135,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
cp ${train_config} ${model_dir} cp ${train_config} ${model_dir}
cmd="python3 -u ${root_dir}/fairseq_cli/train.py cmd="python3 -u ${root_dir}/fairseq_cli/train.py
$data_dir/$lang ${data_dir}
--config-yaml ${data_config} --config-yaml ${data_config}
--train-config ${train_config} --train-config ${train_config}
--task speech_to_text --task speech_to_text
...@@ -179,7 +187,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -179,7 +187,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# save info # save info
log=./history.log log=./history.log
echo "${time} | ${device} | $data_dir | ${model_dir} " >> $log echo "${time} | ${device} | ${data_dir} | ${model_dir} " >> $log
cat $log | tail -n 50 > tmp.log cat $log | tail -n 50 > tmp.log
mv tmp.log $log mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
...@@ -227,7 +235,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -227,7 +235,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for subset in ${test_subset[@]}; do for subset in ${test_subset[@]}; do
subset=${subset}_asr subset=${subset}_asr
cmd="python ${root_dir}/fairseq_cli/generate.py cmd="python ${root_dir}/fairseq_cli/generate.py
${data_dir}/$lang ${data_dir}
--config-yaml ${data_config} --config-yaml ${data_config}
--gen-subset ${subset} --gen-subset ${subset}
--task speech_to_text --task speech_to_text
......
...@@ -13,7 +13,7 @@ get_devices(){ ...@@ -13,7 +13,7 @@ get_devices(){
do do
line=`expr $dev + 2` line=`expr $dev + 2`
use=`cat $record | head -n $line | tail -1 | cut -d '|' -f3 | cut -d '/' -f1` use=`cat $record | head -n $line | tail -1 | cut -d '|' -f3 | cut -d '/' -f1`
if [[ $use -eq 0 ]]; then if [[ $use -lt 10 ]]; then
device[$count]=$dev device[$count]=$dev
count=`expr $count + 1` count=`expr $count + 1`
if [[ $count -eq $gpu_num ]]; then if [[ $count -eq $gpu_num ]]; then
......
...@@ -39,7 +39,8 @@ asr_vocab_size=5000 ...@@ -39,7 +39,8 @@ asr_vocab_size=5000
vocab_size=10000 vocab_size=10000
share_dict=1 share_dict=1
data_dir=~/st/data/${dataset} org_data_dir=/media/data/${dataset}
data_dir=~/st/data/${dataset}/st
test_subset=(tst-COMMON) test_subset=(tst-COMMON)
# exp # exp
...@@ -49,7 +50,7 @@ exp_tag=baseline ...@@ -49,7 +50,7 @@ exp_tag=baseline
exp_name= exp_name=
# config # config
train_config=st_train_ctc.yaml train_config=train_ctc.yaml
# training setting # training setting
fp16=1 fp16=1
...@@ -100,8 +101,13 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -100,8 +101,13 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
### Task dependent. You have to make data the following preparation part by yourself. ### Task dependent. You have to make data the following preparation part by yourself.
### But you can utilize Kaldi recipes in most cases ### But you can utilize Kaldi recipes in most cases
echo "stage 0: ASR Data Preparation" echo "stage 0: ASR Data Preparation"
if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir}
fi
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${data_dir} --data-root ${org_data_dir}
--output-root ${data_dir}
--task asr --task asr
--vocab-type ${vocab_type} --vocab-type ${vocab_type}
--vocab-size ${asr_vocab_size}" --vocab-size ${asr_vocab_size}"
...@@ -110,7 +116,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -110,7 +116,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: ST Data Preparation" echo "stage 0: ST Data Preparation"
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
--data-root ${data_dir} --data-root ${org_data_dir}
--output-root ${data_dir}
--task st --task st
--add-src --add-src
--cmvn-type utterance --cmvn-type utterance
...@@ -128,6 +135,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -128,6 +135,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
fi fi
data_dir=${data_dir}/${lang}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: ST Network Training" echo "stage 1: ST Network Training"
[[ ! -d ${data_dir} ]] && echo "The data dir ${data_dir} is not existing!" && exit 1; [[ ! -d ${data_dir} ]] && echo "The data dir ${data_dir} is not existing!" && exit 1;
...@@ -154,7 +163,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -154,7 +163,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
cp ${train_config} ${model_dir} cp ${train_config} ${model_dir}
cmd="python3 -u ${root_dir}/fairseq_cli/train.py cmd="python3 -u ${root_dir}/fairseq_cli/train.py
${data_dir}/$lang ${data_dir}
--config-yaml ${data_config} --config-yaml ${data_config}
--train-config ${train_config} --train-config ${train_config}
--task speech_to_text --task speech_to_text
...@@ -263,7 +272,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -263,7 +272,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for subset in ${test_subset[@]}; do for subset in ${test_subset[@]}; do
subset=${subset}_st subset=${subset}_st
cmd="python ${root_dir}/fairseq_cli/generate.py cmd="python ${root_dir}/fairseq_cli/generate.py
${data_dir}/$lang ${data_dir}
--config-yaml ${data_config} --config-yaml ${data_config}
--gen-subset ${subset} --gen-subset ${subset}
--task speech_to_text --task speech_to_text
......
...@@ -13,7 +13,7 @@ extra_parameter= ...@@ -13,7 +13,7 @@ extra_parameter=
#extra_parameter="${extra_parameter} " #extra_parameter="${extra_parameter} "
exp_tag= exp_tag=
train_config=st_train_ctc.yaml train_config=train_ctc.yaml
cmd="./run.sh cmd="./run.sh
--stage 1 --stage 1
......
...@@ -107,10 +107,15 @@ def process(args): ...@@ -107,10 +107,15 @@ def process(args):
if not cur_root.is_dir(): if not cur_root.is_dir():
print(f"{cur_root.as_posix()} does not exist. Skipped.") print(f"{cur_root.as_posix()} does not exist. Skipped.")
continue continue
if args.output_root is None:
output_root = cur_root
else:
output_root = Path(args.output_root).absolute() / f"en-{lang}"
# Extract features # Extract features
feature_root = cur_root / "fbank80" feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True) feature_root.mkdir(exist_ok=True)
zip_path = cur_root / "fbank80.zip" zip_path = output_root / "fbank80.zip"
if args.overwrite or not Path.exists(zip_path): if args.overwrite or not Path.exists(zip_path):
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
print(f"Fetching split {split}...") print(f"Fetching split {split}...")
...@@ -135,7 +140,7 @@ def process(args): ...@@ -135,7 +140,7 @@ def process(args):
if split == 'train' and args.cmvn_type == "global": if split == 'train' and args.cmvn_type == "global":
# Estimate and save cmv # Estimate and save cmv
stats = cal_gcmvn_stats(gcmvn_feature_list) stats = cal_gcmvn_stats(gcmvn_feature_list)
with open(cur_root / "gcmvn.npz", "wb") as f: with open(output_root / "gcmvn.npz", "wb") as f:
np.savez(f, mean=stats["mean"], std=stats["std"]) np.savez(f, mean=stats["mean"], std=stats["std"])
# Pack features into ZIP # Pack features into ZIP
...@@ -144,7 +149,7 @@ def process(args): ...@@ -144,7 +149,7 @@ def process(args):
gen_manifest_flag = False gen_manifest_flag = False
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
if not Path.exists(cur_root / f"{split}_{args.task}.tsv"): if not Path.exists(output_root / f"{split}_{args.task}.tsv"):
gen_manifest_flag = True gen_manifest_flag = True
break break
...@@ -183,7 +188,7 @@ def process(args): ...@@ -183,7 +188,7 @@ def process(args):
train_text.extend(manifest["tgt_text"]) train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest) df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split) df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv") save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
# Generate vocab # Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
...@@ -215,7 +220,7 @@ def process(args): ...@@ -215,7 +220,7 @@ def process(args):
f.write(t + "\n") f.write(t + "\n")
gen_vocab( gen_vocab(
Path(f.name), Path(f.name),
cur_root / spm_filename_prefix, output_root / spm_filename_prefix,
args.vocab_type, args.vocab_type,
args.vocab_size, args.vocab_size,
) )
...@@ -225,13 +230,13 @@ def process(args): ...@@ -225,13 +230,13 @@ def process(args):
yaml_filename = f"config_{args.task}_share.yaml" yaml_filename = f"config_{args.task}_share.yaml"
gen_config_yaml( gen_config_yaml(
cur_root, output_root,
spm_filename_prefix + ".model", spm_filename_prefix + ".model",
yaml_filename=yaml_filename, yaml_filename=yaml_filename,
specaugment_policy="lb", specaugment_policy="lb",
cmvn_type=args.cmvn_type, cmvn_type=args.cmvn_type,
gcmvn_path=( gcmvn_path=(
cur_root / "gcmvn.npz" if args.cmvn_type == "global" output_root / "gcmvn.npz" if args.cmvn_type == "global"
else None else None
), ),
asr_spm_filename=asr_spm_filename, asr_spm_filename=asr_spm_filename,
...@@ -245,12 +250,17 @@ def process_joint(args): ...@@ -245,12 +250,17 @@ def process_joint(args):
cur_root = Path(args.data_root) cur_root = Path(args.data_root)
assert all((cur_root / f"en-{lang}").is_dir() for lang in MUSTC.LANGUAGES), \ assert all((cur_root / f"en-{lang}").is_dir() for lang in MUSTC.LANGUAGES), \
"do not have downloaded data available for all 8 languages" "do not have downloaded data available for all 8 languages"
if args.output_root is None:
output_root = cur_root
else:
output_root = Path(args.output_root).absolute()
# Generate vocab # Generate vocab
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}" spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}"
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for lang in MUSTC.LANGUAGES: for lang in MUSTC.LANGUAGES:
tsv_path = cur_root / f"en-{lang}" / f"train_{args.task}.tsv" tsv_path = output_root / f"en-{lang}" / f"train_{args.task}.tsv"
df = load_df_from_tsv(tsv_path) df = load_df_from_tsv(tsv_path)
for t in df["tgt_text"]: for t in df["tgt_text"]:
f.write(t + "\n") f.write(t + "\n")
...@@ -259,14 +269,14 @@ def process_joint(args): ...@@ -259,14 +269,14 @@ def process_joint(args):
special_symbols = [f'<lang:{lang}>' for lang in MUSTC.LANGUAGES] special_symbols = [f'<lang:{lang}>' for lang in MUSTC.LANGUAGES]
gen_vocab( gen_vocab(
Path(f.name), Path(f.name),
cur_root / spm_filename_prefix, output_root / spm_filename_prefix,
args.vocab_type, args.vocab_type,
args.vocab_size, args.vocab_size,
special_symbols=special_symbols special_symbols=special_symbols
) )
# Generate config YAML # Generate config YAML
gen_config_yaml( gen_config_yaml(
cur_root, output_root,
spm_filename_prefix + ".model", spm_filename_prefix + ".model",
yaml_filename=f"config_{args.task}.yaml", yaml_filename=f"config_{args.task}.yaml",
specaugment_policy="ld", specaugment_policy="ld",
...@@ -275,8 +285,8 @@ def process_joint(args): ...@@ -275,8 +285,8 @@ def process_joint(args):
# Make symbolic links to manifests # Make symbolic links to manifests
for lang in MUSTC.LANGUAGES: for lang in MUSTC.LANGUAGES:
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
src_path = cur_root / f"en-{lang}" / f"{split}_{args.task}.tsv" src_path = output_root / f"en-{lang}" / f"{split}_{args.task}.tsv"
desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv" desc_path = output_root / f"{split}_{lang}_{args.task}.tsv"
if not desc_path.is_symlink(): if not desc_path.is_symlink():
os.symlink(src_path, desc_path) os.symlink(src_path, desc_path)
...@@ -284,6 +294,7 @@ def process_joint(args): ...@@ -284,6 +294,7 @@ def process_joint(args):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--data-root", "-d", required=True, type=str) parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument("--output-root", "-o", default=None, type=str)
parser.add_argument( parser.add_argument(
"--vocab-type", "--vocab-type",
default="unigram", default="unigram",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论