Commit 6cbfe851 by xuchen

implement the mixup method for speech-to-text

parent 67d8695f
arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 512
pds-stages: 4
#ctc-layer: 15
encoder-layers: 18
pds-layers: 6_3_3_6
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
decoder-layers: 6
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
...@@ -38,18 +38,18 @@ task=speech_to_text ...@@ -38,18 +38,18 @@ task=speech_to_text
vocab_type=unigram vocab_type=unigram
vocab_size=5000 vocab_size=5000
speed_perturb=0 speed_perturb=0
lcrm=0 lcrm=1
tokenizer=0 tokenizer=0
use_raw_audio=0 use_raw_audio=0
use_specific_dict=0 use_specific_dict=1
specific_prefix=st specific_prefix=unified
specific_dir=${root_dir}/data/mustc/st specific_dir=${root_dir}/data/iwslt2022/vocab
asr_vocab_prefix=spm_unigram10000_st_share asr_vocab_prefix=spm_en
org_data_dir=${root_dir}/data/${dataset} org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/asr data_dir=${root_dir}/data/${dataset}/asr
train_split=train train_split=train_covost,train_eu,train_iwslt,train_mustc_ende,train_voxpopuil,train_mustc_enzh,train_ted
valid_split=dev valid_split=dev
test_split=tst-COMMON test_split=tst-COMMON
test_subset=tst-COMMON test_subset=tst-COMMON
......
#! /bin/bash #! /bin/bash
# Processing MuST-C Datasets # Processing IWSLT 2022 Datasets
# Copyright 2021 Natural Language Processing Laboratory # Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com) # Xu Chen (xuchenneu@163.com)
...@@ -141,6 +141,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -141,6 +141,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir} ]]; then if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir} mkdir -p ${data_dir}
fi fi
if [[ ! -e ${data_dir}/data ]]; then
mkdir -p ${data_dir}/data
fi
if [[ ! -f ${data_dir}/${src_vocab_prefix}.txt || ! -f ${data_dir}/${tgt_vocab_prefix}.txt ]]; then if [[ ! -f ${data_dir}/${src_vocab_prefix}.txt || ! -f ${data_dir}/${tgt_vocab_prefix}.txt ]]; then
if [[ ${use_specific_dict} -eq 0 ]]; then if [[ ${use_specific_dict} -eq 0 ]]; then
...@@ -154,52 +157,31 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -154,52 +157,31 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--tgt-vocab-type ${tgt_vocab_type} --tgt-vocab-type ${tgt_vocab_type}
--src-vocab-size ${src_vocab_size} --src-vocab-size ${src_vocab_size}
--tgt-vocab-size ${tgt_vocab_size}" --tgt-vocab-size ${tgt_vocab_size}"
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
else else
cp -r ${specific_dir}/${src_vocab_prefix}.* ${data_dir} cp -r ${specific_dir}/${src_vocab_prefix}.* ${data_dir}
cp ${specific_dir}/${tgt_vocab_prefix}.* ${data_dir} cp ${specific_dir}/${tgt_vocab_prefix}.* ${data_dir}
fi
fi
mkdir -p ${data_dir}/data cmd="python ${code_dir}/examples/speech_to_text/prep_mt_data.py
for split in ${train_subset} ${valid_subset} ${trans_subset}; do --data-root ${org_data_dir}
{ --output-root ${data_dir}
if [[ -d ${org_data_dir}/data/${split}/txt ]]; then --splits ${train_subset},${valid_subset},${trans_subset}
text_dir=${org_data_dir}/data/${split}/txt --src-lang ${src_lang}
else --tgt-lang ${tgt_lang}
text_dir=${org_data_dir}/data/${split} --src-vocab-prefix ${src_vocab_prefix}
--tgt-vocab-prefix ${tgt_vocab_prefix}"
fi
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi fi
src_text=${text_dir}/${split}.${src_lang}
tgt_text=${text_dir}/${split}.${tgt_lang}
cmd="cat ${src_text}"
if [[ ${lcrm} -eq 1 ]]; then if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${src_text}" cmd="$cmd
--lowercase-src
--rm-punc-src"
fi fi
cmd="${cmd}
| spm_encode --model ${data_dir}/${src_vocab_prefix}.model
--output_format=piece
> ${data_dir}/data/${split}.${src_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
cmd="spm_encode
--model ${data_dir}/${tgt_vocab_prefix}.model
--output_format=piece
< ${tgt_text}
> ${data_dir}/data/${split}.${tgt_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
}& fi
done
wait
cmd="python ${code_dir}/fairseq_cli/preprocess.py cmd="python ${code_dir}/fairseq_cli/preprocess.py
--source-lang ${src_lang} --target-lang ${tgt_lang} --source-lang ${src_lang} --target-lang ${tgt_lang}
......
arch: s2t_sate arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
inter-mixup: True
inter-mixup-layer: 0
inter-mixup-beta: 0.5
encoder-embed-dim: 384
pds-stages: 4
#ctc-layer: 15
encoder-layers: 6
pds-layers: 2_1_1_2
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_256_256_384
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_4
pds-attn-heads: 4_4_4_6
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -9,39 +39,16 @@ lr: 2e-3 ...@@ -9,39 +39,16 @@ lr: 2e-3
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
ctc-weight: 0.3 ctc-weight: 0.3
target-ctc-weight: 0.2
target-ctc-layers: 3,6
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 4
decoder-layers: 6
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
acoustic-encoder: transformer
adapter: league
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
...@@ -12,7 +12,7 @@ arch: pdss2t_transformer_s_8 ...@@ -12,7 +12,7 @@ arch: pdss2t_transformer_s_8
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
#ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 4_2_2_4
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
pds-fusion-method: all_conv pds-fusion-method: all_conv
......
arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 384
pds-stages: 4
#ctc-layer: 15
encoder-layers: 12
pds-layers: 4_3_3_2
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_256_256_384
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_4
pds-attn-heads: 4_4_4_6
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
decoder-layers: 6
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 512
pds-stages: 4
#ctc-layer: 15
encoder-layers: 18
pds-layers: 6_3_3_6
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
decoder-layers: 6
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 200
pds-stages: 3
pds-layers: 4_4_4
pds-ratios: 2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 200_200_200
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 5_5_5
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-layers: 12
#macaron-style: True
#use-cnn-module: True
#cnn-module-kernel: 15
#encoder-activation-fn: swish
#encoder-attention-type: rel_pos
#load-pretrained-encoder-from:
...@@ -6,27 +6,27 @@ encoder-type: pds ...@@ -6,27 +6,27 @@ encoder-type: pds
#intermedia-ctc-weight: 1 #intermedia-ctc-weight: 1
#intermedia-temperature: 5 #intermedia-temperature: 5
encoder-attention-type: rel_pos #encoder-attention-type: rel_pos
#encoder-attention-type: reduced_rel_pos #encoder-attention-type: reduced_rel_pos
#pds-attn-ds-ratios: 4_2_2_1 #pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool #attention-reduced-method: pool
#attention-reduced-q: True #attention-reduced-q: True
encoder-embed-dim: 512 encoder-embed-dim: 384
pds-stages: 4 pds-stages: 4
#ctc-layer: 15 #ctc-layer: 15
encoder-layers: 10 encoder-layers: 12
pds-layers: 3_2_2_3 pds-layers: 4_3_3_2
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
pds-fusion-method: all_conv pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512 pds-embed-dims: 128_256_256_384
pds-ds-method: conv pds-ds-method: conv
pds-embed-norm: True pds-embed-norm: True
pds-position-embed: 1_1_1_1 pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5 pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4 pds-ffn-ratios: 8_8_8_4
pds-attn-heads: 4_6_6_8 pds-attn-heads: 4_4_4_8
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -42,9 +42,4 @@ post-process: sentencepiece ...@@ -42,9 +42,4 @@ post-process: sentencepiece
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#intermedia-temperature: 5
encoder-attention-type: rel_pos
#encoder-attention-type: reduced_rel_pos
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 512
pds-stages: 4
#ctc-layer: 15
encoder-layers: 10
pds-layers: 3_2_2_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.002
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
#load-pretrained-encoder-from:
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 240
pds-stages: 3
pds-layers: 4_4_4
pds-ratios: 2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 5_5_5
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-layers: 12
#macaron-style: True
#use-cnn-module: True
#cnn-module-kernel: 15
#encoder-activation-fn: swish
#encoder-attention-type: rel_pos
#load-pretrained-encoder-from:
...@@ -35,15 +35,17 @@ lang=${src_lang}-${tgt_lang} ...@@ -35,15 +35,17 @@ lang=${src_lang}-${tgt_lang}
dataset=mustc dataset=mustc
task=translation task=translation
vocab_type=unigram src_vocab_type=unigram
vocab_size=10000 tgt_vocab_type=unigram
src_vocab_size=10000
tgt_vocab_size=10000
share_dict=1 share_dict=1
lcrm=0 lcrm=0
tokenizer=0 tokenizer=0
use_specific_dict=1 use_specific_dict=1
specific_prefix=st specific_prefix=st
specific_dir=${root_dir}/data/mustc/st specific_dir=${root_dir}/data/${dataset}/st
src_vocab_prefix=spm_unigram10000_st_share src_vocab_prefix=spm_unigram10000_st_share
tgt_vocab_prefix=spm_unigram10000_st_share tgt_vocab_prefix=spm_unigram10000_st_share
...@@ -80,17 +82,24 @@ len_penalty=1.0 ...@@ -80,17 +82,24 @@ len_penalty=1.0
if [[ ${use_specific_dict} -eq 1 ]]; then if [[ ${use_specific_dict} -eq 1 ]]; then
exp_prefix=${exp_prefix}_${specific_prefix} exp_prefix=${exp_prefix}_${specific_prefix}
data_dir=${data_dir}/${specific_prefix} data_dir=${data_dir}/${specific_prefix}
mkdir -p ${data_dir}
else else
if [[ "${vocab_type}" == "char" ]]; then if [[ "${tgt_vocab_type}" == "char" ]]; then
vocab_name=${vocab_type} vocab_name=char
exp_prefix=${exp_prefix}_${vocab_type} exp_prefix=${exp_prefix}_char
else
if [[ ${src_vocab_size} -ne ${tgt_vocab_size} || "${src_vocab_type}" -ne "${tgt_vocab_type}" ]]; then
src_vocab_name=${src_vocab_type}${src_vocab_size}
tgt_vocab_name=${tgt_vocab_type}${tgt_vocab_size}
vocab_name=${src_vocab_name}_${tgt_vocab_name}
else else
vocab_name=${vocab_type}${vocab_size} vocab_name=${tgt_vocab_type}${tgt_vocab_size}
src_vocab_name=${vocab_name}
tgt_vocab_name=${vocab_name}
fi
fi fi
data_dir=${data_dir}/${vocab_name} data_dir=${data_dir}/${vocab_name}
src_vocab_prefix=spm_${vocab_name}_${src_lang} src_vocab_prefix=spm_${src_vocab_name}_${src_lang}
tgt_vocab_prefix=spm_${vocab_name}_${tgt_lang} tgt_vocab_prefix=spm_${tgt_vocab_name}_${tgt_lang}
if [[ $share_dict -eq 1 ]]; then if [[ $share_dict -eq 1 ]]; then
data_dir=${data_dir}_share data_dir=${data_dir}_share
src_vocab_prefix=spm_${vocab_name}_share src_vocab_prefix=spm_${vocab_name}_share
...@@ -132,6 +141,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -132,6 +141,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir} ]]; then if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir} mkdir -p ${data_dir}
fi fi
if [[ ! -e ${data_dir}/data ]]; then
mkdir -p ${data_dir}/data
fi
if [[ ! -f ${data_dir}/${src_vocab_prefix}.txt || ! -f ${data_dir}/${tgt_vocab_prefix}.txt ]]; then if [[ ! -f ${data_dir}/${src_vocab_prefix}.txt || ! -f ${data_dir}/${tgt_vocab_prefix}.txt ]]; then
if [[ ${use_specific_dict} -eq 0 ]]; then if [[ ${use_specific_dict} -eq 0 ]]; then
...@@ -141,51 +153,35 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -141,51 +153,35 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--splits ${train_subset},${valid_subset},${trans_subset} --splits ${train_subset},${valid_subset},${trans_subset}
--src-lang ${src_lang} --src-lang ${src_lang}
--tgt-lang ${tgt_lang} --tgt-lang ${tgt_lang}
--vocab-type ${vocab_type} --src-vocab-type ${src_vocab_type}
--vocab-size ${vocab_size}" --tgt-vocab-type ${tgt_vocab_type}
if [[ $share_dict -eq 1 ]]; then --src-vocab-size ${src_vocab_size}
cmd="$cmd --tgt-vocab-size ${tgt_vocab_size}"
--share"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
else else
cp -r ${specific_dir}/${src_vocab_prefix}.* ${data_dir} cp -r ${specific_dir}/${src_vocab_prefix}.* ${data_dir}
cp ${specific_dir}/${tgt_vocab_prefix}.* ${data_dir} cp ${specific_dir}/${tgt_vocab_prefix}.* ${data_dir}
fi
fi
mkdir -p ${data_dir}/data cmd="python ${code_dir}/examples/speech_to_text/prep_mt_data.py
for split in ${train_subset} ${valid_subset} ${trans_subset}; do --data-root ${org_data_dir}
{ --output-root ${data_dir}
if [[ -d ${org_data_dir}/data/${split}/txt ]]; then --splits ${train_subset},${valid_subset},${trans_subset}
txt_dir=${org_data_dir}/data/${split}/txt --src-lang ${src_lang}
else --tgt-lang ${tgt_lang}
txt_dir=${org_data_dir}/data/${split} --src-vocab-prefix ${src_vocab_prefix}
--tgt-vocab-prefix ${tgt_vocab_prefix}"
fi
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi fi
cmd="cat ${txt_dir}/${split}.${src_lang}"
if [[ ${lcrm} -eq 1 ]]; then if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${org_data_dir}/data/${split}.${src_lang}" cmd="$cmd
--lowercase-src
--rm-punc-src"
fi fi
cmd="${cmd}
| spm_encode --model ${data_dir}/${src_vocab_prefix}.model
--output_format=piece
> ${data_dir}/data/${split}.${src_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
cmd="spm_encode
--model ${data_dir}/${tgt_vocab_prefix}.model
--output_format=piece
< ${txt_dir}/${split}.${tgt_lang}
> ${data_dir}/data/${split}.${tgt_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
}& fi
done
wait
cmd="python ${code_dir}/fairseq_cli/preprocess.py cmd="python ${code_dir}/fairseq_cli/preprocess.py
--source-lang ${src_lang} --target-lang ${tgt_lang} --source-lang ${src_lang} --target-lang ${tgt_lang}
...@@ -317,11 +313,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -317,11 +313,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
mv tmp.log $log mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
cmd="nohup ${cmd} >> ${model_dir}/train.log 2>&1 &" log=${model_dir}/train.log
cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
sleep 2s sleep 2s
tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log}
fi fi
fi fi
wait wait
......
...@@ -35,8 +35,10 @@ lang=${src_lang}-${tgt_lang} ...@@ -35,8 +35,10 @@ lang=${src_lang}-${tgt_lang}
dataset=wmt16.en-de dataset=wmt16.en-de
task=translation task=translation
vocab_type=unigram src_vocab_type=unigram
vocab_size=32000 tgt_vocab_type=unigram
src_vocab_size=10000
tgt_vocab_size=10000
share_dict=1 share_dict=1
lcrm=0 lcrm=0
tokenizer=1 tokenizer=1
...@@ -81,17 +83,24 @@ len_penalty=1.0 ...@@ -81,17 +83,24 @@ len_penalty=1.0
if [[ ${use_specific_dict} -eq 1 ]]; then if [[ ${use_specific_dict} -eq 1 ]]; then
exp_prefix=${exp_prefix}_${specific_prefix} exp_prefix=${exp_prefix}_${specific_prefix}
data_dir=${data_dir}/${specific_prefix} data_dir=${data_dir}/${specific_prefix}
mkdir -p ${data_dir}
else else
if [[ "${vocab_type}" == "char" ]]; then if [[ "${tgt_vocab_type}" == "char" ]]; then
vocab_name=${vocab_type} vocab_name=char
exp_prefix=${exp_prefix}_${vocab_type} exp_prefix=${exp_prefix}_char
else
if [[ ${src_vocab_size} -ne ${tgt_vocab_size} || "${src_vocab_type}" -ne "${tgt_vocab_type}" ]]; then
src_vocab_name=${src_vocab_type}${src_vocab_size}
tgt_vocab_name=${tgt_vocab_type}${tgt_vocab_size}
vocab_name=${src_vocab_name}_${tgt_vocab_name}
else else
vocab_name=${vocab_type}${vocab_size} vocab_name=${tgt_vocab_type}${tgt_vocab_size}
src_vocab_name=${vocab_name}
tgt_vocab_name=${vocab_name}
fi
fi fi
data_dir=${data_dir}/${vocab_name} data_dir=${data_dir}/${vocab_name}
src_vocab_prefix=spm_${vocab_name}_${src_lang} src_vocab_prefix=spm_${src_vocab_name}_${src_lang}
tgt_vocab_prefix=spm_${vocab_name}_${tgt_lang} tgt_vocab_prefix=spm_${tgt_vocab_name}_${tgt_lang}
if [[ $share_dict -eq 1 ]]; then if [[ $share_dict -eq 1 ]]; then
data_dir=${data_dir}_share data_dir=${data_dir}_share
src_vocab_prefix=spm_${vocab_name}_share src_vocab_prefix=spm_${vocab_name}_share
...@@ -103,6 +112,9 @@ if [[ ${lcrm} -eq 1 ]]; then ...@@ -103,6 +112,9 @@ if [[ ${lcrm} -eq 1 ]]; then
exp_prefix=${exp_prefix}_lcrm exp_prefix=${exp_prefix}_lcrm
fi fi
if [[ ${tokenizer} -eq 1 ]]; then if [[ ${tokenizer} -eq 1 ]]; then
train_subset=${train_subset}.tok
valid_subset=${valid_subset}.tok
trans_subset=${trans_subset}.tok
data_dir=${data_dir}_tok data_dir=${data_dir}_tok
exp_prefix=${exp_prefix}_tok exp_prefix=${exp_prefix}_tok
fi fi
...@@ -130,6 +142,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -130,6 +142,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir} ]]; then if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir} mkdir -p ${data_dir}
fi fi
if [[ ! -e ${data_dir}/data ]]; then
mkdir -p ${data_dir}/data
fi
if [[ ! -f ${data_dir}/${src_vocab_prefix}.txt || ! -f ${data_dir}/${tgt_vocab_prefix}.txt ]]; then if [[ ! -f ${data_dir}/${src_vocab_prefix}.txt || ! -f ${data_dir}/${tgt_vocab_prefix}.txt ]]; then
if [[ ${use_specific_dict} -eq 0 ]]; then if [[ ${use_specific_dict} -eq 0 ]]; then
...@@ -139,62 +154,35 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -139,62 +154,35 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--splits ${train_subset},${valid_subset},${trans_subset} --splits ${train_subset},${valid_subset},${trans_subset}
--src-lang ${src_lang} --src-lang ${src_lang}
--tgt-lang ${tgt_lang} --tgt-lang ${tgt_lang}
--vocab-type ${vocab_type} --src-vocab-type ${src_vocab_type}
--vocab-size ${vocab_size}" --tgt-vocab-type ${tgt_vocab_type}
if [[ $share_dict -eq 1 ]]; then --src-vocab-size ${src_vocab_size}
cmd="$cmd --tgt-vocab-size ${tgt_vocab_size}"
--share"
fi
if [[ ${tokenizer} -eq 1 ]]; then
cmd="$cmd
--tokenizer"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
else else
cp -r ${specific_dir}/${src_vocab_prefix}.* ${data_dir} cp -r ${specific_dir}/${src_vocab_prefix}.* ${data_dir}
cp ${specific_dir}/${tgt_vocab_prefix}.* ${data_dir} cp ${specific_dir}/${tgt_vocab_prefix}.* ${data_dir}
fi
fi
mkdir -p ${data_dir}/data cmd="python ${code_dir}/examples/speech_to_text/prep_mt_data.py
for split in ${train_subset} ${valid_subset} ${trans_subset}; do --data-root ${org_data_dir}
{ --output-root ${data_dir}
if [[ -d ${org_data_dir}/data/${split}/txt ]]; then --splits ${train_subset},${valid_subset},${trans_subset}
text_dir=${org_data_dir}/data/${split}/txt --src-lang ${src_lang}
else --tgt-lang ${tgt_lang}
text_dir=${org_data_dir}/data/${split} --src-vocab-prefix ${src_vocab_prefix}
--tgt-vocab-prefix ${tgt_vocab_prefix}"
fi fi
src_text=${text_dir}/${split}.${src_lang} if [[ $share_dict -eq 1 ]]; then
tgt_text=${text_dir}/${split}.${tgt_lang} cmd="$cmd
if [[ ${tokenizer} -eq 1 ]]; then --share"
src_text=${text_dir}/${split}.tok.${src_lang}
tgt_text=${text_dir}/${split}.tok.${tgt_lang}
fi fi
cmd="cat ${src_text}"
if [[ ${lcrm} -eq 1 ]]; then if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${src_text}" cmd="$cmd
--lowercase-src
--rm-punc-src"
fi fi
cmd="${cmd}
| spm_encode --model ${data_dir}/${src_vocab_prefix}.model
--output_format=piece
> ${data_dir}/data/${split}.${src_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
cmd="spm_encode
--model ${data_dir}/${tgt_vocab_prefix}.model
--output_format=piece
< ${tgt_text}
> ${data_dir}/data/${split}.${tgt_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
}& fi
done
wait
cmd="python ${code_dir}/fairseq_cli/preprocess.py cmd="python ${code_dir}/fairseq_cli/preprocess.py
--source-lang ${src_lang} --target-lang ${tgt_lang} --source-lang ${src_lang} --target-lang ${tgt_lang}
...@@ -327,16 +315,14 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -327,16 +315,14 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
log=${model_dir}/train.log log=${model_dir}/train.log
cmd="nohup ${cmd} >> ${log} 2>&1 &" cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
sleep 2s sleep 2s
tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log} tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log}
fi fi
wait
echo -e " >> finish training \n"
fi fi
wait
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: MT Decoding" echo "stage 2: MT Decoding"
......
...@@ -45,7 +45,7 @@ tokenizer=1 ...@@ -45,7 +45,7 @@ tokenizer=1
use_specific_dict=1 use_specific_dict=1
specific_prefix=unified specific_prefix=unified
specific_dir=${root_dir}/data/wmt20/vocab specific_dir=${root_dir}/data/${dataset}/vocab
src_vocab_prefix=spm_en src_vocab_prefix=spm_en
tgt_vocab_prefix=spm_zh tgt_vocab_prefix=spm_zh
......
...@@ -148,19 +148,50 @@ class CtcCriterion(FairseqCriterion): ...@@ -148,19 +148,50 @@ class CtcCriterion(FairseqCriterion):
loss, logging_output = self.compute_ctc_loss(model, sample, net_output, logging_output) loss, logging_output = self.compute_ctc_loss(model, sample, net_output, logging_output)
return loss, sample_size, logging_output return loss, sample_size, logging_output
def get_loss(self, lprobs, targets_flat, input_lengths, transcript_lengths):
with torch.backends.cudnn.flags(enabled=False):
ctc_loss = self.ctc_loss(
lprobs,
targets_flat,
input_lengths,
transcript_lengths,
)
return ctc_loss
def compute_ctc_loss(self, model, sample, net_output, logging_output): def compute_ctc_loss(self, model, sample, net_output, logging_output):
transcript = sample["transcript"] transcript = sample["transcript"]
if "ctc_padding_mask" in net_output: # if "ctc_padding_mask" in net_output:
non_padding_mask = ~net_output["ctc_padding_mask"][0] # non_padding_mask = ~net_output["ctc_padding_mask"][0]
else: # else:
# non_padding_mask = ~net_output["encoder_padding_mask"][0]
mixup = False
if "mixup" in net_output and net_output["mixup"] is not None:
mixup = True
mixup_coef = net_output["mixup"]["coef"]
mixup_idx1 = net_output["mixup"]["index1"]
mixup_idx2 = net_output["mixup"]["index2"]
non_padding_mask = ~net_output["encoder_padding_mask"][0] non_padding_mask = ~net_output["encoder_padding_mask"][0]
ctc_input_lengths = input_lengths = non_padding_mask.long().sum(-1) input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (transcript["tokens"] != self.pad_idx) & ( pad_mask = (transcript["tokens"] != self.pad_idx) & (
transcript["tokens"] != self.eos_idx transcript["tokens"] != self.eos_idx
) )
targets_flat = transcript["tokens"].masked_select(pad_mask) if mixup:
transcript_lengths = pad_mask.sum(-1) mask1 = pad_mask[mixup_idx1]
mask2 = pad_mask[mixup_idx2]
transcript_flat1 = transcript["tokens"][[mixup_idx1]].masked_select(mask1)
transcript_flat2 = transcript["tokens"][mixup_idx2].masked_select(mask2)
transcript_lengths1 = mask1.sum(-1)
transcript_lengths2 = mask2.sum(-1)
transcript_flat = [transcript_flat1, transcript_flat2]
transcript_lengths = [transcript_lengths1, transcript_lengths2]
loss_coef = [mixup_coef, 1 - mixup_coef]
else:
transcript_flat = [transcript["tokens"].masked_select(pad_mask)]
transcript_lengths = [pad_mask.sum(-1)]
loss_coef = [1]
ctc_loss = 0 ctc_loss = 0
ctc_entropy = 0 ctc_entropy = 0
...@@ -172,13 +203,9 @@ class CtcCriterion(FairseqCriterion): ...@@ -172,13 +203,9 @@ class CtcCriterion(FairseqCriterion):
).contiguous() # (T, B, C) from the encoder ).contiguous() # (T, B, C) from the encoder
lprobs.batch_first = False lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False): for flat, lengths, coef in zip(transcript_flat, transcript_lengths, loss_coef):
ctc_loss = self.ctc_loss( ctc_loss += self.get_loss(lprobs, flat, input_lengths, lengths) * coef
lprobs,
targets_flat,
input_lengths,
transcript_lengths,
)
if self.ctc_entropy > 0: if self.ctc_entropy > 0:
from torch.distributions import Categorical from torch.distributions import Categorical
# ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100] # ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100]
...@@ -201,23 +228,19 @@ class CtcCriterion(FairseqCriterion): ...@@ -201,23 +228,19 @@ class CtcCriterion(FairseqCriterion):
if type(out) == list: if type(out) == list:
inter_ctc_logit = out[0] inter_ctc_logit = out[0]
padding = ~out[1] padding = ~out[1]
input_lengths = padding.long().sum(-1) inter_input_lengths = padding.long().sum(-1)
else: else:
inter_ctc_logit = out inter_ctc_logit = out
inter_input_lengths = input_lengths
inter_lprobs = model.get_normalized_probs( inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True [inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder ).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False inter_lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False): for flat, lengths, coef in zip(transcript_flat, transcript_lengths, loss_coef):
loss = self.ctc_loss( intermedia_ctc_loss += self.get_loss(inter_lprobs, flat, inter_input_lengths, lengths) * coef
inter_lprobs,
targets_flat,
input_lengths,
transcript_lengths,
)
intermedia_ctc_loss += loss
intermedia_ctc_loss /= intermedia_ctc_num intermedia_ctc_loss /= intermedia_ctc_num
logging_output["intermedia_ctc_loss"] = utils.item(intermedia_ctc_loss.data) logging_output["intermedia_ctc_loss"] = utils.item(intermedia_ctc_loss.data)
...@@ -233,31 +256,40 @@ class CtcCriterion(FairseqCriterion): ...@@ -233,31 +256,40 @@ class CtcCriterion(FairseqCriterion):
if self.target_ctc_weight > 0 and target_ctc_num > 0: if self.target_ctc_weight > 0 and target_ctc_num > 0:
target = sample["target"] target = sample["target"]
pad_mask = (target != self.pad_idx) & (target != self.eos_idx) pad_mask = (target != self.pad_idx) & (target != self.eos_idx)
targets_flat = target.masked_select(pad_mask)
target_length = pad_mask.sum(-1) if mixup:
mask1 = pad_mask[mixup_idx1]
mask2 = pad_mask[mixup_idx2]
target_flat1 = target.masked_select(mask1)
target_flat2 = target.masked_select(mask2)
transcript_lengths1 = mask1.sum(-1)
transcript_lengths2 = mask2.sum(-1)
target_flat = [target_flat1, target_flat2]
target_length = [transcript_lengths1, transcript_lengths2]
loss_coef = [mixup_coef, 1 - mixup_coef]
else:
target_flat = [target.masked_select(pad_mask)]
target_length = [pad_mask.sum(-1)]
loss_coef = [1]
for i in range(target_ctc_num): for i in range(target_ctc_num):
out = net_output["target_ctc_logits"][i] out = net_output["target_ctc_logits"][i]
if type(out) == list: if type(out) == list:
inter_ctc_logit = out[0] inter_ctc_logit = out[0]
padding = ~out[1] padding = ~out[1]
input_lengths = padding.long().sum(-1) tgt_input_lengths = padding.long().sum(-1)
else: else:
inter_ctc_logit = out inter_ctc_logit = out
tgt_input_lengths = input_lengths
inter_lprobs = model.get_normalized_probs( tgt_inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True [inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder ).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False tgt_inter_lprobs.batch_first = False
for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths, lengths) * coef
with torch.backends.cudnn.flags(enabled=False):
loss = self.ctc_loss(
inter_lprobs,
targets_flat,
ctc_input_lengths,
target_length,
)
target_ctc_loss += loss
target_ctc_loss /= target_ctc_num target_ctc_loss /= target_ctc_num
logging_output["target_ctc_loss"] = utils.item(target_ctc_loss.data) logging_output["target_ctc_loss"] = utils.item(target_ctc_loss.data)
...@@ -270,7 +302,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -270,7 +302,6 @@ class CtcCriterion(FairseqCriterion):
if type(out) == list: if type(out) == list:
inter_ctc_logit = out[0] inter_ctc_logit = out[0]
padding = ~out[1] padding = ~out[1]
input_lengths = padding.long().sum(-1)
else: else:
inter_ctc_logit = out inter_ctc_logit = out
...@@ -299,11 +330,26 @@ class CtcCriterion(FairseqCriterion): ...@@ -299,11 +330,26 @@ class CtcCriterion(FairseqCriterion):
logging_output["all_ctc_loss"] = utils.item(loss.data) logging_output["all_ctc_loss"] = utils.item(loss.data)
if torch.isnan(loss) or torch.isinf(loss) or utils.item(loss.data) < 0:
logger.warning("Illegal loss %f!" % loss)
if self.ctc_weight != 0:
logger.warning("CTC loss %f!" % ctc_loss)
if self.intermedia_ctc_weight != 0:
logger.warning("Intermedia CTC loss %f!" % intermedia_ctc_loss)
if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss)
if not model.training and self.ctc_weight > 0: if not model.training and self.ctc_weight > 0:
import editdistance import editdistance
with torch.no_grad(): with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
target = sample["transcript"]["tokens"] if "transcript" in sample else sample["target"]
if mixup:
idx = mixup_idx1
if mixup_coef < 0.5:
idx = mixup_idx2
target = target[idx]
c_err = 0 c_err = 0
c_len = 0 c_len = 0
...@@ -312,7 +358,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -312,7 +358,7 @@ class CtcCriterion(FairseqCriterion):
wv_errs = 0 wv_errs = 0
for lp, t, inp_l in zip( for lp, t, inp_l in zip(
lprobs_t, lprobs_t,
sample["transcript"]["tokens"] if "transcript" in sample else sample["target"], target,
input_lengths, input_lengths,
): ):
lp = lp[:inp_l].unsqueeze(0) lp = lp[:inp_l].unsqueeze(0)
...@@ -398,7 +444,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -398,7 +444,7 @@ class CtcCriterion(FairseqCriterion):
) )
if np.isnan(all_ctc_loss_sum) or np.isinf(all_ctc_loss_sum) or all_ctc_loss_sum < 0: if np.isnan(all_ctc_loss_sum) or np.isinf(all_ctc_loss_sum) or all_ctc_loss_sum < 0:
logger.error("Illegal loss %f!" % all_ctc_loss_sum) logger.warning("Illegal loss %f!" % all_ctc_loss_sum)
if all_ctc_loss_sum > 0: if all_ctc_loss_sum > 0:
if "loss" not in logging_outputs[0]: if "loss" not in logging_outputs[0]:
metrics.log_scalar( metrics.log_scalar(
......
...@@ -104,10 +104,40 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -104,10 +104,40 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
else: else:
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous() lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
target = target[self.ignore_prefix_size :, :].contiguous() target = target[self.ignore_prefix_size :, :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1) if "mixup" in net_output[1] and net_output[1]["mixup"] is not None:
mixup = net_output[1]["mixup"]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
target1 = target[idx1].view(-1)
target2 = target[idx2].view(-1)
target = [target1, target2]
else:
target = target.view(-1)
return lprobs.view(-1, lprobs.size(-1)), target
def compute_loss(self, model, net_output, sample, reduce=True): def compute_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample) lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
if type(target) == list:
assert "mixup" in net_output[1] and net_output[1]["mixup"] is not None
coef = net_output[1]["mixup"]["coef"]
loss1, nll_loss1 = label_smoothed_nll_loss(
lprobs,
target[0],
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
loss2, nll_loss2 = label_smoothed_nll_loss(
lprobs,
target[1],
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
loss = coef * loss1 + (1 - coef) * loss2
nll_loss = coef * nll_loss1 + (1 - coef) * nll_loss2
else:
loss, nll_loss = label_smoothed_nll_loss( loss, nll_loss = label_smoothed_nll_loss(
lprobs, lprobs,
target, target,
...@@ -119,6 +149,15 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -119,6 +149,15 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def compute_accuracy(self, model, net_output, sample): def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample) lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
if type(target) == list:
n_correct = total = 0
for item in target:
mask = item.ne(self.padding_idx)
n_correct += torch.sum(
lprobs.argmax(1).masked_select(mask).eq(item.masked_select(mask))
)
total += torch.sum(mask)
else:
mask = target.ne(self.padding_idx) mask = target.ne(self.padding_idx)
n_correct = torch.sum( n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)) lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
......
...@@ -53,6 +53,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -53,6 +53,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
""" """
src_tokens, src_lengths, prev_output_tokens = sample["net_input"].values() src_tokens, src_lengths, prev_output_tokens = sample["net_input"].values()
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths) encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
use_mixup = False
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
use_mixup = True
net_output = model.decoder( net_output = model.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
) )
...@@ -61,11 +66,18 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -61,11 +66,18 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
sample_size = ( sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"] sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
) )
n_tokens = sample["ntokens"]
n_sentences = sample["target"].size(0)
if use_mixup:
sample_size //= 2
n_tokens //= 2
n_sentences //= 2
logging_output = { logging_output = {
"trans_loss": utils.item(loss.data) if reduce else loss.data, "trans_loss": utils.item(loss.data) if reduce else loss.data,
"nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data, "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
"ntokens": sample["ntokens"], "ntokens": n_tokens,
"nsentences": sample["target"].size(0), "nsentences": n_sentences,
"sample_size": sample_size, "sample_size": sample_size,
} }
......
import logging import logging
import math import math
from functools import reduce from functools import reduce
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -586,6 +587,30 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -586,6 +587,30 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=float, type=float,
help="temperature of the intermedia ctc probability", help="temperature of the intermedia ctc probability",
) )
# mixup
parser.add_argument(
"--inter-mixup",
action="store_true",
help="use mixup or not",
)
parser.add_argument(
"--inter-mixup-layer",
default=None,
type=int,
help="the layers for mixup",
)
parser.add_argument(
"--inter-mixup-beta",
default=0.5,
type=float,
help="the coefficient beta for mixup",
)
parser.add_argument(
"--inter-mixup-prob",
default=1,
type=float,
help="the probability to apply mixup",
)
pass pass
@classmethod @classmethod
...@@ -633,10 +658,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -633,10 +658,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")] self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")]
self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")] self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")]
self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")] self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")]
self.pds_attn_ds_ratios = [int(n) for n in args.pds_attn_ds_ratios.split("_")] self.pds_attn_ds_ratios = \
[int(n) for n in args.pds_attn_ds_ratios.split("_")] if args.pds_attn_ds_ratios is not None else None
self.pds_conv_strides = [int(n) for n in args.pds_conv_strides.split("_")] self.pds_conv_strides = \
self.pds_attn_strides = [int(n) for n in args.pds_attn_strides.split("_")] [int(n) for n in args.pds_conv_strides.split("_")] if args.pds_conv_strides is not None else None
self.pds_attn_strides = \
[int(n) for n in args.pds_attn_strides.split("_")] if args.pds_attn_strides is not None else None
# fusion # fusion
self.pds_fusion = args.pds_fusion self.pds_fusion = args.pds_fusion
...@@ -679,8 +707,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -679,8 +707,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
ffn_ratio = self.pds_ffn_ratios[i] ffn_ratio = self.pds_ffn_ratios[i]
num_head = self.pds_attn_heads[i] num_head = self.pds_attn_heads[i]
attn_ds_ratio = self.pds_attn_ds_ratios[i] # if self.attn_type == "reduced" else -1 attn_ds_ratio = self.pds_attn_ds_ratios[i] # if self.attn_type == "reduced" else -1
conv_stride = self.pds_conv_strides[i] conv_stride = self.pds_conv_strides[i] if self.pds_conv_strides is not None else 1
attn_stride = self.pds_attn_strides[i] attn_stride = self.pds_attn_strides[i] if self.pds_attn_strides is not None else 1
if conv_stride != 1 or attn_stride != 1: if conv_stride != 1 or attn_stride != 1:
expand_embed_dim = embed_dim if i == self.pds_stages - 1 else self.pds_embed_dims[i + 1] expand_embed_dim = embed_dim if i == self.pds_stages - 1 else self.pds_embed_dims[i + 1]
else: else:
...@@ -860,7 +888,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -860,7 +888,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False) need_layernorm=True if self.inter_ctc else False)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None: if task.source_dictionary == task.target_dictionary and \
embed_tokens is not None and embed_dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = embed_tokens.weight
else: else:
self.ctc = inter_ctc_module self.ctc = inter_ctc_module
...@@ -871,6 +900,18 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -871,6 +900,18 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.layer_norm = None self.layer_norm = None
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1) self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
# mixup
self.mixup = getattr(args, "inter_mixup", False)
if self.mixup:
self.mixup_layer = args.inter_mixup_layer
self.mixup_prob = getattr(args, "inter_mixup_prob", 1.0)
beta = args.inter_mixup_beta
from torch.distributions import Beta
self.beta = Beta(torch.Tensor([beta]), torch.Tensor([beta]))
logger.info("Use mixup in layer %d with beta %f." % (self.mixup_layer, beta))
# gather cosine similarity
self.gather_cos_sim = getattr(args, "gather_cos_sim", False) self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
self.dis = 2 self.dis = 2
self.cos_sim = dict() self.cos_sim = dict()
...@@ -893,6 +934,32 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -893,6 +934,32 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.cos_sim[idx] = [] self.cos_sim[idx] = []
self.cos_sim[idx].append(float(sim)) self.cos_sim[idx].append(float(sim))
def apply_mixup(self, x, encoder_padding_mask):
batch = x.size(1)
indices = np.random.permutation(batch)
if len(indices) % 2 != 0:
indices = np.append(indices, (indices[-1]))
idx1 = torch.from_numpy(indices[0::2]).to(x.device)
idx2 = torch.from_numpy(indices[1::2]).to(x.device)
x1 = x[:, idx1]
x2 = x[:, idx2]
coef = self.beta.sample().to(x.device).type_as(x)
x = (coef * x1 + (1 - coef) * x2)
pad1 = encoder_padding_mask[idx1]
pad2 = encoder_padding_mask[idx2]
encoder_padding_mask = pad1 + pad2
input_lengths = (~encoder_padding_mask).sum(-1)
mixup = {
"coef": coef,
"index1": idx1,
"index2": idx2,
}
return x, encoder_padding_mask, input_lengths, mixup
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
batch = src_tokens.size(0) batch = src_tokens.size(0)
...@@ -908,6 +975,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -908,6 +975,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
padding_for_pds = x.new_zeros((padding_to_len, batch, x.size(2))) padding_for_pds = x.new_zeros((padding_to_len, batch, x.size(2)))
x = torch.cat([x, padding_for_pds], dim=0) x = torch.cat([x, padding_for_pds], dim=0)
encoder_padding_mask = lengths_to_padding_mask_with_maxlen(input_lengths, x.size(0))
# gather cosine similarity # gather cosine similarity
cos_sim_idx = -1 cos_sim_idx = -1
dis = self.dis dis = self.dis
...@@ -916,6 +985,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -916,6 +985,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
layer_idx = 0 layer_idx = 0
ctc_logit = None ctc_logit = None
mixup = None
prev_state = [] prev_state = []
prev_padding = [] prev_padding = []
intermedia_ctc_logits = [] intermedia_ctc_logits = []
...@@ -926,6 +996,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -926,6 +996,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
ctc = getattr(self, f"ctc{i + 1}") ctc = getattr(self, f"ctc{i + 1}")
adapter = getattr(self, f"adapter{i + 1}") adapter = getattr(self, f"adapter{i + 1}")
if self.training and self.mixup and layer_idx == self.mixup_layer:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
x, input_lengths = downsampling(x, input_lengths) x, input_lengths = downsampling(x, input_lengths)
encoder_padding_mask = lengths_to_padding_mask_with_maxlen(input_lengths, x.size(0)) encoder_padding_mask = lengths_to_padding_mask_with_maxlen(input_lengths, x.size(0))
...@@ -975,6 +1048,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -975,6 +1048,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
cos_sim_idx += 1 cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx) self.add_to_dict(x, dis, cos_sim_idx)
if self.training and self.mixup and layer_idx == self.mixup_layer:
if torch.rand(1) < self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone()) ctc_logit = self.ctc(x.clone())
...@@ -1027,6 +1104,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -1027,6 +1104,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C "ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"intermedia_ctc_logits": intermedia_ctc_logits, # T x B x C "intermedia_ctc_logits": intermedia_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"mixup": mixup,
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C] "encoder_states": [], # List[T x B x C]
"src_tokens": [], "src_tokens": [],
...@@ -1153,6 +1231,12 @@ def base_architecture(args): ...@@ -1153,6 +1231,12 @@ def base_architecture(args):
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none") args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 0.5)
def set_pds_base_8(args): def set_pds_base_8(args):
args.pds_stages = getattr(args, "pds_stages", 4) args.pds_stages = getattr(args, "pds_stages", 4)
......
...@@ -553,7 +553,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -553,7 +553,8 @@ class S2TTransformerEncoder(FairseqEncoder):
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False) need_layernorm=True if self.inter_ctc else False)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None: if task.source_dictionary == task.target_dictionary and \
embed_tokens is not None and dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_dropout = getattr(args, "interleave_dropout", None) self.interleaved_dropout = getattr(args, "interleave_dropout", None)
...@@ -769,7 +770,7 @@ class TransformerDecoderScriptable(TransformerDecoder): ...@@ -769,7 +770,7 @@ class TransformerDecoderScriptable(TransformerDecoder):
alignment_heads: Optional[int] = None, alignment_heads: Optional[int] = None,
): ):
# call scriptable method from parent class # call scriptable method from parent class
x, _ = self.extract_features_scriptable( x, extra = self.extract_features_scriptable(
prev_output_tokens, prev_output_tokens,
encoder_out, encoder_out,
incremental_state, incremental_state,
...@@ -777,7 +778,7 @@ class TransformerDecoderScriptable(TransformerDecoder): ...@@ -777,7 +778,7 @@ class TransformerDecoderScriptable(TransformerDecoder):
alignment_layer, alignment_layer,
alignment_heads, alignment_heads,
) )
return x, None return x, extra
def get_normalized_probs_scriptable( def get_normalized_probs_scriptable(
self, self,
......
...@@ -1049,6 +1049,23 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -1049,6 +1049,23 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
mixup = None
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
mixup = encoder_out["mixup"]
coef = mixup["coef"]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
x1 = x[:, idx1]
x2 = x[:, idx2]
x = coef * x1 + (1 - coef) * x2
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1]
pad2 = self_attn_padding_mask[idx2]
self_attn_padding_mask = pad1 + pad2
# decoder layers # decoder layers
avg_attn = None avg_attn = None
attn: Optional[Tensor] = None attn: Optional[Tensor] = None
...@@ -1132,7 +1149,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -1132,7 +1149,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self.project_out_dim is not None: if self.project_out_dim is not None:
x = self.project_out_dim(x) x = self.project_out_dim(x)
return x, {"attn": [attn], "inner_states": inner_states} return x, {"attn": [attn], "inner_states": inner_states, "mixup": mixup}
def output_layer(self, features): def output_layer(self, features):
"""Project features to the vocabulary size.""" """Project features to the vocabulary size."""
......
...@@ -73,11 +73,16 @@ class ESPNETMultiHeadedAttention(nn.Module): ...@@ -73,11 +73,16 @@ class ESPNETMultiHeadedAttention(nn.Module):
if mask is not None: if mask is not None:
scores = scores.masked_fill( scores = scores.masked_fill(
mask.unsqueeze(1).unsqueeze(2).to(bool), mask.unsqueeze(1).unsqueeze(2).to(bool),
# -1e8 if scores.dtype == torch.float32 else -1e4 -1e8 if scores.dtype == torch.float32 else -1e4
float("-inf"), # (batch, head, time1, time2) # float("-inf"), # (batch, head, time1, time2)
) )
scores = scores.clamp(min=-1e8 if scores.dtype == torch.float32 else -1e4,
max=1e8 if scores.dtype == torch.float32 else 1e4)
self.attn = F.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores) # (batch, head, time1, time2) self.attn = F.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores) # (batch, head, time1, time2)
# self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) if torch.isnan(self.attn).any():
import logging
logging.error("Tensor attention scores has nan.")
p_attn = self.dropout(self.attn) p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = ( x = (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论