Commit 6cbfe851 by xuchen

implement the mixup method for speech-to-text

parent 67d8695f
arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 512
pds-stages: 4
#ctc-layer: 15
encoder-layers: 18
pds-layers: 6_3_3_6
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
decoder-layers: 6
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
......@@ -38,18 +38,18 @@ task=speech_to_text
vocab_type=unigram
vocab_size=5000
speed_perturb=0
lcrm=0
lcrm=1
tokenizer=0
use_raw_audio=0
use_specific_dict=0
specific_prefix=st
specific_dir=${root_dir}/data/mustc/st
asr_vocab_prefix=spm_unigram10000_st_share
use_specific_dict=1
specific_prefix=unified
specific_dir=${root_dir}/data/iwslt2022/vocab
asr_vocab_prefix=spm_en
org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/asr
train_split=train
train_split=train_covost,train_eu,train_iwslt,train_mustc_ende,train_voxpopuil,train_mustc_enzh,train_ted
valid_split=dev
test_split=tst-COMMON
test_subset=tst-COMMON
......
#! /bin/bash
# Processing MuST-C Datasets
# Processing IWSLT 2022 Datasets
# Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com)
......@@ -141,6 +141,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir}
fi
if [[ ! -e ${data_dir}/data ]]; then
mkdir -p ${data_dir}/data
fi
if [[ ! -f ${data_dir}/${src_vocab_prefix}.txt || ! -f ${data_dir}/${tgt_vocab_prefix}.txt ]]; then
if [[ ${use_specific_dict} -eq 0 ]]; then
......@@ -154,52 +157,31 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--tgt-vocab-type ${tgt_vocab_type}
--src-vocab-size ${src_vocab_size}
--tgt-vocab-size ${tgt_vocab_size}"
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
else
cp -r ${specific_dir}/${src_vocab_prefix}.* ${data_dir}
cp ${specific_dir}/${tgt_vocab_prefix}.* ${data_dir}
fi
fi
mkdir -p ${data_dir}/data
for split in ${train_subset} ${valid_subset} ${trans_subset}; do
{
if [[ -d ${org_data_dir}/data/${split}/txt ]]; then
text_dir=${org_data_dir}/data/${split}/txt
else
text_dir=${org_data_dir}/data/${split}
cmd="python ${code_dir}/examples/speech_to_text/prep_mt_data.py
--data-root ${org_data_dir}
--output-root ${data_dir}
--splits ${train_subset},${valid_subset},${trans_subset}
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--src-vocab-prefix ${src_vocab_prefix}
--tgt-vocab-prefix ${tgt_vocab_prefix}"
fi
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi
src_text=${text_dir}/${split}.${src_lang}
tgt_text=${text_dir}/${split}.${tgt_lang}
cmd="cat ${src_text}"
if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${src_text}"
cmd="$cmd
--lowercase-src
--rm-punc-src"
fi
cmd="${cmd}
| spm_encode --model ${data_dir}/${src_vocab_prefix}.model
--output_format=piece
> ${data_dir}/data/${split}.${src_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
cmd="spm_encode
--model ${data_dir}/${tgt_vocab_prefix}.model
--output_format=piece
< ${tgt_text}
> ${data_dir}/data/${split}.${tgt_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
}&
done
wait
fi
cmd="python ${code_dir}/fairseq_cli/preprocess.py
--source-lang ${src_lang} --target-lang ${tgt_lang}
......
arch: s2t_sate
arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
inter-mixup: True
inter-mixup-layer: 0
inter-mixup-beta: 0.5
encoder-embed-dim: 384
pds-stages: 4
#ctc-layer: 15
encoder-layers: 6
pds-layers: 2_1_1_2
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_256_256_384
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_4
pds-attn-heads: 4_4_4_6
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
......@@ -9,39 +39,16 @@ lr: 2e-3
adam_betas: (0.9,0.98)
ctc-weight: 0.3
target-ctc-weight: 0.2
target-ctc-layers: 3,6
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 4
decoder-layers: 6
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
acoustic-encoder: transformer
adapter: league
#load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from:
......@@ -12,7 +12,7 @@ arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-layers: 4_2_2_4
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
......
arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 384
pds-stages: 4
#ctc-layer: 15
encoder-layers: 12
pds-layers: 4_3_3_2
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_256_256_384
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_4
pds-attn-heads: 4_4_4_6
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
decoder-layers: 6
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 512
pds-stages: 4
#ctc-layer: 15
encoder-layers: 18
pds-layers: 6_3_3_6
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
decoder-layers: 6
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 200
pds-stages: 3
pds-layers: 4_4_4
pds-ratios: 2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 200_200_200
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 5_5_5
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-layers: 12
#macaron-style: True
#use-cnn-module: True
#cnn-module-kernel: 15
#encoder-activation-fn: swish
#encoder-attention-type: rel_pos
#load-pretrained-encoder-from:
......@@ -6,27 +6,27 @@ encoder-type: pds
#intermedia-ctc-weight: 1
#intermedia-temperature: 5
encoder-attention-type: rel_pos
#encoder-attention-type: rel_pos
#encoder-attention-type: reduced_rel_pos
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 512
encoder-embed-dim: 384
pds-stages: 4
#ctc-layer: 15
encoder-layers: 10
pds-layers: 3_2_2_3
encoder-layers: 12
pds-layers: 4_3_3_2
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-embed-dims: 128_256_256_384
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
pds-ffn-ratios: 8_8_8_4
pds-attn-heads: 4_4_4_8
optimizer: adam
clip-norm: 10.0
......@@ -42,9 +42,4 @@ post-process: sentencepiece
dropout: 0.1
activation-fn: relu
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
#load-pretrained-encoder-from:
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#intermedia-temperature: 5
encoder-attention-type: rel_pos
#encoder-attention-type: reduced_rel_pos
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 512
pds-stages: 4
#ctc-layer: 15
encoder-layers: 10
pds-layers: 3_2_2_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.002
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
#load-pretrained-encoder-from:
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 240
pds-stages: 3
pds-layers: 4_4_4
pds-ratios: 2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 5_5_5
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-layers: 12
#macaron-style: True
#use-cnn-module: True
#cnn-module-kernel: 15
#encoder-activation-fn: swish
#encoder-attention-type: rel_pos
#load-pretrained-encoder-from:
......@@ -35,15 +35,17 @@ lang=${src_lang}-${tgt_lang}
dataset=mustc
task=translation
vocab_type=unigram
vocab_size=10000
src_vocab_type=unigram
tgt_vocab_type=unigram
src_vocab_size=10000
tgt_vocab_size=10000
share_dict=1
lcrm=0
tokenizer=0
use_specific_dict=1
specific_prefix=st
specific_dir=${root_dir}/data/mustc/st
specific_dir=${root_dir}/data/${dataset}/st
src_vocab_prefix=spm_unigram10000_st_share
tgt_vocab_prefix=spm_unigram10000_st_share
......@@ -80,17 +82,24 @@ len_penalty=1.0
if [[ ${use_specific_dict} -eq 1 ]]; then
exp_prefix=${exp_prefix}_${specific_prefix}
data_dir=${data_dir}/${specific_prefix}
mkdir -p ${data_dir}
else
if [[ "${vocab_type}" == "char" ]]; then
vocab_name=${vocab_type}
exp_prefix=${exp_prefix}_${vocab_type}
if [[ "${tgt_vocab_type}" == "char" ]]; then
vocab_name=char
exp_prefix=${exp_prefix}_char
else
if [[ ${src_vocab_size} -ne ${tgt_vocab_size} || "${src_vocab_type}" -ne "${tgt_vocab_type}" ]]; then
src_vocab_name=${src_vocab_type}${src_vocab_size}
tgt_vocab_name=${tgt_vocab_type}${tgt_vocab_size}
vocab_name=${src_vocab_name}_${tgt_vocab_name}
else
vocab_name=${vocab_type}${vocab_size}
vocab_name=${tgt_vocab_type}${tgt_vocab_size}
src_vocab_name=${vocab_name}
tgt_vocab_name=${vocab_name}
fi
fi
data_dir=${data_dir}/${vocab_name}
src_vocab_prefix=spm_${vocab_name}_${src_lang}
tgt_vocab_prefix=spm_${vocab_name}_${tgt_lang}
src_vocab_prefix=spm_${src_vocab_name}_${src_lang}
tgt_vocab_prefix=spm_${tgt_vocab_name}_${tgt_lang}
if [[ $share_dict -eq 1 ]]; then
data_dir=${data_dir}_share
src_vocab_prefix=spm_${vocab_name}_share
......@@ -132,6 +141,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir}
fi
if [[ ! -e ${data_dir}/data ]]; then
mkdir -p ${data_dir}/data
fi
if [[ ! -f ${data_dir}/${src_vocab_prefix}.txt || ! -f ${data_dir}/${tgt_vocab_prefix}.txt ]]; then
if [[ ${use_specific_dict} -eq 0 ]]; then
......@@ -141,51 +153,35 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--splits ${train_subset},${valid_subset},${trans_subset}
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--vocab-type ${vocab_type}
--vocab-size ${vocab_size}"
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
--src-vocab-type ${src_vocab_type}
--tgt-vocab-type ${tgt_vocab_type}
--src-vocab-size ${src_vocab_size}
--tgt-vocab-size ${tgt_vocab_size}"
else
cp -r ${specific_dir}/${src_vocab_prefix}.* ${data_dir}
cp ${specific_dir}/${tgt_vocab_prefix}.* ${data_dir}
fi
fi
mkdir -p ${data_dir}/data
for split in ${train_subset} ${valid_subset} ${trans_subset}; do
{
if [[ -d ${org_data_dir}/data/${split}/txt ]]; then
txt_dir=${org_data_dir}/data/${split}/txt
else
txt_dir=${org_data_dir}/data/${split}
cmd="python ${code_dir}/examples/speech_to_text/prep_mt_data.py
--data-root ${org_data_dir}
--output-root ${data_dir}
--splits ${train_subset},${valid_subset},${trans_subset}
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--src-vocab-prefix ${src_vocab_prefix}
--tgt-vocab-prefix ${tgt_vocab_prefix}"
fi
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi
cmd="cat ${txt_dir}/${split}.${src_lang}"
if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${org_data_dir}/data/${split}.${src_lang}"
cmd="$cmd
--lowercase-src
--rm-punc-src"
fi
cmd="${cmd}
| spm_encode --model ${data_dir}/${src_vocab_prefix}.model
--output_format=piece
> ${data_dir}/data/${split}.${src_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
cmd="spm_encode
--model ${data_dir}/${tgt_vocab_prefix}.model
--output_format=piece
< ${txt_dir}/${split}.${tgt_lang}
> ${data_dir}/data/${split}.${tgt_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
}&
done
wait
fi
cmd="python ${code_dir}/fairseq_cli/preprocess.py
--source-lang ${src_lang} --target-lang ${tgt_lang}
......@@ -317,11 +313,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device}
cmd="nohup ${cmd} >> ${model_dir}/train.log 2>&1 &"
log=${model_dir}/train.log
cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then
eval $cmd
sleep 2s
tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log
tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log}
fi
fi
wait
......
......@@ -35,8 +35,10 @@ lang=${src_lang}-${tgt_lang}
dataset=wmt16.en-de
task=translation
vocab_type=unigram
vocab_size=32000
src_vocab_type=unigram
tgt_vocab_type=unigram
src_vocab_size=10000
tgt_vocab_size=10000
share_dict=1
lcrm=0
tokenizer=1
......@@ -81,17 +83,24 @@ len_penalty=1.0
if [[ ${use_specific_dict} -eq 1 ]]; then
exp_prefix=${exp_prefix}_${specific_prefix}
data_dir=${data_dir}/${specific_prefix}
mkdir -p ${data_dir}
else
if [[ "${vocab_type}" == "char" ]]; then
vocab_name=${vocab_type}
exp_prefix=${exp_prefix}_${vocab_type}
if [[ "${tgt_vocab_type}" == "char" ]]; then
vocab_name=char
exp_prefix=${exp_prefix}_char
else
if [[ ${src_vocab_size} -ne ${tgt_vocab_size} || "${src_vocab_type}" -ne "${tgt_vocab_type}" ]]; then
src_vocab_name=${src_vocab_type}${src_vocab_size}
tgt_vocab_name=${tgt_vocab_type}${tgt_vocab_size}
vocab_name=${src_vocab_name}_${tgt_vocab_name}
else
vocab_name=${vocab_type}${vocab_size}
vocab_name=${tgt_vocab_type}${tgt_vocab_size}
src_vocab_name=${vocab_name}
tgt_vocab_name=${vocab_name}
fi
fi
data_dir=${data_dir}/${vocab_name}
src_vocab_prefix=spm_${vocab_name}_${src_lang}
tgt_vocab_prefix=spm_${vocab_name}_${tgt_lang}
src_vocab_prefix=spm_${src_vocab_name}_${src_lang}
tgt_vocab_prefix=spm_${tgt_vocab_name}_${tgt_lang}
if [[ $share_dict -eq 1 ]]; then
data_dir=${data_dir}_share
src_vocab_prefix=spm_${vocab_name}_share
......@@ -103,6 +112,9 @@ if [[ ${lcrm} -eq 1 ]]; then
exp_prefix=${exp_prefix}_lcrm
fi
if [[ ${tokenizer} -eq 1 ]]; then
train_subset=${train_subset}.tok
valid_subset=${valid_subset}.tok
trans_subset=${trans_subset}.tok
data_dir=${data_dir}_tok
exp_prefix=${exp_prefix}_tok
fi
......@@ -130,6 +142,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir}
fi
if [[ ! -e ${data_dir}/data ]]; then
mkdir -p ${data_dir}/data
fi
if [[ ! -f ${data_dir}/${src_vocab_prefix}.txt || ! -f ${data_dir}/${tgt_vocab_prefix}.txt ]]; then
if [[ ${use_specific_dict} -eq 0 ]]; then
......@@ -139,62 +154,35 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--splits ${train_subset},${valid_subset},${trans_subset}
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--vocab-type ${vocab_type}
--vocab-size ${vocab_size}"
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi
if [[ ${tokenizer} -eq 1 ]]; then
cmd="$cmd
--tokenizer"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
--src-vocab-type ${src_vocab_type}
--tgt-vocab-type ${tgt_vocab_type}
--src-vocab-size ${src_vocab_size}
--tgt-vocab-size ${tgt_vocab_size}"
else
cp -r ${specific_dir}/${src_vocab_prefix}.* ${data_dir}
cp ${specific_dir}/${tgt_vocab_prefix}.* ${data_dir}
fi
fi
mkdir -p ${data_dir}/data
for split in ${train_subset} ${valid_subset} ${trans_subset}; do
{
if [[ -d ${org_data_dir}/data/${split}/txt ]]; then
text_dir=${org_data_dir}/data/${split}/txt
else
text_dir=${org_data_dir}/data/${split}
cmd="python ${code_dir}/examples/speech_to_text/prep_mt_data.py
--data-root ${org_data_dir}
--output-root ${data_dir}
--splits ${train_subset},${valid_subset},${trans_subset}
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--src-vocab-prefix ${src_vocab_prefix}
--tgt-vocab-prefix ${tgt_vocab_prefix}"
fi
src_text=${text_dir}/${split}.${src_lang}
tgt_text=${text_dir}/${split}.${tgt_lang}
if [[ ${tokenizer} -eq 1 ]]; then
src_text=${text_dir}/${split}.tok.${src_lang}
tgt_text=${text_dir}/${split}.tok.${tgt_lang}
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi
cmd="cat ${src_text}"
if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${src_text}"
cmd="$cmd
--lowercase-src
--rm-punc-src"
fi
cmd="${cmd}
| spm_encode --model ${data_dir}/${src_vocab_prefix}.model
--output_format=piece
> ${data_dir}/data/${split}.${src_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
cmd="spm_encode
--model ${data_dir}/${tgt_vocab_prefix}.model
--output_format=piece
< ${tgt_text}
> ${data_dir}/data/${split}.${tgt_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
}&
done
wait
fi
cmd="python ${code_dir}/fairseq_cli/preprocess.py
--source-lang ${src_lang} --target-lang ${tgt_lang}
......@@ -327,16 +315,14 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
export CUDA_VISIBLE_DEVICES=${device}
log=${model_dir}/train.log
cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then
eval $cmd
sleep 2s
tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log}
fi
wait
echo -e " >> finish training \n"
fi
wait
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: MT Decoding"
......
......@@ -45,7 +45,7 @@ tokenizer=1
use_specific_dict=1
specific_prefix=unified
specific_dir=${root_dir}/data/wmt20/vocab
specific_dir=${root_dir}/data/${dataset}/vocab
src_vocab_prefix=spm_en
tgt_vocab_prefix=spm_zh
......
......@@ -148,19 +148,50 @@ class CtcCriterion(FairseqCriterion):
loss, logging_output = self.compute_ctc_loss(model, sample, net_output, logging_output)
return loss, sample_size, logging_output
def get_loss(self, lprobs, targets_flat, input_lengths, transcript_lengths):
with torch.backends.cudnn.flags(enabled=False):
ctc_loss = self.ctc_loss(
lprobs,
targets_flat,
input_lengths,
transcript_lengths,
)
return ctc_loss
def compute_ctc_loss(self, model, sample, net_output, logging_output):
transcript = sample["transcript"]
if "ctc_padding_mask" in net_output:
non_padding_mask = ~net_output["ctc_padding_mask"][0]
else:
# if "ctc_padding_mask" in net_output:
# non_padding_mask = ~net_output["ctc_padding_mask"][0]
# else:
# non_padding_mask = ~net_output["encoder_padding_mask"][0]
mixup = False
if "mixup" in net_output and net_output["mixup"] is not None:
mixup = True
mixup_coef = net_output["mixup"]["coef"]
mixup_idx1 = net_output["mixup"]["index1"]
mixup_idx2 = net_output["mixup"]["index2"]
non_padding_mask = ~net_output["encoder_padding_mask"][0]
ctc_input_lengths = input_lengths = non_padding_mask.long().sum(-1)
input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (transcript["tokens"] != self.pad_idx) & (
transcript["tokens"] != self.eos_idx
)
targets_flat = transcript["tokens"].masked_select(pad_mask)
transcript_lengths = pad_mask.sum(-1)
if mixup:
mask1 = pad_mask[mixup_idx1]
mask2 = pad_mask[mixup_idx2]
transcript_flat1 = transcript["tokens"][[mixup_idx1]].masked_select(mask1)
transcript_flat2 = transcript["tokens"][mixup_idx2].masked_select(mask2)
transcript_lengths1 = mask1.sum(-1)
transcript_lengths2 = mask2.sum(-1)
transcript_flat = [transcript_flat1, transcript_flat2]
transcript_lengths = [transcript_lengths1, transcript_lengths2]
loss_coef = [mixup_coef, 1 - mixup_coef]
else:
transcript_flat = [transcript["tokens"].masked_select(pad_mask)]
transcript_lengths = [pad_mask.sum(-1)]
loss_coef = [1]
ctc_loss = 0
ctc_entropy = 0
......@@ -172,13 +203,9 @@ class CtcCriterion(FairseqCriterion):
).contiguous() # (T, B, C) from the encoder
lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False):
ctc_loss = self.ctc_loss(
lprobs,
targets_flat,
input_lengths,
transcript_lengths,
)
for flat, lengths, coef in zip(transcript_flat, transcript_lengths, loss_coef):
ctc_loss += self.get_loss(lprobs, flat, input_lengths, lengths) * coef
if self.ctc_entropy > 0:
from torch.distributions import Categorical
# ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100]
......@@ -201,23 +228,19 @@ class CtcCriterion(FairseqCriterion):
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
inter_input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
inter_input_lengths = input_lengths
inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False):
loss = self.ctc_loss(
inter_lprobs,
targets_flat,
input_lengths,
transcript_lengths,
)
intermedia_ctc_loss += loss
for flat, lengths, coef in zip(transcript_flat, transcript_lengths, loss_coef):
intermedia_ctc_loss += self.get_loss(inter_lprobs, flat, inter_input_lengths, lengths) * coef
intermedia_ctc_loss /= intermedia_ctc_num
logging_output["intermedia_ctc_loss"] = utils.item(intermedia_ctc_loss.data)
......@@ -233,31 +256,40 @@ class CtcCriterion(FairseqCriterion):
if self.target_ctc_weight > 0 and target_ctc_num > 0:
target = sample["target"]
pad_mask = (target != self.pad_idx) & (target != self.eos_idx)
targets_flat = target.masked_select(pad_mask)
target_length = pad_mask.sum(-1)
if mixup:
mask1 = pad_mask[mixup_idx1]
mask2 = pad_mask[mixup_idx2]
target_flat1 = target.masked_select(mask1)
target_flat2 = target.masked_select(mask2)
transcript_lengths1 = mask1.sum(-1)
transcript_lengths2 = mask2.sum(-1)
target_flat = [target_flat1, target_flat2]
target_length = [transcript_lengths1, transcript_lengths2]
loss_coef = [mixup_coef, 1 - mixup_coef]
else:
target_flat = [target.masked_select(pad_mask)]
target_length = [pad_mask.sum(-1)]
loss_coef = [1]
for i in range(target_ctc_num):
out = net_output["target_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
tgt_input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
tgt_input_lengths = input_lengths
inter_lprobs = model.get_normalized_probs(
tgt_inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False
tgt_inter_lprobs.batch_first = False
for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths, lengths) * coef
with torch.backends.cudnn.flags(enabled=False):
loss = self.ctc_loss(
inter_lprobs,
targets_flat,
ctc_input_lengths,
target_length,
)
target_ctc_loss += loss
target_ctc_loss /= target_ctc_num
logging_output["target_ctc_loss"] = utils.item(target_ctc_loss.data)
......@@ -270,7 +302,6 @@ class CtcCriterion(FairseqCriterion):
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
......@@ -299,11 +330,26 @@ class CtcCriterion(FairseqCriterion):
logging_output["all_ctc_loss"] = utils.item(loss.data)
if torch.isnan(loss) or torch.isinf(loss) or utils.item(loss.data) < 0:
logger.warning("Illegal loss %f!" % loss)
if self.ctc_weight != 0:
logger.warning("CTC loss %f!" % ctc_loss)
if self.intermedia_ctc_weight != 0:
logger.warning("Intermedia CTC loss %f!" % intermedia_ctc_loss)
if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss)
if not model.training and self.ctc_weight > 0:
import editdistance
with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
target = sample["transcript"]["tokens"] if "transcript" in sample else sample["target"]
if mixup:
idx = mixup_idx1
if mixup_coef < 0.5:
idx = mixup_idx2
target = target[idx]
c_err = 0
c_len = 0
......@@ -312,7 +358,7 @@ class CtcCriterion(FairseqCriterion):
wv_errs = 0
for lp, t, inp_l in zip(
lprobs_t,
sample["transcript"]["tokens"] if "transcript" in sample else sample["target"],
target,
input_lengths,
):
lp = lp[:inp_l].unsqueeze(0)
......@@ -398,7 +444,7 @@ class CtcCriterion(FairseqCriterion):
)
if np.isnan(all_ctc_loss_sum) or np.isinf(all_ctc_loss_sum) or all_ctc_loss_sum < 0:
logger.error("Illegal loss %f!" % all_ctc_loss_sum)
logger.warning("Illegal loss %f!" % all_ctc_loss_sum)
if all_ctc_loss_sum > 0:
if "loss" not in logging_outputs[0]:
metrics.log_scalar(
......
......@@ -104,10 +104,40 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
else:
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
target = target[self.ignore_prefix_size :, :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
if "mixup" in net_output[1] and net_output[1]["mixup"] is not None:
mixup = net_output[1]["mixup"]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
target1 = target[idx1].view(-1)
target2 = target[idx2].view(-1)
target = [target1, target2]
else:
target = target.view(-1)
return lprobs.view(-1, lprobs.size(-1)), target
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
if type(target) == list:
assert "mixup" in net_output[1] and net_output[1]["mixup"] is not None
coef = net_output[1]["mixup"]["coef"]
loss1, nll_loss1 = label_smoothed_nll_loss(
lprobs,
target[0],
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
loss2, nll_loss2 = label_smoothed_nll_loss(
lprobs,
target[1],
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
loss = coef * loss1 + (1 - coef) * loss2
nll_loss = coef * nll_loss1 + (1 - coef) * nll_loss2
else:
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
......@@ -119,6 +149,15 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
if type(target) == list:
n_correct = total = 0
for item in target:
mask = item.ne(self.padding_idx)
n_correct += torch.sum(
lprobs.argmax(1).masked_select(mask).eq(item.masked_select(mask))
)
total += torch.sum(mask)
else:
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
......
......@@ -53,6 +53,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
"""
src_tokens, src_lengths, prev_output_tokens = sample["net_input"].values()
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
use_mixup = False
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
use_mixup = True
net_output = model.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
......@@ -61,11 +66,18 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
n_tokens = sample["ntokens"]
n_sentences = sample["target"].size(0)
if use_mixup:
sample_size //= 2
n_tokens //= 2
n_sentences //= 2
logging_output = {
"trans_loss": utils.item(loss.data) if reduce else loss.data,
"nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"ntokens": n_tokens,
"nsentences": n_sentences,
"sample_size": sample_size,
}
......
import logging
import math
from functools import reduce
import numpy as np
import torch
import torch.nn as nn
......@@ -586,6 +587,30 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=float,
help="temperature of the intermedia ctc probability",
)
# mixup
parser.add_argument(
"--inter-mixup",
action="store_true",
help="use mixup or not",
)
parser.add_argument(
"--inter-mixup-layer",
default=None,
type=int,
help="the layers for mixup",
)
parser.add_argument(
"--inter-mixup-beta",
default=0.5,
type=float,
help="the coefficient beta for mixup",
)
parser.add_argument(
"--inter-mixup-prob",
default=1,
type=float,
help="the probability to apply mixup",
)
pass
@classmethod
......@@ -633,10 +658,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")]
self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")]
self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")]
self.pds_attn_ds_ratios = [int(n) for n in args.pds_attn_ds_ratios.split("_")]
self.pds_attn_ds_ratios = \
[int(n) for n in args.pds_attn_ds_ratios.split("_")] if args.pds_attn_ds_ratios is not None else None
self.pds_conv_strides = [int(n) for n in args.pds_conv_strides.split("_")]
self.pds_attn_strides = [int(n) for n in args.pds_attn_strides.split("_")]
self.pds_conv_strides = \
[int(n) for n in args.pds_conv_strides.split("_")] if args.pds_conv_strides is not None else None
self.pds_attn_strides = \
[int(n) for n in args.pds_attn_strides.split("_")] if args.pds_attn_strides is not None else None
# fusion
self.pds_fusion = args.pds_fusion
......@@ -679,8 +707,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
ffn_ratio = self.pds_ffn_ratios[i]
num_head = self.pds_attn_heads[i]
attn_ds_ratio = self.pds_attn_ds_ratios[i] # if self.attn_type == "reduced" else -1
conv_stride = self.pds_conv_strides[i]
attn_stride = self.pds_attn_strides[i]
conv_stride = self.pds_conv_strides[i] if self.pds_conv_strides is not None else 1
attn_stride = self.pds_attn_strides[i] if self.pds_attn_strides is not None else 1
if conv_stride != 1 or attn_stride != 1:
expand_embed_dim = embed_dim if i == self.pds_stages - 1 else self.pds_embed_dims[i + 1]
else:
......@@ -860,7 +888,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
if task.source_dictionary == task.target_dictionary and \
embed_tokens is not None and embed_dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight
else:
self.ctc = inter_ctc_module
......@@ -871,6 +900,18 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.layer_norm = None
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
# mixup
self.mixup = getattr(args, "inter_mixup", False)
if self.mixup:
self.mixup_layer = args.inter_mixup_layer
self.mixup_prob = getattr(args, "inter_mixup_prob", 1.0)
beta = args.inter_mixup_beta
from torch.distributions import Beta
self.beta = Beta(torch.Tensor([beta]), torch.Tensor([beta]))
logger.info("Use mixup in layer %d with beta %f." % (self.mixup_layer, beta))
# gather cosine similarity
self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
self.dis = 2
self.cos_sim = dict()
......@@ -893,6 +934,32 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.cos_sim[idx] = []
self.cos_sim[idx].append(float(sim))
def apply_mixup(self, x, encoder_padding_mask):
batch = x.size(1)
indices = np.random.permutation(batch)
if len(indices) % 2 != 0:
indices = np.append(indices, (indices[-1]))
idx1 = torch.from_numpy(indices[0::2]).to(x.device)
idx2 = torch.from_numpy(indices[1::2]).to(x.device)
x1 = x[:, idx1]
x2 = x[:, idx2]
coef = self.beta.sample().to(x.device).type_as(x)
x = (coef * x1 + (1 - coef) * x2)
pad1 = encoder_padding_mask[idx1]
pad2 = encoder_padding_mask[idx2]
encoder_padding_mask = pad1 + pad2
input_lengths = (~encoder_padding_mask).sum(-1)
mixup = {
"coef": coef,
"index1": idx1,
"index2": idx2,
}
return x, encoder_padding_mask, input_lengths, mixup
def forward(self, src_tokens, src_lengths):
batch = src_tokens.size(0)
......@@ -908,6 +975,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
padding_for_pds = x.new_zeros((padding_to_len, batch, x.size(2)))
x = torch.cat([x, padding_for_pds], dim=0)
encoder_padding_mask = lengths_to_padding_mask_with_maxlen(input_lengths, x.size(0))
# gather cosine similarity
cos_sim_idx = -1
dis = self.dis
......@@ -916,6 +985,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
layer_idx = 0
ctc_logit = None
mixup = None
prev_state = []
prev_padding = []
intermedia_ctc_logits = []
......@@ -926,6 +996,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
ctc = getattr(self, f"ctc{i + 1}")
adapter = getattr(self, f"adapter{i + 1}")
if self.training and self.mixup and layer_idx == self.mixup_layer:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
x, input_lengths = downsampling(x, input_lengths)
encoder_padding_mask = lengths_to_padding_mask_with_maxlen(input_lengths, x.size(0))
......@@ -975,6 +1048,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
if self.training and self.mixup and layer_idx == self.mixup_layer:
if torch.rand(1) < self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone())
......@@ -1027,6 +1104,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"intermedia_ctc_logits": intermedia_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"mixup": mixup,
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
"src_tokens": [],
......@@ -1153,6 +1231,12 @@ def base_architecture(args):
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 0.5)
def set_pds_base_8(args):
args.pds_stages = getattr(args, "pds_stages", 4)
......
......@@ -553,7 +553,8 @@ class S2TTransformerEncoder(FairseqEncoder):
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
if task.source_dictionary == task.target_dictionary and \
embed_tokens is not None and dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_dropout = getattr(args, "interleave_dropout", None)
......@@ -769,7 +770,7 @@ class TransformerDecoderScriptable(TransformerDecoder):
alignment_heads: Optional[int] = None,
):
# call scriptable method from parent class
x, _ = self.extract_features_scriptable(
x, extra = self.extract_features_scriptable(
prev_output_tokens,
encoder_out,
incremental_state,
......@@ -777,7 +778,7 @@ class TransformerDecoderScriptable(TransformerDecoder):
alignment_layer,
alignment_heads,
)
return x, None
return x, extra
def get_normalized_probs_scriptable(
self,
......
......@@ -1049,6 +1049,23 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
mixup = None
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
mixup = encoder_out["mixup"]
coef = mixup["coef"]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
x1 = x[:, idx1]
x2 = x[:, idx2]
x = coef * x1 + (1 - coef) * x2
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1]
pad2 = self_attn_padding_mask[idx2]
self_attn_padding_mask = pad1 + pad2
# decoder layers
avg_attn = None
attn: Optional[Tensor] = None
......@@ -1132,7 +1149,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {"attn": [attn], "inner_states": inner_states}
return x, {"attn": [attn], "inner_states": inner_states, "mixup": mixup}
def output_layer(self, features):
"""Project features to the vocabulary size."""
......
......@@ -73,11 +73,16 @@ class ESPNETMultiHeadedAttention(nn.Module):
if mask is not None:
scores = scores.masked_fill(
mask.unsqueeze(1).unsqueeze(2).to(bool),
# -1e8 if scores.dtype == torch.float32 else -1e4
float("-inf"), # (batch, head, time1, time2)
-1e8 if scores.dtype == torch.float32 else -1e4
# float("-inf"), # (batch, head, time1, time2)
)
scores = scores.clamp(min=-1e8 if scores.dtype == torch.float32 else -1e4,
max=1e8 if scores.dtype == torch.float32 else 1e4)
self.attn = F.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores) # (batch, head, time1, time2)
# self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
if torch.isnan(self.attn).any():
import logging
logging.error("Tensor attention scores has nan.")
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论