Commit 9d4fe566 by xuchen

optimize the shell for joint decoding

parent f5fb7d7d
...@@ -5,20 +5,20 @@ ...@@ -5,20 +5,20 @@
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{ {
"name": "Python: Remote Connection", "name": "Python: Remote Attach",
"type": "python", "type": "python",
"request": "attach", "request": "attach",
"listen": { "connect": {
"host": "0.0.0.0", "host": "0.0.0.0",
"port": 5678 "port": 5678
}, },
"preLaunchTask": "Launch",
"pathMappings": [ "pathMappings": [
{ {
"localRoot": "${workspaceFolder}", "localRoot": "${workspaceFolder}",
"remoteRoot": "." "remoteRoot": "."
} }
] ],
} "justMyCode": true
},
] ]
} }
\ No newline at end of file
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
// 原理见文档 https://bytedance.feishu.cn/docx/doxcnZQgAWroFtQrvb1eB2jB3Lf // 原理见文档 https://bytedance.feishu.cn/docx/doxcnZQgAWroFtQrvb1eB2jB3Lf
// MuST-C ST train // MuST-C ST train
"command": "launch --gpu 1 -- python3 -m debugpy --connect $(hostname -i | awk '{print $2}'):5678 train.py /mnt/bd/data-model/data/must_c/en-de/st_tok --config-yaml config_share.yaml --task speech_to_text --max-tokens 10000 --skip-invalid-size-inputs-valid-test --log-interval 1 --save-dir /opt/tiger/s2t/checkpoints/must_c/en-de/st/debug --tensorboard-logdir /opt/tiger/s2t/checkpoints/must_c/en-de/st/debug --train-config /root/st/Fairseq-S2T/egs/mustc/st/conf/basis.yaml --train-config1 /root/st/Fairseq-S2T/egs/mustc/st/conf/at12_big.yaml --train-config2 /root/st/Fairseq-S2T/egs/mustc/st/conf/conformer.yaml --train-config3 /root/st/Fairseq-S2T/egs/mustc/st/conf/ttr.yaml --fp16 --train-subset train_ttr_all_kd --valid-subset dev_ttr_all_kd --ctc-weight 0.2 --xctc-weight 0.2 --axctc-weight 0 --inter-ctc-weight 0.1 --inter-ctc-layers 6,9 --inter-xctc-weight 0.1 --inter-xctc-layers 9 --inter-axctc-weight 0.2 --inter-axctc-layers 6 --ctc-pae inter_league --adapter inter_league --ctc-pae-ground-truth-ratio 0.3 --xctc-pae-ground-truth-ratio 0.3 --xctc-pae-ground-truth-ratio-decay 50:100:0 --only-train-enc-prob 0.3 --ctc-masked-loss", // "command": "launch --gpu 1 -- python3 -m debugpy --connect $(hostname -i | awk '{print $2}'):5678 train.py /mnt/bd/data-model/data/must_c/en-de/st_tok --config-yaml config_share.yaml --task speech_to_text --max-tokens 10000 --skip-invalid-size-inputs-valid-test --log-interval 1 --save-dir /opt/tiger/s2t/checkpoints/must_c/en-de/st/debug --tensorboard-logdir /opt/tiger/s2t/checkpoints/must_c/en-de/st/debug --train-config /root/st/Fairseq-S2T/egs/mustc/st/conf/basis.yaml --train-config1 /root/st/Fairseq-S2T/egs/mustc/st/conf/at12_big.yaml --train-config2 /root/st/Fairseq-S2T/egs/mustc/st/conf/conformer.yaml --train-config3 /root/st/Fairseq-S2T/egs/mustc/st/conf/ttr.yaml --fp16 --train-subset train_ttr_all_kd --valid-subset dev_ttr_all_kd --ctc-weight 0.2 --xctc-weight 0.2 --axctc-weight 0 --inter-ctc-weight 0.1 --inter-ctc-layers 6,9 --inter-xctc-weight 0.1 --inter-xctc-layers 9 --inter-axctc-weight 0.2 --inter-axctc-layers 6 --ctc-pae inter_league --adapter inter_league --ctc-pae-ground-truth-ratio 0.3 --xctc-pae-ground-truth-ratio 0.3 --xctc-pae-ground-truth-ratio-decay 50:100:0 --only-train-enc-prob 0.3 --ctc-masked-loss",
"command": "launch --gpu 1 -- python3 -m debugpy --connect $(hostname -i | awk '{print $2}'):5678 train.py /mnt/bd/data-model/data/must_c/en-de/st --config-yaml config_share.yaml --task speech_to_text --max-tokens 10000 --skip-invalid-size-inputs-valid-test --log-interval 1 --save-dir /opt/tiger/s2t/checkpoints/must_c/en-de/st/debug --tensorboard-logdir /opt/tiger/s2t/checkpoints/must_c/en-de/st/debug --train-config /root/st/Fairseq-S2T/egs/mustc/st/conf/basis.yaml --train-config1 /root/st/Fairseq-S2T/egs/mustc/st/conf/xctc.yaml --train-config2 /root/st/Fairseq-S2T/egs/mustc/st/conf/ctc.yaml --train-config3 /root/st/Fairseq-S2T/egs/mustc/st/conf/purexctc.yaml --fp16 --inter-xctc-weight 0.1 --inter-xctc-layers 6,9 --ctc-pae inter_league --xctc-pae-ground-truth-ratio 0.3", // "command": "launch --gpu 1 -- python3 -m debugpy --connect $(hostname -i | awk '{print $2}'):5678 train.py /mnt/bd/data-model/data/must_c/en-de/st --config-yaml config_share.yaml --task speech_to_text --max-tokens 10000 --skip-invalid-size-inputs-valid-test --log-interval 1 --save-dir /opt/tiger/s2t/checkpoints/must_c/en-de/st/debug --tensorboard-logdir /opt/tiger/s2t/checkpoints/must_c/en-de/st/debug --train-config /root/st/Fairseq-S2T/egs/mustc/st/conf/basis.yaml --train-config1 /root/st/Fairseq-S2T/egs/mustc/st/conf/xctc.yaml --train-config2 /root/st/Fairseq-S2T/egs/mustc/st/conf/ctc.yaml --train-config3 /root/st/Fairseq-S2T/egs/mustc/st/conf/purexctc.yaml --fp16 --inter-xctc-weight 0.1 --inter-xctc-layers 6,9 --ctc-pae inter_league --xctc-pae-ground-truth-ratio 0.3",
// MuST-C ST train // MuST-C ST train
// "command": "launch --gpu 1 -- python3 -m debugpy --connect $(hostname -i | awk '{print $2}'):5678 train.py /mnt/bd/data-model/data/must_c/en-de/st --config-yaml config_share.yaml --task speech_to_text --max-tokens 10000 --skip-invalid-size-inputs-valid-test --log-interval 1 --save-dir /opt/tiger/s2t/checkpoints/must_c/en-de/st/debug --tensorboard-logdir /opt/tiger/s2t/checkpoints/must_c/en-de/st/debug --train-config /root/st/Fairseq-S2T/egs/mustc/st/conf/basis.yaml --train-config1 /root/st/Fairseq-S2T/egs/mustc/st/conf/base.yaml --train-config2 /root/st/Fairseq-S2T/egs/mustc/st/conf/mixup.yaml --train-config3 /root/st/Fairseq-S2T/egs/mustc/st/conf/ctc.yaml --fp16", // "command": "launch --gpu 1 -- python3 -m debugpy --connect $(hostname -i | awk '{print $2}'):5678 train.py /mnt/bd/data-model/data/must_c/en-de/st --config-yaml config_share.yaml --task speech_to_text --max-tokens 10000 --skip-invalid-size-inputs-valid-test --log-interval 1 --save-dir /opt/tiger/s2t/checkpoints/must_c/en-de/st/debug --tensorboard-logdir /opt/tiger/s2t/checkpoints/must_c/en-de/st/debug --train-config /root/st/Fairseq-S2T/egs/mustc/st/conf/basis.yaml --train-config1 /root/st/Fairseq-S2T/egs/mustc/st/conf/base.yaml --train-config2 /root/st/Fairseq-S2T/egs/mustc/st/conf/mixup.yaml --train-config3 /root/st/Fairseq-S2T/egs/mustc/st/conf/ctc.yaml --fp16",
......
...@@ -10,10 +10,14 @@ if [ "$#" -eq 1 ]; then ...@@ -10,10 +10,14 @@ if [ "$#" -eq 1 ]; then
exp_name=$1 exp_name=$1
fi fi
ctc_infer=0
n_average=10 n_average=10
beam_size=5 beam_size=5
infer_ctc_weight=0.3
len_penalty=1.0 len_penalty=1.0
max_tokens=100000 max_tokens=50000
batch_size=1
infer_debug=0
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
...@@ -24,8 +28,12 @@ cmd="./run.sh ...@@ -24,8 +28,12 @@ cmd="./run.sh
--n_average ${n_average} --n_average ${n_average}
--beam_size ${beam_size} --beam_size ${beam_size}
--len_penalty ${len_penalty} --len_penalty ${len_penalty}
--batch_size ${batch_size}
--max_tokens ${max_tokens} --max_tokens ${max_tokens}
--dec_model ${dec_model} --dec_model ${dec_model}
--ctc_infer ${ctc_infer}
--infer_ctc_weight ${infer_ctc_weight}
--infer_debug ${infer_debug}
" "
if [[ -n ${data_dir} ]]; then if [[ -n ${data_dir} ]]; then
......
...@@ -15,4 +15,6 @@ ctc_infer_sort=${infer_dir}/${tag}_ctc_infer_sort ...@@ -15,4 +15,6 @@ ctc_infer_sort=${infer_dir}/${tag}_ctc_infer_sort
cut -f1 ${s2s_infer_file} > ${idx} cut -f1 ${s2s_infer_file} > ${idx}
paste ${idx} ${org_ctc_infer_file} > ${ctc_infer} paste ${idx} ${org_ctc_infer_file} > ${ctc_infer}
sort -n -t $'\t' ${ctc_infer} | cut -f2 > ${ctc_infer_sort} sort -n -t $'\t' ${ctc_infer} | cut -f2 > ${ctc_infer_sort}
python3 ./cal_wer.py ${ref} ${ctc_infer_sort} cmd="python3 ./cal_wer.py ${ref} ${ctc_infer_sort}"
\ No newline at end of file echo $cmd
eval $cmd
...@@ -21,7 +21,6 @@ for s in samples: ...@@ -21,7 +21,6 @@ for s in samples:
if extract_item in s: if extract_item in s:
fw.write("%s\n" % s[extract_item]) fw.write("%s\n" % s[extract_item])
else: else:
print("Error in sample: ") print("Error in sample: ", s, "when extract ", extract_item)
print(s)
exit() exit()
...@@ -72,6 +72,7 @@ step_valid=0 ...@@ -72,6 +72,7 @@ step_valid=0
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cer=0 cer=0
ctc_infer=0 ctc_infer=0
infer_ctc_weight=0
ctc_self_ensemble=0 ctc_self_ensemble=0
ctc_inter_logit=0 ctc_inter_logit=0
n_average=10 n_average=10
...@@ -80,6 +81,7 @@ len_penalty=1.0 ...@@ -80,6 +81,7 @@ len_penalty=1.0
single=0 single=0
epoch_ensemble=0 epoch_ensemble=0
best_ensemble=1 best_ensemble=1
infer_debug=0
infer_score=0 infer_score=0
# infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy" # infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
...@@ -337,7 +339,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -337,7 +339,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
for dec_model in ${dec_models[@]}; do for dec_model in ${dec_models[@]}; do
suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens} suffix=alpha${len_penalty}
model_str=`echo $dec_model | sed -e "s#checkpoint##" | sed "s#.pt##"` model_str=`echo $dec_model | sed -e "s#checkpoint##" | sed "s#.pt##"`
suffix=${suffix}_${model_str} suffix=${suffix}_${model_str}
if [[ -n ${cer} && ${cer} -eq 1 ]]; then if [[ -n ${cer} && ${cer} -eq 1 ]]; then
...@@ -345,6 +347,13 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -345,6 +347,13 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
else else
suffix=${suffix}_wer suffix=${suffix}_wer
fi fi
suffix=${suffix}_beam${beam_size}
if [[ ${batch_size} -ne 0 ]]; then
suffix=${suffix}_batch${batch_size}
else
suffix=${suffix}_tokens${max_tokens}
fi
if [[ ${ctc_infer} -eq 1 ]]; then if [[ ${ctc_infer} -eq 1 ]]; then
suffix=${suffix}_ctc suffix=${suffix}_ctc
fi fi
...@@ -354,6 +363,12 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -354,6 +363,12 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ ${ctc_inter_logit} -ne 0 ]]; then if [[ ${ctc_inter_logit} -ne 0 ]]; then
suffix=${suffix}_logit${ctc_inter_logit} suffix=${suffix}_logit${ctc_inter_logit}
fi fi
if (( $(echo "${infer_ctc_weight} > 0" | bc -l) )); then
suffix=${suffix}_ctc${infer_ctc_weight}
fi
if [[ ${infer_score} -eq 1 ]]; then
suffix=${suffix}_score
fi
suffix=`echo $suffix | sed -e "s#__#_#"` suffix=`echo $suffix | sed -e "s#__#_#"`
result_file=${model_dir}/decode_result_${suffix} result_file=${model_dir}/decode_result_${suffix}
...@@ -362,16 +377,23 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -362,16 +377,23 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
test_subset=${test_subset//,/ } test_subset=${test_subset//,/ }
for subset in ${test_subset[@]}; do for subset in ${test_subset[@]}; do
subset=${subset} subset=${subset}
cmd="python3 ${code_dir}/fairseq_cli/generate.py if [[ ${infer_debug} -ne 0 ]]; then
cmd="python3 -m debugpy --listen 0.0.0.0:5678 --wait-for-client"
else
cmd="python3 "
fi
cmd="$cmd ${code_dir}/fairseq_cli/generate.py
${data_dir} ${data_dir}
--config-yaml ${data_config} --config-yaml ${data_config}
--gen-subset ${subset} --gen-subset ${subset}
--task speech_to_text --task speech_to_text
--path ${model_dir}/${dec_model} --path ${model_dir}/${dec_model}
--results-path ${model_dir} --results-path ${model_dir}
--batch-size ${batch_size}
--max-tokens ${max_tokens} --max-tokens ${max_tokens}
--beam ${beam_size} --beam ${beam_size}
--lenpen ${len_penalty} --lenpen ${len_penalty}
--infer-ctc-weight ${infer_ctc_weight}
--scoring wer" --scoring wer"
if [[ ${cer} -eq 1 ]]; then if [[ ${cer} -eq 1 ]]; then
...@@ -390,10 +412,14 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -390,10 +412,14 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
cmd="${cmd} cmd="${cmd}
--ctc-inter-logit ${ctc_inter_logit}" --ctc-inter-logit ${ctc_inter_logit}"
fi fi
if [[ ${infer_score} -eq 1 ]]; then
cmd="${cmd}
--score-reference"
fi
if [[ -n ${infer_parameters} ]]; then if [[ -n ${infer_parameters} ]]; then
cmd="${cmd} cmd="${cmd}
${infer_parameters}" ${infer_parameters}"
fi fi
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
...@@ -422,12 +448,14 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -422,12 +448,14 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ ${ctc_infer} -eq 1 && -f ${model_dir}/${ctc_file} ]]; then if [[ ${ctc_infer} -eq 1 && -f ${model_dir}/${ctc_file} ]]; then
ref_file=${model_dir}/${subset}.${src_lang} ref_file=${model_dir}/${subset}.${src_lang}
if [[ ! -f ${ref_file} ]]; then if [[ ! -f ${ref_file} ]]; then
python3 ./local/extract_txt_from_tsv.py ${data_dir}/${subset}.tsv ${ref_file} "src_text" python3 ./local/extract_txt_from_tsv.py ${data_dir}/${subset}.tsv ${ref_file} "tgt_text"
fi fi
if [[ -f ${ref_file} ]]; then if [[ -f ${ref_file} ]]; then
ctc=$(mktemp -t temp.record.XXXXXX) ctc=$(mktemp -t temp.record.XXXXXX)
cd ./local cd ./local
./cal_wer.sh ${model_dir} ${subset} ${trans_file} ${ctc_file} ${ref_file} > ${ctc} cmd="./cal_wer.sh ${model_dir} ${subset} ${trans_file} ${ctc_file} ${ref_file} > ${ctc}"
#echo $cmd
eval $cmd
cd .. cd ..
echo "CTC WER" >> ${result_file} echo "CTC WER" >> ${result_file}
......
...@@ -11,11 +11,14 @@ if [ "$#" -eq 1 ]; then ...@@ -11,11 +11,14 @@ if [ "$#" -eq 1 ]; then
fi fi
sacrebleu=1 sacrebleu=1
ctc_infer=0 ctc_infer=1
n_average=10 n_average=10
beam_size=5 beam_size=5
infer_ctc_weight=0.1
len_penalty=1.0 len_penalty=1.0
max_tokens=50000 max_tokens=50000
batch_size=1
infer_debug=0
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
...@@ -24,12 +27,15 @@ cmd="./run.sh ...@@ -24,12 +27,15 @@ cmd="./run.sh
--gpu_num ${gpu_num} --gpu_num ${gpu_num}
--exp_name ${exp_name} --exp_name ${exp_name}
--sacrebleu ${sacrebleu} --sacrebleu ${sacrebleu}
--ctc_infer ${ctc_infer}
--n_average ${n_average} --n_average ${n_average}
--beam_size ${beam_size} --beam_size ${beam_size}
--len_penalty ${len_penalty} --len_penalty ${len_penalty}
--batch_size ${batch_size}
--max_tokens ${max_tokens} --max_tokens ${max_tokens}
--dec_model ${dec_model} --dec_model ${dec_model}
--ctc_infer ${ctc_infer}
--infer_ctc_weight ${infer_ctc_weight}
--infer_debug ${infer_debug}
" "
if [[ -n ${data_dir} ]]; then if [[ -n ${data_dir} ]]; then
......
dir=/xuchen/st/checkpoints/must_c/en-de/st/JointCTC/big
tag=JointCTC/big
for d in `ls $dir`; do
echo $d
./run.sh --stage 2 --max_tokens 10000 --batch_size 1 --ctc_infer 1 --infer_ctc_weight 0.1 --exp_name $tag/$d
./run.sh --stage 2 --max_tokens 10000 --batch_size 1 --ctc_infer 1 --infer_ctc_weight 0.2 --exp_name $tag/$d
./run.sh --stage 2 --max_tokens 10000 --batch_size 1 --ctc_infer 1 --infer_ctc_weight 0.3 --exp_name $tag/$d
./run.sh --stage 2 --max_tokens 10000 --batch_size 1 --ctc_infer 1 --infer_ctc_weight 0.4 --exp_name $tag/$d
./run.sh --stage 2 --max_tokens 10000 --batch_size 1 --ctc_infer 1 --infer_ctc_weight 0.5 --exp_name $tag/$d
done
\ No newline at end of file
...@@ -82,9 +82,11 @@ bleu_valid=0 ...@@ -82,9 +82,11 @@ bleu_valid=0
sacrebleu=1 sacrebleu=1
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
ctc_infer=0 ctc_infer=0
infer_ctc_weight=0
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
infer_debug=0
infer_score=0 infer_score=0
# infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy" # infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
...@@ -401,34 +403,60 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -401,34 +403,60 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
fi fi
suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens} suffix=alpha${len_penalty}
if [[ ${n_average} -ne 1 ]]; then model_str=`echo $dec_model | sed -e "s#checkpoint##" | sed "s#.pt##"`
suffix=${suffix}_${n_average} suffix=${suffix}_${model_str}
fi
if [[ ${sacrebleu} -eq 1 ]]; then if [[ ${sacrebleu} -eq 1 ]]; then
suffix=${suffix}_sacrebleu suffix=${suffix}_sacrebleu
else else
suffix=${suffix}_multibleu suffix=${suffix}_multibleu
fi fi
suffix=${suffix}_beam${beam_size}
if [[ ${batch_size} -ne 0 ]]; then
suffix=${suffix}_batch${batch_size}
else
suffix=${suffix}_tokens${max_tokens}
fi
if [[ ${ctc_infer} -eq 1 ]]; then
suffix=${suffix}_ctc
fi
if [[ ${ctc_self_ensemble} -eq 1 ]]; then
suffix=${suffix}_ensemble
fi
if [[ ${ctc_inter_logit} -ne 0 ]]; then
suffix=${suffix}_logit${ctc_inter_logit}
fi
if (( $(echo "${infer_ctc_weight} > 0" | bc -l) )); then
suffix=${suffix}_ctc${infer_ctc_weight}
fi
if [[ ${infer_score} -eq 1 ]]; then if [[ ${infer_score} -eq 1 ]]; then
suffix=${suffix}_score suffix=${suffix}_score
fi fi
suffix=`echo $suffix | sed -e "s#__#_#"`
result_file=${model_dir}/decode_result_${suffix} result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=${test_subset//,/ } test_subset=${test_subset//,/ }
for subset in ${test_subset[@]}; do for subset in ${test_subset[@]}; do
subset=${subset} subset=${subset}
cmd="python3 ${code_dir}/fairseq_cli/generate.py if [[ ${infer_debug} -ne 0 ]]; then
cmd="python3 -m debugpy --listen 0.0.0.0:5678 --wait-for-client"
else
cmd="python3 "
fi
cmd="$cmd ${code_dir}/fairseq_cli/generate.py
${data_dir} ${data_dir}
--config-yaml ${data_config} --config-yaml ${data_config}
--gen-subset ${subset} --gen-subset ${subset}
--task speech_to_text --task speech_to_text
--path ${model_dir}/${dec_model} --path ${model_dir}/${dec_model}
--results-path ${model_dir} --results-path ${model_dir}
--batch-size ${batch_size}
--max-tokens ${max_tokens} --max-tokens ${max_tokens}
--beam ${beam_size} --beam ${beam_size}
--skip-invalid-size-inputs-valid-test --skip-invalid-size-inputs-valid-test
--infer-ctc-weight ${infer_ctc_weight}
--lenpen ${len_penalty}" --lenpen ${len_penalty}"
if [[ ${ctc_infer} -eq 1 ]]; then if [[ ${ctc_infer} -eq 1 ]]; then
...@@ -452,6 +480,14 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -452,6 +480,14 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--target-lang ${tgt_lang}" --target-lang ${tgt_lang}"
fi fi
fi fi
if [[ ${ctc_self_ensemble} -eq 1 ]]; then
cmd="${cmd}
--ctc-self-ensemble"
fi
if [[ ${ctc_inter_logit} -ne 0 ]]; then
cmd="${cmd}
--ctc-inter-logit ${ctc_inter_logit}"
fi
if [[ ${infer_score} -eq 1 ]]; then if [[ ${infer_score} -eq 1 ]]; then
cmd="${cmd} cmd="${cmd}
--score-reference" --score-reference"
......
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
THIS_DIR="$( cd "$( dirname "$0" )" && pwd )" THIS_DIR="$( cd "$( dirname "$0" )" && pwd )"
cd ${THIS_DIR} cd ${THIS_DIR}
export ST_ROOT=/xuchen/st
export NCCL_DEBUG=INFO
echo "nameserver 114.114.114.114" >> /etc/resolv.conf
pip3 install espnet -i https://pypi.tuna.tsinghua.edu.cn/simple
if [[ `pip list | grep fairseq | wc -l` -eq 0 ]]; then if [[ `pip list | grep fairseq | wc -l` -eq 0 ]]; then
echo "default stage: env configure" echo "default stage: env configure"
......
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
THIS_DIR="$( cd "$( dirname "$0" )" && pwd )" THIS_DIR="$( cd "$( dirname "$0" )" && pwd )"
cd ${THIS_DIR} cd ${THIS_DIR}
export ST_ROOT=/xuchen/st
export NCCL_DEBUG=INFO
echo "nameserver 114.114.114.114" >> /etc/resolv.conf
pip3 install espnet -i https://pypi.tuna.tsinghua.edu.cn/simple
if [[ `pip list | grep fairseq | wc -l` -eq 0 ]]; then if [[ `pip list | grep fairseq | wc -l` -eq 0 ]]; then
echo "default stage: env configure" echo "default stage: env configure"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论