Commit 0bd92062 by xuchen

optimize the shell scripts for iwslt2022 En-Zh, implement the method of the Efficient Conformer

parent 55702466
...@@ -129,8 +129,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -129,8 +129,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ${speed_perturb} -eq 1 ]]; then if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip feature_zip=fbank80_sp.zip
fi fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../feature_zip ]]; then if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../feature_zip ${data_dir} ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi fi
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
......
...@@ -23,10 +23,10 @@ asr_vocab_prefix=spm_unigram10000_st_share ...@@ -23,10 +23,10 @@ asr_vocab_prefix=spm_unigram10000_st_share
src_lang=en src_lang=en
tgt_lang=zh tgt_lang=zh
subsets=(train_covost) subsets=(train_covost train_eu train_iwslt train_mustc_ende train_voxpopuil train_mustc_enzh dev tst-COMMON)
mkdir -p $data_dir mkdir -p $data_dir
splits=$(echo ${subsets[*]} | sed 's/ /_/g') splits=$(echo ${subsets[*]} | sed 's/ /,/g')
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
--data-root ${org_data_dir} --data-root ${org_data_dir}
--output-root ${data_dir} --output-root ${data_dir}
......
train-subset: train #train-subset: train_covost,train_eu,train_iwslt,train_mustc_ende,train_voxpopuil,train_mustc_enzh
train-subset: train_mustc_enzh
valid-subset: dev valid-subset: dev
max-epoch: 100 max-epoch: 100
max-update: 100000 max-update: 1000000
patience: 20 patience: 20
best_checkpoint_metric: loss best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False maximize_best_checkpoint_metric: False
......
...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8 ...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
......
...@@ -30,10 +30,10 @@ pwd_dir=$PWD ...@@ -30,10 +30,10 @@ pwd_dir=$PWD
# dataset # dataset
src_lang=en src_lang=en
tgt_lang=de tgt_lang=zh
lang=${src_lang}-${tgt_lang} lang=${src_lang}-${tgt_lang}
dataset=mustc dataset=iwslt2022
task=speech_to_text task=speech_to_text
vocab_type=unigram vocab_type=unigram
vocab_size=5000 vocab_size=5000
...@@ -42,7 +42,7 @@ lcrm=0 ...@@ -42,7 +42,7 @@ lcrm=0
tokenizer=0 tokenizer=0
use_raw_audio=0 use_raw_audio=0
use_specific_dict=1 use_specific_dict=0
specific_prefix=st specific_prefix=st
specific_dir=${root_dir}/data/mustc/st specific_dir=${root_dir}/data/mustc/st
asr_vocab_prefix=spm_unigram10000_st_share asr_vocab_prefix=spm_unigram10000_st_share
...@@ -125,8 +125,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -125,8 +125,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ${speed_perturb} -eq 1 ]]; then if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip feature_zip=fbank80_sp.zip
fi fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../feature_zip ]]; then if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../feature_zip ${data_dir} ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi fi
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
gpu_num=8 gpu_num=8
update_freq=1 update_freq=1
max_tokens=40000 max_tokens=80000
extra_tag= extra_tag=
extra_parameter= extra_parameter=
...@@ -13,11 +13,11 @@ extra_parameter= ...@@ -13,11 +13,11 @@ extra_parameter=
exp_tag= exp_tag=
config_list=(base ctc) #config_list=(base ctc)
config_list=(purectc) #config_list=(purectc)
#config_list=(base conformer) #config_list=(base conformer)
#config_list=(pds_base_16) config_list=(pds_base_8 ctc)
#config_list=(pds_base_16 conformer rpr) #config_list=(pds_base_16 conformer rpr)
# exp full name # exp full name
......
...@@ -30,15 +30,17 @@ pwd_dir=$PWD ...@@ -30,15 +30,17 @@ pwd_dir=$PWD
# dataset # dataset
src_lang=en src_lang=en
tgt_lang=de tgt_lang=zh
lang=${src_lang}-${tgt_lang} lang=${src_lang}-${tgt_lang}
dataset=mustc dataset=iwslt2022
task=translation task=translation
vocab_type=unigram src_vocab_type=unigram
vocab_size=10000 tgt_vocab_type=unigram
share_dict=1 src_vocab_size=32000
lcrm=0 tgt_vocab_size=32000
share_dict=0
lcrm=1
tokenizer=0 tokenizer=0
use_specific_dict=1 use_specific_dict=1
...@@ -49,7 +51,7 @@ tgt_vocab_prefix=spm_unigram10000_st_share ...@@ -49,7 +51,7 @@ tgt_vocab_prefix=spm_unigram10000_st_share
org_data_dir=${root_dir}/data/${dataset} org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/mt data_dir=${root_dir}/data/${dataset}/mt
train_subset=train train_subset=train_mustc_enzh
valid_subset=dev valid_subset=dev
trans_subset=tst-COMMON trans_subset=tst-COMMON
test_subset=test test_subset=test
...@@ -82,15 +84,23 @@ if [[ ${use_specific_dict} -eq 1 ]]; then ...@@ -82,15 +84,23 @@ if [[ ${use_specific_dict} -eq 1 ]]; then
data_dir=${data_dir}/${specific_prefix} data_dir=${data_dir}/${specific_prefix}
mkdir -p ${data_dir} mkdir -p ${data_dir}
else else
if [[ "${vocab_type}" == "char" ]]; then if [[ "${tgt_vocab_type}" == "char" ]]; then
vocab_name=${vocab_type} vocab_name=char
exp_prefix=${exp_prefix}_${vocab_type} exp_prefix=${exp_prefix}_char
else else
vocab_name=${vocab_type}${vocab_size} if [[ ${src_vocab_size} -ne ${tgt_vocab_size} || "${src_vocab_type}" -ne "${tgt_vocab_type}" ]]; then
src_vocab_name=${src_vocab_type}${src_vocab_size}
tgt_vocab_name=${tgt_vocab_type}${tgt_vocab_size}
vocab_name=${src_vocab_name}_${tgt_vocab_name}
else
vocab_name=${tgt_vocab_type}${tgt_vocab_size}
src_vocab_name=${vocab_name}
tgt_vocab_name=${vocab_name}
fi
fi fi
data_dir=${data_dir}/${vocab_name} data_dir=${data_dir}/${vocab_name}
src_vocab_prefix=spm_${vocab_name}_${src_lang} src_vocab_prefix=spm_${src_vocab_name}_${src_lang}
tgt_vocab_prefix=spm_${vocab_name}_${tgt_lang} tgt_vocab_prefix=spm_${tgt_vocab_name}_${tgt_lang}
if [[ $share_dict -eq 1 ]]; then if [[ $share_dict -eq 1 ]]; then
data_dir=${data_dir}_share data_dir=${data_dir}_share
src_vocab_prefix=spm_${vocab_name}_share src_vocab_prefix=spm_${vocab_name}_share
...@@ -141,8 +151,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -141,8 +151,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--splits ${train_subset},${valid_subset},${trans_subset} --splits ${train_subset},${valid_subset},${trans_subset}
--src-lang ${src_lang} --src-lang ${src_lang}
--tgt-lang ${tgt_lang} --tgt-lang ${tgt_lang}
--vocab-type ${vocab_type} --src-vocab-type ${src_vocab_type}
--vocab-size ${vocab_size}" --tgt-vocab-type ${tgt_vocab_type}
--src-vocab-size ${src_vocab_size}
--tgt-vocab-size ${tgt_vocab_size}"
if [[ $share_dict -eq 1 ]]; then if [[ $share_dict -eq 1 ]]; then
cmd="$cmd cmd="$cmd
--share" --share"
......
train-subset: train #train-subset: train_mustc_enzh,train_covost
train-subset: train_mustc_enzh
valid-subset: dev valid-subset: dev
max-epoch: 100 max-epoch: 100
......
...@@ -29,7 +29,7 @@ acoustic-encoder: pds ...@@ -29,7 +29,7 @@ acoustic-encoder: pds
adapter: league adapter: league
encoder-embed-dim: 256 encoder-embed-dim: 256
ctc-layer: 12 #ctc-layer: 12
pds-stages: 4 pds-stages: 4
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
......
...@@ -10,8 +10,8 @@ if [ "$#" -eq 1 ]; then ...@@ -10,8 +10,8 @@ if [ "$#" -eq 1 ]; then
exp_name=$1 exp_name=$1
fi fi
sacrebleu=1 sacrebleu=0
n_average=10 n_average=1
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=80000 max_tokens=80000
......
...@@ -30,29 +30,29 @@ pwd_dir=$PWD ...@@ -30,29 +30,29 @@ pwd_dir=$PWD
# dataset # dataset
src_lang=en src_lang=en
tgt_lang=de tgt_lang=zh
lang=${src_lang}-${tgt_lang} lang=${src_lang}-${tgt_lang}
dataset=mustc dataset=iwslt2022
task=speech_to_text task=speech_to_text
vocab_type=unigram vocab_type=unigram
asr_vocab_size=5000 asr_vocab_size=5000
vocab_size=10000 vocab_size=10000
share_dict=1 share_dict=0
speed_perturb=0 speed_perturb=0
lcrm=0 lcrm=1
tokenizer=0 tokenizer=0
use_raw_audio=0 use_raw_audio=0
use_specific_dict=0 use_specific_dict=1
specific_prefix=valid specific_prefix=asr
specific_dir=${root_dir}/data/mustc/st specific_dir=${root_dir}/data/${dataset}/asr
asr_vocab_prefix=spm_unigram10000_st_share asr_vocab_prefix=spm_unigram5000_asr
st_vocab_prefix=spm_unigram10000_st_share st_vocab_prefix=
org_data_dir=${root_dir}/data/${dataset} org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/st data_dir=${root_dir}/data/${dataset}/st
train_split=train train_split=train_mustc_enzh
valid_split=dev valid_split=dev
test_split=tst-COMMON test_split=tst-COMMON
test_subset=tst-COMMON test_subset=tst-COMMON
...@@ -133,8 +133,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -133,8 +133,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ${speed_perturb} -eq 1 ]]; then if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip feature_zip=fbank80_sp.zip
fi fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../feature_zip ]]; then if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../feature_zip ${data_dir} ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi fi
# create ASR vocabulary if necessary # create ASR vocabulary if necessary
...@@ -147,8 +147,12 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -147,8 +147,12 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--splits ${valid_split},${test_split},${train_split} --splits ${valid_split},${test_split},${train_split}
--vocab-type ${vocab_type} --vocab-type ${vocab_type}
--vocab-size ${asr_vocab_size}" --vocab-size ${asr_vocab_size}"
[[ $eval -eq 1 && ${share_dict} -ne 1 && ${use_specific_dict} -ne 1 ]] && (echo -e "\033[34mRun command: \n${cmd} \033[0m" && eval $cmd) if [[ $eval -eq 1 && ${share_dict} -ne 1 && ${use_specific_dict} -ne 1 ]]; then
asr_prefix=spm_${vocab_type}${asr_vocab_size}_asr echo -e "\033[34mRun command: \n${cmd} \033[0m"
eval $cmd
asr_vocab_prefix=spm_${vocab_type}${asr_vocab_size}_asr
cp ${data_dir}/asr4st/${asr_vocab_prefix}* ${data_dir}
fi
echo "stage 0: ST Data Preparation" echo "stage 0: ST Data Preparation"
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
...@@ -167,25 +171,21 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -167,25 +171,21 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="$cmd cmd="$cmd
--raw" --raw"
fi fi
if [[ ${use_specific_dict} -eq 1 ]]; then if [[ $share_dict -eq 1 ]]; then
cp -r ${specific_dir}/${asr_vocab_prefix}.* ${data_dir}
cp -r ${specific_dir}/${st_vocab_prefix}.* ${data_dir}
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd cmd="$cmd
--share --share"
--st-spm-prefix ${st_vocab_prefix}" else
else
cmd="$cmd cmd="$cmd
--st-spm-prefix ${st_vocab_prefix}
--asr-prefix ${asr_vocab_prefix}" --asr-prefix ${asr_vocab_prefix}"
fi
if [[ ${use_specific_dict} -eq 1 ]]; then
if [[ ${share_dict} -eq 0 && -n ${asr_vocab_prefix} ]]; then
cp -r ${specific_dir}/${asr_vocab_prefix}.* ${data_dir}
fi fi
else if [[ -n ${st_vocab_prefix} ]]; then
if [[ $share_dict -eq 1 ]]; then cp -r ${specific_dir}/${st_vocab_prefix}.* ${data_dir}
cmd="$cmd cmd="$cmd
--share" --st-spm-prefix ${st_vocab_prefix}"
else
cmd="$cmd
--asr-prefix ${asr_prefix}"
fi fi
fi fi
if [[ ${speed_perturb} -eq 1 ]]; then if [[ ${speed_perturb} -eq 1 ]]; then
......
...@@ -14,13 +14,13 @@ extra_parameter= ...@@ -14,13 +14,13 @@ extra_parameter=
exp_tag= exp_tag=
#config_list=(base) #config_list=(base)
config_list=(ctc) #config_list=(sate ctc)
#config_list=(sate_ctc)
#config_list=(ctc conformer rpr) #config_list=(ctc conformer rpr)
#config_list=(base sate) #config_list=(base sate)
#config_list=(pds_base) config_list=(sate_pds ctc)
#config_list=(pds_base conformer) #config_list=(pds_base_8)
#config_list=(pds_base_8 conformer)
# exp full name # exp full name
exp_name= exp_name=
......
...@@ -2,6 +2,17 @@ arch: s2t_ctc ...@@ -2,6 +2,17 @@ arch: s2t_ctc
encoder-type: pds encoder-type: pds
#arch: pdss2t_transformer_s_8 #arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: transfer
#relative-pos-enc: True
encoder-attention-type: rel_pos
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 ctc-layer: 12
......
#! /bin/bash #! /bin/bash
# Processing MuST-C Datasets # Processing LibriSpeech En-Fr Datasets
# Copyright 2021 Natural Language Processing Laboratory # Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com) # Xu Chen (xuchenneu@163.com)
...@@ -124,8 +124,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -124,8 +124,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ${speed_perturb} -eq 1 ]]; then if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip feature_zip=fbank80_sp.zip
fi fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../feature_zip ]]; then if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../feature_zip ${data_dir} ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi fi
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
......
...@@ -6,7 +6,6 @@ gpu_num=2 ...@@ -6,7 +6,6 @@ gpu_num=2
update_freq=1 update_freq=1
max_tokens=40000 max_tokens=40000
extra_tag= extra_tag=
extra_parameter= extra_parameter=
#extra_tag="${extra_tag}" #extra_tag="${extra_tag}"
...@@ -15,10 +14,9 @@ extra_parameter= ...@@ -15,10 +14,9 @@ extra_parameter=
exp_tag= exp_tag=
#config_list=(base) #config_list=(base)
#config_list=(ctc)
#config_list=(base conformer) #config_list=(base conformer)
#config_list=(pds_base_16) #config_list=(pds_base_8)
config_list=(pds_base_8 conformer rpr) config_list=(pds_base_8 conformer rpr)
# exp full name # exp full name
......
#encoder-attention-type: rel_selfattn
encoder-attention-type: relative encoder-attention-type: relative
decoder-attention-type: relative decoder-attention-type: relative
max-encoder-relative-length: 20 max-encoder-relative-length: 8
max-decoder-relative-length: 20 max-decoder-relative-length: 8
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 8_4_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 176
pds-stages: 4
ctc-layer: 16
pds-layers: 4_4_4_4
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 176_176_176_176
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 4_4_4_4
pds-attn-heads: 4_4_4_4
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
#load-pretrained-encoder-from:
...@@ -11,12 +11,12 @@ extra_parameter= ...@@ -11,12 +11,12 @@ extra_parameter=
#extra_tag="${extra_tag}" #extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} " #extra_parameter="${extra_parameter} "
#exp_tag= exp_tag=
#config_list=(base) #config_list=(base)
#config_list=(ctc) #config_list=(base conformer)
#config_list=(ctc conformer rpr) #config_list=(ConformerCTCSmall)
config_list=(base conformer rpr)
config_list=(purectc_pds_base_16)
#config_list=(pds_base) #config_list=(pds_base)
#config_list=(pds_big) #config_list=(pds_big)
#config_list=(pds_deep) #config_list=(pds_deep)
......
arch: pdss2t_transformer_s_8 arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
......
arch: s2t_ctc arch: s2t_ctc
encoder-type: pds encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 8_4_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
...@@ -26,17 +35,12 @@ lr: 2e-3 ...@@ -26,17 +35,12 @@ lr: 2e-3
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: ctc
post-process: sentencepiece
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
encoder-attention-type: rel_selfattn encoder-attention-type: rel_pos
#encoder-attention-type: relative #encoder-attention-type: relative
#max-encoder-relative-length: 100 #max-encoder-relative-length: 100
...@@ -125,8 +125,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -125,8 +125,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ${speed_perturb} -eq 1 ]]; then if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip feature_zip=fbank80_sp.zip
fi fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../feature_zip ]]; then if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../feature_zip ${data_dir} ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi fi
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
......
...@@ -17,8 +17,8 @@ config_list=(base ctc) ...@@ -17,8 +17,8 @@ config_list=(base ctc)
config_list=(purectc) config_list=(purectc)
#config_list=(base conformer) #config_list=(base conformer)
#config_list=(pds_base_16) config_list=(pds_base_8)
#config_list=(pds_base_16 conformer rpr) config_list=(purectc_pds_base_8)
# exp full name # exp full name
exp_name= exp_name=
......
...@@ -7,7 +7,7 @@ update_freq=1 ...@@ -7,7 +7,7 @@ update_freq=1
max_tokens=8192 max_tokens=8192
exp_tag=baseline exp_tag=baseline
config_list=(base) config_list=(small)
# exp full name # exp full name
exp_name= exp_name=
......
...@@ -2,5 +2,6 @@ ctc-weight: 0.2 ...@@ -2,5 +2,6 @@ ctc-weight: 0.2
intermedia-ctc-layers: 6,9 intermedia-ctc-layers: 6,9
intermedia-adapter: league intermedia-adapter: league
intermedia-ctc-weight: 0.1 intermedia-ctc-weight: 0.1
#intermedia-drop-prob: 0.2
ctc-self-distill-weight: 0 ctc-self-distill-weight: 0
post-process: sentencepiece post-process: sentencepiece
\ No newline at end of file
arch: pdss2t_transformer_s_8 arch: pdss2t_transformer_s_8
pds-ctc: 1_1_1_1 #pds-ctc: 0_1_1_0
intermedia-adapter: league #intermedia-adapter: league
intermedia-ctc-weight: 0.15 #intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True #pds-fusion: True
pds-fusion-method: all_conv pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256 pds-embed-dims: 256_256_256_256
pds-ds-method: conv pds-ds-method: conv
......
encoder-attention-type: rel_selfattn encoder-attention-type: rel_pos
#encoder-attention-type: rel_pos_legacy
#encoder-attention-type: rel_selfattn
#encoder-attention-type: relative #encoder-attention-type: relative
#decoder-attention-type: relative #decoder-attention-type: relative
#max-encoder-relative-length: 100 #max-encoder-relative-length: 100
......
...@@ -133,8 +133,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -133,8 +133,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ${speed_perturb} -eq 1 ]]; then if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip feature_zip=fbank80_sp.zip
fi fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../feature_zip ]]; then if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../feature_zip ${data_dir} ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi fi
# create ASR vocabulary if necessary # create ASR vocabulary if necessary
......
...@@ -14,13 +14,12 @@ extra_parameter= ...@@ -14,13 +14,12 @@ extra_parameter=
exp_tag= exp_tag=
#config_list=(base) #config_list=(base)
config_list=(ctc) #config_list=(base ctc conformer)
#config_list=(sate_ctc) #config_list=(sate ctc)
#config_list=(ctc conformer rpr)
#config_list=(base sate)
#config_list=(pds_base) #config_list=(pds_base_8)
#config_list=(pds_base conformer) #config_list=(pds_base conformer)
#config_list=(sate_pds ctc)
# exp full name # exp full name
exp_name= exp_name=
......
...@@ -327,14 +327,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -327,14 +327,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
log=${model_dir}/train.log log=${model_dir}/train.log
if [[ -e ${log} ]]; then
for i in `seq 1 100`; do
if [ ! -e ${log}.${i} ]; then
log=${log}.${i}
break
fi
done
fi
cmd="nohup ${cmd} >> ${log} 2>&1 &" cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
......
set -e
eval=1
lcrm=0
src_lang=en
tgt_lang=zh
tokenize=1
splits=(tst-COMMON test11)
dataset=wmt20
root_dir=~/st/Fairseq-S2T
data_dir=/home/xuchen/st/data/$dataset/data
vocab_dir=/home/xuchen/st/data/$dataset/mt/unigram32000_tok
dest_dir=$vocab_dir
src_vocab_prefix=spm_unigram32000_en
tgt_vocab_prefix=spm_unigram32000_zh
for split in ${splits[@]}; do
src_file=${data_dir}/${split}/${split}.${src_lang}
tgt_file=${data_dir}/${split}/${split}.${tgt_lang}
if [[ ${tokenize} -eq 1 ]]; then
src_tok_file=${data_dir}/${split}.tok/${split}.tok.${src_lang}
tgt_tok_file=${data_dir}/${split}.tok/${split}.tok.${tgt_lang}
if [[ ! -f ${src_tok_file} ]]; then
cmd="tokenizer.perl -l ${src_lang} --threads 8 -no-escape < ${src_file} > ${src_tok_file}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
fi
if [[ ! -f ${tgt_tok_file} ]]; then
cmd="tokenizer.perl -l ${tgt_lang} --threads 8 -no-escape < ${tgt_file} > ${tgt_tok_file}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
fi
src_file=${src_tok_file}
tgt_file=${tgt_tok_file}
fi
cmd="cat ${src_file}"
if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${src_file}"
fi
cmd="${cmd}
| spm_encode --model ${vocab_dir}/${src_vocab_prefix}.model
--output_format=piece
> ${src_file}.spm"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
cmd="spm_encode
--model ${vocab_dir}/${tgt_vocab_prefix}.model
--output_format=piece
< ${tgt_file}
> ${tgt_file}.spm"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
src_file=${src_file}.spm
tgt_file=${tgt_file}.spm
mkdir -p ${dest_dir}/final
cmd="cp ${src_file} ${dest_dir}/final/${split}.${src_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
cmd="cp ${tgt_file} ${dest_dir}/final/${split}.${tgt_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
done
n_set=${#splits[*]}
for ((i=0;i<$n_set;i++)); do
dataset[$i]=${dest_dir}/final/${splits[$i]}
done
pref=`echo ${dataset[*]} | sed 's/ /,/g'`
cmd="python ${root_dir}/fairseq_cli/preprocess.py
--source-lang ${src_lang}
--target-lang ${tgt_lang}
--testpref ${pref}
--destdir ${dest_dir}/data-bin
--srcdict ${vocab_dir}/${src_vocab_prefix}.txt
--tgtdict ${vocab_dir}/${tgt_vocab_prefix}.txt
--workers 64"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
arch: transformer
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 2e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: transformer
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 4000
lr: 7e-4
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: False
decoder-normalize-before: False
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
train-subset: train
valid-subset: valid
max-epoch: 20
max-update: 100000
patience: 5
best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
no-epoch-checkpoints: True
#keep-last-epochs: 10
keep-best-checkpoints: 5
num-workers: 8
no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
skip-invalid-size-inputs-valid-test: True
max-source-positions: 512
arch: transformer_wmt_en_de_big_t2t
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 7e-4
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.3
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 1024
encoder-ffn-embed-dim: 4096
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 16
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: transformer_wmt_en_de_big
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 4000
lr: 5e-4
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.3
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: False
decoder-normalize-before: False
encoder-embed-dim: 1024
encoder-ffn-embed-dim: 4096
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 16
decoder-embed-dim: 1024
decoder-ffn-embed-dim: 4096
decoder-attention-heads: 16
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: transformer
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 16000
lr: 2e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 30
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
use-enc-dlcl: True
use-dec-dlcl: True
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 8
max-decoder-relative-length: 8
#! /bin/bash
gpu_num=1
data_dir=
test_subset=(test)
exp_name=
if [ "$#" -eq 1 ]; then
exp_name=$1
fi
sacrebleu=0
n_average=5
beam_size=4
len_penalty=0.6
max_tokens=80000
dec_model=checkpoint_best.pt
cmd="./run.sh
--stage 2
--stop_stage 2
--gpu_num ${gpu_num}
--exp_name ${exp_name}
--sacrebleu ${sacrebleu}
--n_average ${n_average}
--beam_size ${beam_size}
--len_penalty ${len_penalty}
--max_tokens ${max_tokens}
--dec_model ${dec_model}
"
if [[ -n ${data_dir} ]]; then
cmd="$cmd --data_dir ${data_dir}"
fi
if [[ -n ${test_subset} ]]; then
test_subset=`echo ${test_subset[*]} | sed 's/ /,/g'`
cmd="$cmd --test_subset ${test_subset}"
fi
echo $cmd
eval $cmd
import sys
import string
in_file = sys.argv[1]
with open(in_file, "r", encoding="utf-8") as f:
for line in f.readlines():
line = line.strip().lower()
for w in string.punctuation:
line = line.replace(w, "")
line = line.replace(" ", "")
print(line)
gpu_num=4
cmd="sh train.sh"
while :
do
record=$(mktemp -t temp.record.XXXXXX)
gpustat > $record
all_devices=$(seq 0 "$(sed '1,2d' ${record} | wc -l)");
count=0
for dev in ${all_devices[@]}
do
line=$((dev + 2))
use=$(head -n $line ${record} | tail -1 | cut -d '|' -f3 | cut -d '/' -f1)
if [[ $use -lt 100 ]]; then
device[$count]=$dev
count=$((count + 1))
if [[ $count -eq $gpu_num ]]; then
break
fi
fi
done
if [[ ${#device[@]} -lt $gpu_num ]]; then
sleep 60s
else
echo "Run $cmd"
eval $cmd
sleep 10s
exit
fi
done
#!/usr/bin/env perl
#
# This file is part of moses. Its use is licensed under the GNU Lesser General
# Public License version 2.1 or, at your option, any later version.
# $Id$
use warnings;
use strict;
my $lowercase = 0;
if ($ARGV[0] eq "-lc") {
$lowercase = 1;
shift;
}
my $stem = $ARGV[0];
if (!defined $stem) {
print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n";
print STDERR "Reads the references from reference or reference0, reference1, ...\n";
exit(1);
}
$stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0";
my @REF;
my $ref=0;
while(-e "$stem$ref") {
&add_to_ref("$stem$ref",\@REF);
$ref++;
}
&add_to_ref($stem,\@REF) if -e $stem;
die("ERROR: could not find reference file $stem") unless scalar @REF;
# add additional references explicitly specified on the command line
shift;
foreach my $stem (@ARGV) {
&add_to_ref($stem,\@REF) if -e $stem;
}
sub add_to_ref {
my ($file,$REF) = @_;
my $s=0;
if ($file =~ /.gz$/) {
open(REF,"gzip -dc $file|") or die "Can't read $file";
} else {
open(REF,$file) or die "Can't read $file";
}
while(<REF>) {
chop;
push @{$$REF[$s++]}, $_;
}
close(REF);
}
my(@CORRECT,@TOTAL,$length_translation,$length_reference);
my $s=0;
while(<STDIN>) {
chop;
$_ = lc if $lowercase;
my @WORD = split;
my %REF_NGRAM = ();
my $length_translation_this_sentence = scalar(@WORD);
my ($closest_diff,$closest_length) = (9999,9999);
foreach my $reference (@{$REF[$s]}) {
# print "$s $_ <=> $reference\n";
$reference = lc($reference) if $lowercase;
my @WORD = split(' ',$reference);
my $length = scalar(@WORD);
my $diff = abs($length_translation_this_sentence-$length);
if ($diff < $closest_diff) {
$closest_diff = $diff;
$closest_length = $length;
# print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n";
} elsif ($diff == $closest_diff) {
$closest_length = $length if $length < $closest_length;
# from two references with the same closeness to me
# take the *shorter* into account, not the "first" one.
}
for(my $n=1;$n<=4;$n++) {
my %REF_NGRAM_N = ();
for(my $start=0;$start<=$#WORD-($n-1);$start++) {
my $ngram = "$n";
for(my $w=0;$w<$n;$w++) {
$ngram .= " ".$WORD[$start+$w];
}
$REF_NGRAM_N{$ngram}++;
}
foreach my $ngram (keys %REF_NGRAM_N) {
if (!defined($REF_NGRAM{$ngram}) ||
$REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) {
$REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram};
# print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}<BR>\n";
}
}
}
}
$length_translation += $length_translation_this_sentence;
$length_reference += $closest_length;
for(my $n=1;$n<=4;$n++) {
my %T_NGRAM = ();
for(my $start=0;$start<=$#WORD-($n-1);$start++) {
my $ngram = "$n";
for(my $w=0;$w<$n;$w++) {
$ngram .= " ".$WORD[$start+$w];
}
$T_NGRAM{$ngram}++;
}
foreach my $ngram (keys %T_NGRAM) {
$ngram =~ /^(\d+) /;
my $n = $1;
# my $corr = 0;
# print "$i e $ngram $T_NGRAM{$ngram}<BR>\n";
$TOTAL[$n] += $T_NGRAM{$ngram};
if (defined($REF_NGRAM{$ngram})) {
if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) {
$CORRECT[$n] += $T_NGRAM{$ngram};
# $corr = $T_NGRAM{$ngram};
# print "$i e correct1 $T_NGRAM{$ngram}<BR>\n";
}
else {
$CORRECT[$n] += $REF_NGRAM{$ngram};
# $corr = $REF_NGRAM{$ngram};
# print "$i e correct2 $REF_NGRAM{$ngram}<BR>\n";
}
}
# $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram};
# print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n"
}
}
$s++;
}
my $brevity_penalty = 1;
my $bleu = 0;
my @bleu=();
for(my $n=1;$n<=4;$n++) {
if (defined ($TOTAL[$n])){
$bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0;
# print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n";
}else{
$bleu[$n]=0;
}
}
if ($length_reference==0){
printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n";
exit(1);
}
if ($length_translation<$length_reference) {
$brevity_penalty = exp(1-$length_reference/$length_translation);
}
$bleu = $brevity_penalty * exp((my_log( $bleu[1] ) +
my_log( $bleu[2] ) +
my_log( $bleu[3] ) +
my_log( $bleu[4] ) ) / 4) ;
printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n",
100*$bleu,
100*$bleu[1],
100*$bleu[2],
100*$bleu[3],
100*$bleu[4],
$brevity_penalty,
$length_translation / $length_reference,
$length_translation,
$length_reference;
sub my_log {
return -9999999999 unless $_[0];
return log($_[0]);
}
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.
#!/usr/bin/env perl
#
# This file is part of moses. Its use is licensed under the GNU Lesser General
# Public License version 2.1 or, at your option, any later version.
use warnings;
use strict;
#binmode(STDIN, ":utf8");
#binmode(STDOUT, ":utf8");
while(<STDIN>) {
s/,/,/g;
s/。 */. /g;
s/、/,/g;
s/”/"/g;
s/“/"/g;
s/∶/:/g;
s/:/:/g;
s/?/\?/g;
s/《/"/g;
s/》/"/g;
s/)/\)/g;
s/!/\!/g;
s/(/\(/g;
s/;/;/g;
s/1/"/g;
s/」/"/g;
s/「/"/g;
s/0/0/g;
s/3/3/g;
s/2/2/g;
s/5/5/g;
s/6/6/g;
s/9/9/g;
s/7/7/g;
s/8/8/g;
s/4/4/g;
s/. */. /g;
s/~/\~/g;
s/’/\'/g;
s/…/\.\.\./g;
s/━/\-/g;
s/〈/\</g;
s/〉/\>/g;
s/【/\[/g;
s/】/\]/g;
s/%/\%/g;
print $_;
}
get_devices(){
gpu_num=$1
use_cpu=$2
device=()
while :
do
record=$(mktemp -t temp.record.XXXXXX)
gpustat > $record
all_devices=$(seq 0 "$(sed '1,2d' ${record} | wc -l)");
count=0
for dev in ${all_devices[@]}
do
line=$((dev + 2))
use=$(head -n $line ${record} | tail -1 | cut -d '|' -f3 | cut -d '/' -f1)
if [[ $use -lt 100 ]]; then
device[$count]=$dev
count=$((count + 1))
if [[ $count -eq $gpu_num ]]; then
break
fi
fi
done
if [[ ${#device[@]} -lt $gpu_num ]]; then
if [[ $use_cpu -eq 1 ]]; then
device=(-1)
else
sleep 60s
fi
else
break
fi
done
echo ${device[*]} | sed 's/ /,/g'
return $?
}
#! /bin/bash
# calculate wmt14 en-de multi-bleu score
if [ $# -ne 1 ]; then
echo "usage: $0 GENERATE_PY_OUTPUT"
exit 1
fi
echo -e "\n RUN >> "$0
requirement_scripts=(detokenizer.perl replace-unicode-punctuation.perl tokenizer.perl multi-bleu.perl)
for script in ${requirement_scripts[@]}; do
if ! which ${script} > /dev/null; then
echo "Error: it seems that moses is not installed or exported int the environment variables." >&2
return 1
fi
done
detokenizer=detokenizer.perl
replace_unicode_punctuation=replace-unicode-punctuation.perl
tokenizer=tokenizer.perl
multi_bleu=multi-bleu.perl
GEN=$1
SYS=$GEN.sys
REF=$GEN.ref
cat $GEN | cut -f 3 > $REF
cat $GEN | cut -f 4 > $SYS
#detokenize the decodes file to format the manner to do tokenize
$detokenizer -l de < $SYS > $SYS.dtk
$detokenizer -l de < $REF > $REF.dtk
#replace unicode
$replace_unicode_punctuation -l de < $SYS.dtk > $SYS.dtk.punc
$replace_unicode_punctuation -l de < $REF.dtk > $REF.dtk.punc
#tokenize the decodes file by moses tokenizer.perl
$tokenizer -l de < $SYS.dtk.punc > $SYS.dtk.punc.tok
$tokenizer -l de < $REF.dtk.punc > $REF.dtk.punc.tok
#"rich-text format" --> rich ##AT##-##AT## text format.
perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $SYS.dtk.punc.tok > $SYS.dtk.punc.tok.atat
perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $REF.dtk.punc.tok > $REF.dtk.punc.tok.atat
$multi_bleu $REF.dtk.punc.tok.atat < $SYS.dtk.punc.tok.atat
rm -f $SYS.dtk $SYS.dtk.punc $SYS.dtk.punc.tok $REF.dtk $REF.dtk.punc $REF.dtk.punc.tok
\ No newline at end of file
#! /bin/bash
# training the model
gpu_num=8
update_freq=2
max_tokens=8192
exp_tag=baseline
#config_list=(base)
config_list=(deep)
# exp full name
exp_name=
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
cmd="./run.sh
--stage 1
--stop_stage 1
--gpu_num ${gpu_num}
--update_freq ${update_freq}
--train_config ${train_config}
--max_tokens ${max_tokens}
"
if [[ -n ${exp_name} ]]; then
cmd="$cmd --exp_name ${exp_name}"
fi
if [[ -n ${exp_tag} ]]; then
cmd="$cmd --exp_tag ${exp_tag}"
fi
if [[ -n ${extra_tag} ]]; then
cmd="$cmd --extra_tag ${extra_tag}"
fi
if [[ -n ${extra_parameter} ]]; then
cmd="$cmd --extra_parameter \"${extra_parameter}\""
fi
echo ${cmd}
eval ${cmd}
...@@ -112,7 +112,7 @@ class AudioDataset(Dataset): ...@@ -112,7 +112,7 @@ class AudioDataset(Dataset):
if self.mode == "easy": if self.mode == "easy":
real_idx = 0 real_idx = 0
for idx, v in segments.items(): for idx, v in segments.items():
audio_name = v["audio"] audio_name = f"{split}_{v['audio']}"
v["audio"] = (wav_root / v["audio"].strip()).as_posix() + ".wav" v["audio"] = (wav_root / v["audio"].strip()).as_posix() + ".wav"
if self.speed_perturb is not None: if self.speed_perturb is not None:
for perturb in self.speed_perturb: for perturb in self.speed_perturb:
...@@ -137,8 +137,8 @@ class AudioDataset(Dataset): ...@@ -137,8 +137,8 @@ class AudioDataset(Dataset):
for i, segment in enumerate(seg_group): for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate) offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate) n_frames = int(float(segment["duration"]) * sample_rate)
# _id = f"{split}_{wav_path.stem}_{i}" _id = f"{split}_{wav_path.stem}_{i}"
_id = f"{wav_path.stem}_{i}" # _id = f"{wav_path.stem}_{i}"
item = dict() item = dict()
item["audio"] = wav_path.as_posix() item["audio"] = wav_path.as_posix()
...@@ -237,7 +237,7 @@ def process(args): ...@@ -237,7 +237,7 @@ def process(args):
if not Path.exists(zip_path) or args.overwrite: if not Path.exists(zip_path) or args.overwrite:
gen_feature_flag = True gen_feature_flag = True
if True and gen_feature_flag: if gen_feature_flag:
if args.speed_perturb: if args.speed_perturb:
feature_root = output_root / "fbank80_sp" feature_root = output_root / "fbank80_sp"
else: else:
...@@ -264,12 +264,8 @@ def process(args): ...@@ -264,12 +264,8 @@ def process(args):
utt_id = item['id'] utt_id = item['id']
features_path = (feature_root / f"{utt_id}.npy").as_posix() features_path = (feature_root / f"{utt_id}.npy").as_posix()
tag_features_path = (feature_root / f"{split}_{utt_id}.npy").as_posix()
if os.path.exists(tag_features_path): if os.path.exists(features_path):
continue
if os.path.exists(features_path) and not os.path.exists(tag_features_path):
shutil.move(features_path, tag_features_path)
continue continue
waveform, sample_rate, _ = dataset.get(idx, need_waveform=True) waveform, sample_rate, _ = dataset.get(idx, need_waveform=True)
......
...@@ -96,16 +96,19 @@ def process(args): ...@@ -96,16 +96,19 @@ def process(args):
tgt_train_text.extend(manifest["tgt_text"]) tgt_train_text.extend(manifest["tgt_text"])
# Generate vocab and yaml # Generate vocab and yaml
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) tgt_v_size_str = "" if args.tgt_vocab_type == "char" else str(args.tgt_vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}" tgt_spm_filename_prefix = f"spm_{args.tgt_vocab_type}{tgt_v_size_str}"
if args.share: if args.share:
tgt_train_text.extend(src_train_text) tgt_train_text.extend(src_train_text)
src_spm_filename_prefix = spm_filename_prefix + "_share" tgt_spm_filename_prefix = tgt_spm_filename_prefix + "_share"
tgt_spm_filename_prefix = src_spm_filename_prefix src_spm_filename_prefix = tgt_spm_filename_prefix
else: else:
src_spm_filename_prefix = spm_filename_prefix + "_" + src_lang src_v_size_str = "" if args.src_vocab_type == "char" else str(args.src_vocab_size)
tgt_spm_filename_prefix = spm_filename_prefix + "_" + tgt_lang src_spm_filename_prefix = f"spm_{args.src_vocab_type}{src_v_size_str}"
src_spm_filename_prefix = src_spm_filename_prefix + "_" + src_lang
tgt_spm_filename_prefix = tgt_spm_filename_prefix + "_" + tgt_lang
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for t in tgt_train_text: for t in tgt_train_text:
...@@ -113,8 +116,8 @@ def process(args): ...@@ -113,8 +116,8 @@ def process(args):
gen_vocab( gen_vocab(
Path(f.name), Path(f.name),
output_root / tgt_spm_filename_prefix, output_root / tgt_spm_filename_prefix,
args.vocab_type, args.tgt_vocab_type,
args.vocab_size, args.tgt_vocab_size,
normalization_rule_name="identity" if tgt_lang == "zh" else None normalization_rule_name="identity" if tgt_lang == "zh" else None
) )
...@@ -125,8 +128,8 @@ def process(args): ...@@ -125,8 +128,8 @@ def process(args):
gen_vocab( gen_vocab(
Path(f.name), Path(f.name),
output_root / src_spm_filename_prefix, output_root / src_spm_filename_prefix,
args.vocab_type, args.src_vocab_type,
args.vocab_size, args.src_vocab_size,
normalization_rule_name="identity" if tgt_lang == "zh" else None normalization_rule_name="identity" if tgt_lang == "zh" else None
) )
...@@ -135,7 +138,7 @@ def process(args): ...@@ -135,7 +138,7 @@ def process(args):
if args.share: if args.share:
yaml_filename = f"config_share.yaml" yaml_filename = f"config_share.yaml"
conf = {} conf = dict()
conf["src_vocab_filename"] = src_spm_filename_prefix + ".txt" conf["src_vocab_filename"] = src_spm_filename_prefix + ".txt"
conf["tgt_vocab_filename"] = tgt_spm_filename_prefix + ".txt" conf["tgt_vocab_filename"] = tgt_spm_filename_prefix + ".txt"
conf["src_bpe_tokenizer"] = { conf["src_bpe_tokenizer"] = {
...@@ -157,13 +160,21 @@ def main(): ...@@ -157,13 +160,21 @@ def main():
parser.add_argument("--data-root", "-d", required=True, type=str) parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument("--output-root", "-o", default=None, type=str) parser.add_argument("--output-root", "-o", default=None, type=str)
parser.add_argument( parser.add_argument(
"--vocab-type", "--src-vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
)
parser.add_argument(
"--tgt-vocab-type",
default="unigram", default="unigram",
required=True, required=True,
type=str, type=str,
choices=["bpe", "unigram", "char"], choices=["bpe", "unigram", "char"],
), )
parser.add_argument("--vocab-size", default=10000, type=int) parser.add_argument("--src-vocab-size", default=10000, type=int)
parser.add_argument("--tgt-vocab-size", default=10000, type=int)
parser.add_argument("--size", default=-1, type=int) parser.add_argument("--size", default=-1, type=int)
parser.add_argument("--splits", default="train,dev,test", type=str) parser.add_argument("--splits", default="train,dev,test", type=str)
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text") parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
......
...@@ -704,6 +704,8 @@ def load_pretrained_component_from_model( ...@@ -704,6 +704,8 @@ def load_pretrained_component_from_model(
if key.startswith(component_type): if key.startswith(component_type):
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight # encoder.input_layers.0.0.weight --> input_layers.0.0.weight
component_subkey = key[len(component_type) + 1:] component_subkey = key[len(component_type) + 1:]
if component_subkey.startswith(component_type):
component_subkey = component_subkey[len(component_type) + 1:]
component_state_dict[component_subkey] = state["model"][key] component_state_dict[component_subkey] = state["model"][key]
mismatch_keys = [] mismatch_keys = []
......
...@@ -91,7 +91,7 @@ class Adapter(nn.Module): ...@@ -91,7 +91,7 @@ class Adapter(nn.Module):
logger.info("CTC Compress Strategy: %s" % strategy) logger.info("CTC Compress Strategy: %s" % strategy)
elif self.adapter_type == "league": elif self.adapter_type == "league":
self.distribution_cutoff = strategy self.distribution_cutoff = strategy
if self.distribution_cutoff != -1: if self.distribution_cutoff is not None:
logger.info("Distribution cutoff: %d" % int(strategy)) logger.info("Distribution cutoff: %d" % int(strategy))
def forward(self, x, padding): def forward(self, x, padding):
...@@ -112,7 +112,7 @@ class Adapter(nn.Module): ...@@ -112,7 +112,7 @@ class Adapter(nn.Module):
elif self.adapter_type == "league": elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
if self.distribution_cutoff != -1: if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), distribution.size(-1) - 1) cutoff = min(int(self.distribution_cutoff), distribution.size(-1) - 1)
threshold = distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1] threshold = distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
distribution = torch.where(distribution > threshold, distribution, torch.zeros_like(distribution)) distribution = torch.where(distribution > threshold, distribution, torch.zeros_like(distribution))
......
...@@ -192,9 +192,34 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -192,9 +192,34 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"rel_pos", "rel_pos",
"rope", "rope",
"abs", "abs",
"transfer",
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
# transfer
parser.add_argument(
"--relative-pos-enc",
action="store_true",
help="use relative position encoding for attention",
)
parser.add_argument(
"--linear-att",
action="store_true",
help="use linear attention",
)
# reduced attention
parser.add_argument(
"--attention-reduced-method",
type=str,
default="conv",
help="reduction method for attention",
)
parser.add_argument(
"--attention-reduced-q",
action="store_true",
help="use reduction for query or not"
)
parser.add_argument( parser.add_argument(
"--encoder-attention-heads", "--encoder-attention-heads",
type=int, type=int,
...@@ -450,9 +475,9 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -450,9 +475,9 @@ class PDSS2TTransformerModel(S2TTransformerModel):
help="the number of the attention heads in each stage", help="the number of the attention heads in each stage",
) )
parser.add_argument( parser.add_argument(
"--pds-attn-ds-ratio", "--pds-attn-ds-ratios",
type=str, type=str,
help="the ratio of the down-sampling in the self attention module", help="the ratios of the down-sampling in the self attention module",
) )
parser.add_argument( parser.add_argument(
"--pds-ffn-ratios", "--pds-ffn-ratios",
...@@ -495,7 +520,7 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -495,7 +520,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
) )
parser.add_argument( parser.add_argument(
"--intermedia-distribution-cutoff", "--intermedia-distribution-cutoff",
default=-1, default=None,
type=int, type=int,
help="cutoff of the distribution", help="cutoff of the distribution",
) )
...@@ -641,7 +666,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -641,7 +666,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
pos_embed = None pos_embed = None
stage = nn.ModuleList([ stage = nn.ModuleList([
PDSTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_ds_ratio) PDSTransformerEncoderLayer(args, embed_dim, ffn_ratio, num_head, attn_ds_ratio)
for _ in range(num_layers)]) for _ in range(num_layers)])
# representation fusion # representation fusion
...@@ -735,9 +760,12 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -735,9 +760,12 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
(("ctc" in getattr(args, "criterion", "")) and (("ctc" in getattr(args, "criterion", "")) and
(getattr(args, "ctc_weight", False) > 0)) (getattr(args, "ctc_weight", False) > 0))
if self.use_ctc: if self.use_ctc:
self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers # self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers
self.ctc_layer = args.encoder_layers if self.ctc_layer == 0 else self.ctc_layer # self.ctc_layer = args.encoder_layers if self.ctc_layer == 0 else self.ctc_layer
self.inter_ctc = True if self.ctc_layer != args.encoder_layers or self.fusion_stages_num != 0 else False # self.inter_ctc = True if self.ctc_layer != args.encoder_layers or self.fusion_stages_num != 0 else False
self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != 0 else False
if self.inter_ctc: if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer) logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
...@@ -1027,7 +1055,7 @@ def base_architecture(args): ...@@ -1027,7 +1055,7 @@ def base_architecture(args):
# intermedia CTC # intermedia CTC
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0") args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0")
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none") args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.ctc_self_distill = getattr(args, "ctc_self_distill", False) args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
def set_pds_base_8(args): def set_pds_base_8(args):
......
...@@ -131,11 +131,34 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -131,11 +131,34 @@ class S2TCTCModel(FairseqEncoderModel):
"relative", "relative",
"rel_pos", "rel_pos",
"rope", "rope",
"abs" "abs",
"transfer",
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
parser.add_argument( parser.add_argument(
"--relative-pos-enc",
action="store_true",
help="use relative position encoding for attention",
)
parser.add_argument(
"--linear-att",
action="store_true",
help="use linear attention",
)
parser.add_argument(
"--attention-reduced-method",
type=str,
default="conv",
help="reduction method for attention",
)
parser.add_argument(
"--attention-reduced-q",
action="store_true",
help="use reduction for query or not",
)
parser.add_argument(
"--encoder-attention-heads", "--encoder-attention-heads",
type=int, type=int,
metavar="N", metavar="N",
...@@ -412,7 +435,7 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -412,7 +435,7 @@ class S2TCTCModel(FairseqEncoderModel):
help="the number of the attention heads in each stage", help="the number of the attention heads in each stage",
) )
parser.add_argument( parser.add_argument(
"--pds-attn-ds-ratio", "--pds-attn-ds-ratios",
type=str, type=str,
help="the ratio of the down-sampling in the self attention module", help="the ratio of the down-sampling in the self attention module",
) )
...@@ -457,7 +480,7 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -457,7 +480,7 @@ class S2TCTCModel(FairseqEncoderModel):
) )
parser.add_argument( parser.add_argument(
"--intermedia-distribution-cutoff", "--intermedia-distribution-cutoff",
default=-1, default=None,
type=int, type=int,
help="cutoff of the distribution", help="cutoff of the distribution",
) )
...@@ -931,6 +954,26 @@ def base_architecture(args): ...@@ -931,6 +954,26 @@ def base_architecture(args):
args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None) args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None)
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear") args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
# PDS
args.pds_stages = getattr(args, "pds_stages", None)
args.pds_layers = getattr(args, "pds_layers", None)
args.pds_ratios = getattr(args, "pds_ratios", None)
args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", True)
args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC # intermedia CTC
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None) args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None) args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
......
...@@ -5,13 +5,11 @@ import torch ...@@ -5,13 +5,11 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import checkpoint_utils from fairseq import checkpoint_utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.models.speech_to_text import ( from fairseq.models.speech_to_text import (
S2TTransformerModel, S2TTransformerModel,
S2TTransformerEncoder, S2TTransformerEncoder,
...@@ -314,12 +312,12 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -314,12 +312,12 @@ class S2TSATEEncoder(FairseqEncoder):
if args.adapter == "shrink": if args.adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", "avg") strategy = getattr(args, "ctc_compress_strategy", "avg")
elif args.adapter == "league": elif args.adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", -1) strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(args.encoder_embed_dim, self.adapter = Adapter(args.encoder_embed_dim,
args.adapter, args.adapter,
task.source_dictionary, task.source_dictionary,
embed_tokens, embed_tokens if task.source_dictionary == task.target_dictionary else None,
strategy=strategy) strategy=strategy)
if args.share_ctc_and_adapter and hasattr(self.adapter, "embed_adapter"): if args.share_ctc_and_adapter and hasattr(self.adapter, "embed_adapter"):
......
...@@ -385,7 +385,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -385,7 +385,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
) )
parser.add_argument( parser.add_argument(
"--intermedia-distribution-cutoff", "--intermedia-distribution-cutoff",
default=-1, default=None,
type=int, type=int,
help="cutoff of the distribution", help="cutoff of the distribution",
) )
...@@ -581,7 +581,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -581,7 +581,7 @@ class S2TTransformerEncoder(FairseqEncoder):
if args.intermedia_adapter == "shrink": if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None) strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league": elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", -1) strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(dim, args.intermedia_adapter, self.adapter = Adapter(dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy) task.source_dictionary, strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
......
...@@ -8,6 +8,7 @@ from .squeeze_excitation import SEAttention ...@@ -8,6 +8,7 @@ from .squeeze_excitation import SEAttention
from .activations import swish, Swish from .activations import swish, Swish
from .adaptive_input import AdaptiveInput from .adaptive_input import AdaptiveInput
from .adaptive_softmax import AdaptiveSoftmax from .adaptive_softmax import AdaptiveSoftmax
from .attention import MultiHeadSelfAttentionModule
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .character_token_embedder import CharacterTokenEmbedder from .character_token_embedder import CharacterTokenEmbedder
from .downsample_convolution import DownSampleConvolutionModule from .downsample_convolution import DownSampleConvolutionModule
...@@ -91,6 +92,7 @@ __all__ = [ ...@@ -91,6 +92,7 @@ __all__ = [
"LinearizedConvolution", "LinearizedConvolution",
"LocalMultiheadAttention", "LocalMultiheadAttention",
"MultiheadAttention", "MultiheadAttention",
"MultiHeadSelfAttentionModule",
"PositionalEmbedding", "PositionalEmbedding",
"PDSTransformerEncoderLayer", "PDSTransformerEncoderLayer",
"ReducedMultiheadAttention", "ReducedMultiheadAttention",
......
...@@ -43,6 +43,7 @@ def get_activation_class(activation: str, dim=None): ...@@ -43,6 +43,7 @@ def get_activation_class(activation: str, dim=None):
else: else:
raise RuntimeError("activation function {} not supported".format(activation)) raise RuntimeError("activation function {} not supported".format(activation))
def swish(x: torch.Tensor) -> torch.Tensor: def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
......
...@@ -281,7 +281,7 @@ class StridedMultiHeadAttention(MultiHeadAttention): ...@@ -281,7 +281,7 @@ class StridedMultiHeadAttention(MultiHeadAttention):
"""Strided Multi-Head Attention Layer """Strided Multi-Head Attention Layer
Strided multi-head attention performs global sequence downsampling by striding Strided multi-head attention performs global sequence downsampling by striding
the attention query before aplying scaled dot-product attention. This results in the attention query before applying scaled dot-product attention. This results in
strided attention maps where query positions can attend to the entire sequence strided attention maps where query positions can attend to the entire sequence
context to perform downsampling. context to perform downsampling.
...@@ -1321,7 +1321,7 @@ class MultiHeadSelfAttentionModule(nn.Module): ...@@ -1321,7 +1321,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
Args: Args:
dim_model: model feature dimension dim_model: model feature dimension
num_heads: number of attention heads num_heads: number of attention heads
Pdrop: residual dropout probability dropout: residual dropout probability
max_pos_encoding: maximum position max_pos_encoding: maximum position
relative_pos_enc: whether to use relative postion embedding relative_pos_enc: whether to use relative postion embedding
causal: True for causal attention with masked future context causal: True for causal attention with masked future context
...@@ -1335,14 +1335,14 @@ class MultiHeadSelfAttentionModule(nn.Module): ...@@ -1335,14 +1335,14 @@ class MultiHeadSelfAttentionModule(nn.Module):
def __init__(self, def __init__(self,
dim_model, dim_model,
num_heads, num_heads,
Pdrop, dropout,
max_pos_encoding, max_pos_encoding,
relative_pos_enc, relative_pos_enc=False,
causal, causal=False,
group_size, group_size=1,
kernel_size, kernel_size=None,
stride, stride=1,
linear_att): linear_att=False):
super(MultiHeadSelfAttentionModule, self).__init__() super(MultiHeadSelfAttentionModule, self).__init__()
# Assert # Assert
...@@ -1351,7 +1351,7 @@ class MultiHeadSelfAttentionModule(nn.Module): ...@@ -1351,7 +1351,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
assert not (linear_att and relative_pos_enc), "Linear attention requires absolute positional encodings" assert not (linear_att and relative_pos_enc), "Linear attention requires absolute positional encodings"
# Pre Norm # Pre Norm
self.norm = nn.LayerNorm(dim_model, eps=1e-6) # self.norm = nn.LayerNorm(dim_model, eps=1e-6)
# Multi-Head Linear Attention # Multi-Head Linear Attention
if linear_att: if linear_att:
...@@ -1394,7 +1394,7 @@ class MultiHeadSelfAttentionModule(nn.Module): ...@@ -1394,7 +1394,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
self.mhsa = MultiHeadAttention(dim_model, num_heads) self.mhsa = MultiHeadAttention(dim_model, num_heads)
# Dropout # Dropout
self.dropout = nn.Dropout(Pdrop) # self.dropout = nn.Dropout(Pdrop)
# Module Params # Module Params
self.rel_pos_enc = relative_pos_enc self.rel_pos_enc = relative_pos_enc
...@@ -1402,8 +1402,9 @@ class MultiHeadSelfAttentionModule(nn.Module): ...@@ -1402,8 +1402,9 @@ class MultiHeadSelfAttentionModule(nn.Module):
def forward(self, x, mask=None, hidden=None): def forward(self, x, mask=None, hidden=None):
# Pre Norm x = x.transpose(0, 1)
x = self.norm(x) if mask is not None:
mask = mask.view(mask.size(0), 1, 1, mask.size(-1))
# Multi-Head Self-Attention # Multi-Head Self-Attention
if self.linear_att: if self.linear_att:
...@@ -1414,6 +1415,7 @@ class MultiHeadSelfAttentionModule(nn.Module): ...@@ -1414,6 +1415,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
x, attention = self.mhsa(x, x, x, mask) x, attention = self.mhsa(x, x, x, mask)
# Dropout # Dropout
x = self.dropout(x) # x = self.dropout(x)
return x, attention, hidden x = x.transpose(0, 1)
return x, attention
...@@ -10,17 +10,18 @@ class ConvolutionModule(nn.Module): ...@@ -10,17 +10,18 @@ class ConvolutionModule(nn.Module):
def __init__( def __init__(
self, self,
embed_dim, embed_dim,
channels, expand_embed_dim,
depthwise_kernel_size, depthwise_kernel_size,
dropout, dropout,
activation_fn="swish", activation_fn="swish",
bias=False, bias=False,
stride=1,
export=False, export=False,
): ):
""" """
Args: Args:
embed_dim: Embedding dimension embed_dim: Embedding dimension
channels: Number of channels in depthwise conv layers expand_embed_dim: Number of output embedding dimension
depthwise_kernel_size: Depthwise conv layer kernel size depthwise_kernel_size: Depthwise conv layer kernel size
dropout: dropout value dropout: dropout value
activation_fn: Activation function to use after depthwise convolution kernel activation_fn: Activation function to use after depthwise convolution kernel
...@@ -33,7 +34,7 @@ class ConvolutionModule(nn.Module): ...@@ -33,7 +34,7 @@ class ConvolutionModule(nn.Module):
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
self.pointwise_conv1 = torch.nn.Conv1d( self.pointwise_conv1 = torch.nn.Conv1d(
embed_dim, embed_dim,
2 * channels, 2 * expand_embed_dim,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
...@@ -41,19 +42,19 @@ class ConvolutionModule(nn.Module): ...@@ -41,19 +42,19 @@ class ConvolutionModule(nn.Module):
) )
self.glu = torch.nn.GLU(dim=1) self.glu = torch.nn.GLU(dim=1)
self.depthwise_conv = torch.nn.Conv1d( self.depthwise_conv = torch.nn.Conv1d(
channels, expand_embed_dim,
channels, expand_embed_dim,
depthwise_kernel_size, depthwise_kernel_size,
stride=1, stride=stride,
padding=(depthwise_kernel_size - 1) // 2, padding=(depthwise_kernel_size - 1) // 2,
groups=channels, groups=expand_embed_dim,
bias=bias, bias=bias,
) )
self.batch_norm = nn.BatchNorm1d(channels) self.batch_norm = nn.BatchNorm1d(expand_embed_dim)
self.activation = get_activation_class(activation_fn) self.activation = get_activation_class(activation_fn)
self.pointwise_conv2 = torch.nn.Conv1d( self.pointwise_conv2 = torch.nn.Conv1d(
channels, expand_embed_dim,
embed_dim, expand_embed_dim,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
...@@ -72,8 +73,8 @@ class ConvolutionModule(nn.Module): ...@@ -72,8 +73,8 @@ class ConvolutionModule(nn.Module):
x = x.transpose(1, 2) x = x.transpose(1, 2)
# GLU mechanism # GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim) x = self.pointwise_conv1(x) # (batch, 2*expand_embed_dim, dim)
x = self.glu(x) # (batch, channel, dim) x = self.glu(x) # (batch, expand_embed_dim, dim)
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
...@@ -81,10 +82,13 @@ class ConvolutionModule(nn.Module): ...@@ -81,10 +82,13 @@ class ConvolutionModule(nn.Module):
x = self.activation(x) x = self.activation(x)
x = self.pointwise_conv2(x) x = self.pointwise_conv2(x)
x = x.transpose(1, 2)
x = self.dropout(x) x = self.dropout(x)
return x.transpose(1, 2)
# return x
# class ConvolutionModule(nn.Module): # class ConvolutionModule(nn.Module):
# """ConvolutionModule in Conformer model.""" # """ConvolutionModule in Conformer model."""
# def __init__(self, # def __init__(self,
......
from typing import Optional from typing import Optional
import torch import torch
from torch import Tensor
import torch.nn as nn import torch.nn as nn
from fairseq.modules import ( from fairseq.modules import (
LayerNorm, LayerNorm,
MultiheadAttention, MultiheadAttention,
...@@ -14,10 +16,11 @@ from fairseq.modules import ( ...@@ -14,10 +16,11 @@ from fairseq.modules import (
LocalMultiheadAttention, LocalMultiheadAttention,
ReducedMultiheadAttention, ReducedMultiheadAttention,
RotaryPositionMultiHeadedAttention, RotaryPositionMultiHeadedAttention,
MultiHeadSelfAttentionModule,
) )
from fairseq.modules.s2t_transformer_layer import FeedForwardModule from fairseq.modules.s2t_transformer_layer import FeedForwardModule
from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.fairseq_dropout import FairseqDropout
from torch import Tensor from .utils import Transpose, Permute3D
class PDSTransformerEncoderLayer(nn.Module): class PDSTransformerEncoderLayer(nn.Module):
...@@ -35,29 +38,48 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -35,29 +38,48 @@ class PDSTransformerEncoderLayer(nn.Module):
args (argparse.Namespace): parsed command-line arguments args (argparse.Namespace): parsed command-line arguments
""" """
def __init__(self, args, embed_dim, ffn_embed_dim, num_head, att_sample_ratio=1): def __init__(self, args,
embed_dim,
ffn_ratio,
num_head,
attn_sample_ratio=1,
attn_stride=1,
conv_stride=1,
expand_embed_dim=None):
super().__init__() super().__init__()
self.args = args self.args = args
embed_dim = embed_dim embed_dim = embed_dim
ffn_dim = args.encoder_ffn_embed_dim
dropout = args.dropout dropout = args.dropout
self.quant_noise = getattr(args, 'quant_noise_pq', 0) if expand_embed_dim is None:
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 expand_embed_dim = embed_dim
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(args, embed_dim, num_head, att_sample_ratio)
self.self_attn_layer_norm = LayerNorm(embed_dim)
self.dropout_module = FairseqDropout( self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__ dropout, module_name=self.__class__.__name__
) )
self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.normalize_before = args.encoder_normalize_before self.normalize_before = args.encoder_normalize_before
activation = getattr(args, 'encoder_activation_fn', 'relu') activation = getattr(args, 'encoder_activation_fn', 'relu')
# attention
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(args, embed_dim, num_head, attn_sample_ratio)
self.self_attn_layer_norm = LayerNorm(embed_dim)
# Attention Residual
self.attn_res = nn.Sequential(
Permute3D(1, 2, 0),
nn.MaxPool1d(kernel_size=1, stride=attn_stride),
Permute3D(2, 0, 1)
) if attn_stride > 1 else nn.Identity()
if args.macaron_style: if args.macaron_style:
self.macaron_ffn = FeedForwardModule( self.macaron_ffn = FeedForwardModule(
embed_dim, embed_dim,
ffn_dim, embed_dim * ffn_ratio,
dropout, dropout,
dropout, dropout,
activation activation
...@@ -73,24 +95,37 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -73,24 +95,37 @@ class PDSTransformerEncoderLayer(nn.Module):
self.conv_norm = LayerNorm(embed_dim) self.conv_norm = LayerNorm(embed_dim)
self.conv_module = ConvolutionModule( self.conv_module = ConvolutionModule(
embed_dim, embed_dim,
embed_dim, expand_embed_dim,
depthwise_kernel_size=args.cnn_module_kernel, depthwise_kernel_size=args.cnn_module_kernel,
dropout=args.dropout, dropout=args.dropout,
activation_fn=getattr(args, 'activation_fn', 'swish')) activation_fn=activation,
self.final_norm = LayerNorm(embed_dim) stride=conv_stride
)
self.final_norm = LayerNorm(expand_embed_dim)
# Convolution Residual
self.conv_res = nn.Sequential(
Permute3D(1, 2, 0),
nn.Conv1d(embed_dim, expand_embed_dim, kernel_size=1, stride=conv_stride),
Permute3D(2, 0, 1)
) if embed_dim != expand_embed_dim else nn.Sequential(
Permute3D(1, 2, 0),
nn.MaxPool1d(kernel_size=1, stride=conv_stride),
Permute3D(2, 0, 1)
) if conv_stride > 1 else nn.Identity()
else: else:
self.conv_norm = None self.conv_norm = None
self.conv_module = None self.conv_module = None
self.final_norm = None self.final_norm = None
self.ffn = FeedForwardModule( self.ffn = FeedForwardModule(
embed_dim, expand_embed_dim,
ffn_dim, expand_embed_dim * ffn_ratio,
dropout, dropout,
dropout, dropout,
activation activation
) )
self.ffn_norm = LayerNorm(embed_dim) self.ffn_norm = LayerNorm(expand_embed_dim)
def build_self_attention(self, args, embed_dim, num_head, sample_ratio=1): def build_self_attention(self, args, embed_dim, num_head, sample_ratio=1):
attention_heads = num_head attention_heads = num_head
...@@ -165,6 +200,17 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -165,6 +200,17 @@ class PDSTransformerEncoderLayer(nn.Module):
q_noise=self.quant_noise, q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size, qn_block_size=self.quant_noise_block_size,
sample_ratio=sample_ratio, sample_ratio=sample_ratio,
reduced_method=getattr(args, "attention_reduced_method", "conv"),
reduced_q=getattr(args, "attention_reduced_q", False)
)
elif self.attn_type == "transfer":
return MultiHeadSelfAttentionModule(
embed_dim,
attention_heads,
dropout,
max_pos_encoding=args.max_source_positions,
relative_pos_enc=getattr(args, "relative_pos_enc", False),
linear_att=getattr(args, "linear_att", False),
) )
else: else:
print("The encoder attention type %s is not supported!" % self.attn_type) print("The encoder attention type %s is not supported!" % self.attn_type)
...@@ -248,6 +294,10 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -248,6 +294,10 @@ class PDSTransformerEncoderLayer(nn.Module):
attn_mask=attn_mask, attn_mask=attn_mask,
pos_emb=pos_emb pos_emb=pos_emb
) )
elif self.attn_type == "transfer":
x, _ = self.self_attn(
x, encoder_padding_mask
)
else: else:
x, _ = self.self_attn( x, _ = self.self_attn(
query=x, query=x,
...@@ -258,7 +308,7 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -258,7 +308,7 @@ class PDSTransformerEncoderLayer(nn.Module):
attn_mask=attn_mask, attn_mask=attn_mask,
) )
x = self.dropout_module(x) x = self.dropout_module(x)
x = self.residual_connection(x, residual) x = self.residual_connection(self.attn_res(x), residual)
if not self.normalize_before: if not self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math import math
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.layer_norm import LayerNorm
from fairseq.modules.quant_noise import quant_noise from fairseq.modules.quant_noise import quant_noise
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import Parameter from torch.nn import Parameter
...@@ -38,6 +33,8 @@ class ReducedMultiheadAttention(nn.Module): ...@@ -38,6 +33,8 @@ class ReducedMultiheadAttention(nn.Module):
q_noise=0.0, q_noise=0.0,
qn_block_size=8, qn_block_size=8,
sample_ratio=1, sample_ratio=1,
reduced_method="conv",
reduced_q=False,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
...@@ -85,13 +82,25 @@ class ReducedMultiheadAttention(nn.Module): ...@@ -85,13 +82,25 @@ class ReducedMultiheadAttention(nn.Module):
self.add_zero_attn = add_zero_attn self.add_zero_attn = add_zero_attn
self.sample_ratio = sample_ratio self.sample_ratio = sample_ratio
self.reduced_method = reduced_method
self.reduced_q = reduced_q
if reduced_q:
assert self.reduced_method == 'group', "only support grouped method for query reduction"
if self.sample_ratio > 1: if self.sample_ratio > 1:
self.sr = nn.Conv1d(embed_dim, embed_dim, if reduced_method == "conv":
kernel_size=sample_ratio, self.sr = nn.Conv1d(embed_dim, embed_dim,
stride=sample_ratio, kernel_size=sample_ratio,
# padding=(sample_ratio - 1) // 2 stride=sample_ratio,
) # padding=(sample_ratio - 1) // 2
self.norm = nn.LayerNorm(embed_dim) )
self.norm = LayerNorm(embed_dim)
elif reduced_method == "pool":
self.linear = nn.Linear(embed_dim, embed_dim)
self.norm = LayerNorm(embed_dim)
self.act = nn.GELU()
elif reduced_method == "group":
pass
self.reset_parameters() self.reset_parameters()
...@@ -159,41 +168,6 @@ class ReducedMultiheadAttention(nn.Module): ...@@ -159,41 +168,6 @@ class ReducedMultiheadAttention(nn.Module):
assert embed_dim == self.embed_dim assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim] assert list(query.size()) == [tgt_len, bsz, embed_dim]
if (
self.sample_ratio == 1 and
not self.onnx_trace
and not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
):
assert key is not None and value is not None
return F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
if incremental_state is not None: if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state) saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state: if saved_state is not None and "prev_key" in saved_state:
...@@ -205,16 +179,41 @@ class ReducedMultiheadAttention(nn.Module): ...@@ -205,16 +179,41 @@ class ReducedMultiheadAttention(nn.Module):
else: else:
saved_state = None saved_state = None
q = self.q_proj(query) # only support self attention
if self.self_attention: if self.self_attention:
query_ = query
if self.sample_ratio > 1: if self.sample_ratio > 1:
query_ = query.permute(1, 2, 0) # bsz, dim, seq_len: assert tgt_len % self.sample_ratio == 0, \
query_ = self.sr(query_).permute(2, 0, 1) # seq_len, bsz, dim ("sample ratio %d is mismatched with length %d" % (self.sample_ratio, tgt_len))
query = self.norm(query_) if self.reduced_method == "conv":
query_ = query.permute(1, 2, 0) # bsz, dim, seq_len
query_ = self.sr(query_).permute(2, 0, 1) # seq_len, bsz, dim
query_ = self.norm(query_)
elif self.reduced_method == "pool":
query_ = query.permute(1, 2, 0) # bsz, dim, seq_len:
pool_length = int(tgt_len / self.sample_ratio)
query_ = nn.functional.adaptive_max_pool1d(query_, pool_length).permute(2, 0, 1)
query_ = self.act(self.norm(query_))
if self.reduced_q:
q = self.q_proj(query_)
tgt_len = int(tgt_len / self.sample_ratio)
else:
q = self.q_proj(query)
k = self.k_proj(query_)
v = self.v_proj(query_)
if self.sample_ratio > 1 and self.reduced_method == "group":
assert self.reduced_q is True
self.head_dim *= self.sample_ratio
q = q.transpose(0, 1).contiguous().view(bsz, -1, self.embed_dim * self.sample_ratio).transpose(0, 1)
k = q.transpose(0, 1).view(bsz, -1, self.embed_dim * self.sample_ratio).transpose(0, 1)
v = q.transpose(0, 1).view(bsz, -1, self.embed_dim * self.sample_ratio).transpose(0, 1)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention: elif self.encoder_decoder_attention:
q = self.q_proj(query)
# encoder-decoder attention # encoder-decoder attention
if key is None: if key is None:
assert value is None assert value is None
...@@ -224,10 +223,12 @@ class ReducedMultiheadAttention(nn.Module): ...@@ -224,10 +223,12 @@ class ReducedMultiheadAttention(nn.Module):
v = self.v_proj(key) v = self.v_proj(key)
else: else:
q = self.q_proj(query)
assert key is not None and value is not None assert key is not None and value is not None
k = self.k_proj(key) k = self.k_proj(key)
v = self.v_proj(value) v = self.v_proj(value)
q *= self.scaling # q *= self.scaling
q *= (self.head_dim ** -0.5)
if self.bias_k is not None: if self.bias_k is not None:
assert self.bias_v is not None assert self.bias_v is not None
...@@ -313,13 +314,15 @@ class ReducedMultiheadAttention(nn.Module): ...@@ -313,13 +314,15 @@ class ReducedMultiheadAttention(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
if self.sample_ratio > 1: if self.sample_ratio > 1:
lengths = (~key_padding_mask).sum(-1) key_padding_mask = key_padding_mask[:, ::self.sample_ratio]
lengths = (lengths / self.sample_ratio).long()
# lengths = ((lengths.float() - 1) / self.sample_ratio + 1).floor().long() # lengths = (~key_padding_mask).sum(-1)
max_length = src_len # lengths = (lengths / self.sample_ratio).long()
assert max_length >= max(lengths), (max_length, max(lengths)) # # lengths = ((lengths.float() - 1) / self.sample_ratio + 1).floor().long()
mask = torch.arange(max_length).to(lengths.device).view(1, max_length) # max_length = src_len
key_padding_mask = mask.expand(bsz, -1) >= lengths.view(bsz, 1).expand(-1, max_length) # assert max_length >= max(lengths), (max_length, max(lengths))
# mask = torch.arange(max_length).to(lengths.device).view(1, max_length)
# key_padding_mask = mask.expand(bsz, -1) >= lengths.view(bsz, 1).expand(-1, max_length)
assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len assert key_padding_mask.size(1) == src_len
...@@ -380,6 +383,10 @@ class ReducedMultiheadAttention(nn.Module): ...@@ -380,6 +383,10 @@ class ReducedMultiheadAttention(nn.Module):
assert v is not None assert v is not None
attn = torch.bmm(attn_probs, v) attn = torch.bmm(attn_probs, v)
if self.sample_ratio > 1 and self.reduced_q:
tgt_len = attn.size(1) * self.sample_ratio
self.head_dim = int(self.head_dim / self.sample_ratio)
attn = attn.view(bsz * self.num_heads, tgt_len, self.head_dim)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.size(1) == 1: if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1) # when ONNX tracing a single decoder step (sequence length == 1)
......
import torch
from torch import nn as nn
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
return x.transpose(self.dim0, self.dim1)
class Permute3D(nn.Module):
def __init__(self, dim0, dim1, dim2):
super(Permute3D, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
self.dim2 = dim2
def forward(self, x):
return x.permute(self.dim0, self.dim1, self.dim2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论