Commit 0bd92062 by xuchen

optimize the shell scripts for iwslt2022 En-Zh, implement the method of the Efficient Conformer

parent 55702466
......@@ -129,8 +129,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip
fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../feature_zip ]]; then
ln -s ${data_dir}/../feature_zip ${data_dir}
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
......
......@@ -23,10 +23,10 @@ asr_vocab_prefix=spm_unigram10000_st_share
src_lang=en
tgt_lang=zh
subsets=(train_covost)
subsets=(train_covost train_eu train_iwslt train_mustc_ende train_voxpopuil train_mustc_enzh dev tst-COMMON)
mkdir -p $data_dir
splits=$(echo ${subsets[*]} | sed 's/ /_/g')
splits=$(echo ${subsets[*]} | sed 's/ /,/g')
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
--data-root ${org_data_dir}
--output-root ${data_dir}
......
train-subset: train
#train-subset: train_covost,train_eu,train_iwslt,train_mustc_ende,train_voxpopuil,train_mustc_enzh
train-subset: train_mustc_enzh
valid-subset: dev
max-epoch: 100
max-update: 100000
max-update: 1000000
patience: 20
best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
......
......@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
......
......@@ -30,10 +30,10 @@ pwd_dir=$PWD
# dataset
src_lang=en
tgt_lang=de
tgt_lang=zh
lang=${src_lang}-${tgt_lang}
dataset=mustc
dataset=iwslt2022
task=speech_to_text
vocab_type=unigram
vocab_size=5000
......@@ -42,7 +42,7 @@ lcrm=0
tokenizer=0
use_raw_audio=0
use_specific_dict=1
use_specific_dict=0
specific_prefix=st
specific_dir=${root_dir}/data/mustc/st
asr_vocab_prefix=spm_unigram10000_st_share
......@@ -125,8 +125,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip
fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../feature_zip ]]; then
ln -s ${data_dir}/../feature_zip ${data_dir}
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
......
......@@ -4,7 +4,7 @@
gpu_num=8
update_freq=1
max_tokens=40000
max_tokens=80000
extra_tag=
extra_parameter=
......@@ -13,11 +13,11 @@ extra_parameter=
exp_tag=
config_list=(base ctc)
config_list=(purectc)
#config_list=(base ctc)
#config_list=(purectc)
#config_list=(base conformer)
#config_list=(pds_base_16)
config_list=(pds_base_8 ctc)
#config_list=(pds_base_16 conformer rpr)
# exp full name
......
......@@ -30,15 +30,17 @@ pwd_dir=$PWD
# dataset
src_lang=en
tgt_lang=de
tgt_lang=zh
lang=${src_lang}-${tgt_lang}
dataset=mustc
dataset=iwslt2022
task=translation
vocab_type=unigram
vocab_size=10000
share_dict=1
lcrm=0
src_vocab_type=unigram
tgt_vocab_type=unigram
src_vocab_size=32000
tgt_vocab_size=32000
share_dict=0
lcrm=1
tokenizer=0
use_specific_dict=1
......@@ -49,7 +51,7 @@ tgt_vocab_prefix=spm_unigram10000_st_share
org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/mt
train_subset=train
train_subset=train_mustc_enzh
valid_subset=dev
trans_subset=tst-COMMON
test_subset=test
......@@ -82,15 +84,23 @@ if [[ ${use_specific_dict} -eq 1 ]]; then
data_dir=${data_dir}/${specific_prefix}
mkdir -p ${data_dir}
else
if [[ "${vocab_type}" == "char" ]]; then
vocab_name=${vocab_type}
exp_prefix=${exp_prefix}_${vocab_type}
if [[ "${tgt_vocab_type}" == "char" ]]; then
vocab_name=char
exp_prefix=${exp_prefix}_char
else
vocab_name=${vocab_type}${vocab_size}
if [[ ${src_vocab_size} -ne ${tgt_vocab_size} || "${src_vocab_type}" -ne "${tgt_vocab_type}" ]]; then
src_vocab_name=${src_vocab_type}${src_vocab_size}
tgt_vocab_name=${tgt_vocab_type}${tgt_vocab_size}
vocab_name=${src_vocab_name}_${tgt_vocab_name}
else
vocab_name=${tgt_vocab_type}${tgt_vocab_size}
src_vocab_name=${vocab_name}
tgt_vocab_name=${vocab_name}
fi
fi
data_dir=${data_dir}/${vocab_name}
src_vocab_prefix=spm_${vocab_name}_${src_lang}
tgt_vocab_prefix=spm_${vocab_name}_${tgt_lang}
src_vocab_prefix=spm_${src_vocab_name}_${src_lang}
tgt_vocab_prefix=spm_${tgt_vocab_name}_${tgt_lang}
if [[ $share_dict -eq 1 ]]; then
data_dir=${data_dir}_share
src_vocab_prefix=spm_${vocab_name}_share
......@@ -141,8 +151,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--splits ${train_subset},${valid_subset},${trans_subset}
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--vocab-type ${vocab_type}
--vocab-size ${vocab_size}"
--src-vocab-type ${src_vocab_type}
--tgt-vocab-type ${tgt_vocab_type}
--src-vocab-size ${src_vocab_size}
--tgt-vocab-size ${tgt_vocab_size}"
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
......
train-subset: train
#train-subset: train_mustc_enzh,train_covost
train-subset: train_mustc_enzh
valid-subset: dev
max-epoch: 100
......
......@@ -29,7 +29,7 @@ acoustic-encoder: pds
adapter: league
encoder-embed-dim: 256
ctc-layer: 12
#ctc-layer: 12
pds-stages: 4
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
......
......@@ -10,8 +10,8 @@ if [ "$#" -eq 1 ]; then
exp_name=$1
fi
sacrebleu=1
n_average=10
sacrebleu=0
n_average=1
beam_size=5
len_penalty=1.0
max_tokens=80000
......
......@@ -30,29 +30,29 @@ pwd_dir=$PWD
# dataset
src_lang=en
tgt_lang=de
tgt_lang=zh
lang=${src_lang}-${tgt_lang}
dataset=mustc
dataset=iwslt2022
task=speech_to_text
vocab_type=unigram
asr_vocab_size=5000
vocab_size=10000
share_dict=1
share_dict=0
speed_perturb=0
lcrm=0
lcrm=1
tokenizer=0
use_raw_audio=0
use_specific_dict=0
specific_prefix=valid
specific_dir=${root_dir}/data/mustc/st
asr_vocab_prefix=spm_unigram10000_st_share
st_vocab_prefix=spm_unigram10000_st_share
use_specific_dict=1
specific_prefix=asr
specific_dir=${root_dir}/data/${dataset}/asr
asr_vocab_prefix=spm_unigram5000_asr
st_vocab_prefix=
org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/st
train_split=train
train_split=train_mustc_enzh
valid_split=dev
test_split=tst-COMMON
test_subset=tst-COMMON
......@@ -133,8 +133,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip
fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../feature_zip ]]; then
ln -s ${data_dir}/../feature_zip ${data_dir}
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi
# create ASR vocabulary if necessary
......@@ -147,8 +147,12 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--splits ${valid_split},${test_split},${train_split}
--vocab-type ${vocab_type}
--vocab-size ${asr_vocab_size}"
[[ $eval -eq 1 && ${share_dict} -ne 1 && ${use_specific_dict} -ne 1 ]] && (echo -e "\033[34mRun command: \n${cmd} \033[0m" && eval $cmd)
asr_prefix=spm_${vocab_type}${asr_vocab_size}_asr
if [[ $eval -eq 1 && ${share_dict} -ne 1 && ${use_specific_dict} -ne 1 ]]; then
echo -e "\033[34mRun command: \n${cmd} \033[0m"
eval $cmd
asr_vocab_prefix=spm_${vocab_type}${asr_vocab_size}_asr
cp ${data_dir}/asr4st/${asr_vocab_prefix}* ${data_dir}
fi
echo "stage 0: ST Data Preparation"
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
......@@ -167,25 +171,21 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="$cmd
--raw"
fi
if [[ ${use_specific_dict} -eq 1 ]]; then
cp -r ${specific_dir}/${asr_vocab_prefix}.* ${data_dir}
cp -r ${specific_dir}/${st_vocab_prefix}.* ${data_dir}
if [[ $share_dict -eq 1 ]]; then
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share
--st-spm-prefix ${st_vocab_prefix}"
else
--share"
else
cmd="$cmd
--st-spm-prefix ${st_vocab_prefix}
--asr-prefix ${asr_vocab_prefix}"
fi
if [[ ${use_specific_dict} -eq 1 ]]; then
if [[ ${share_dict} -eq 0 && -n ${asr_vocab_prefix} ]]; then
cp -r ${specific_dir}/${asr_vocab_prefix}.* ${data_dir}
fi
else
if [[ $share_dict -eq 1 ]]; then
if [[ -n ${st_vocab_prefix} ]]; then
cp -r ${specific_dir}/${st_vocab_prefix}.* ${data_dir}
cmd="$cmd
--share"
else
cmd="$cmd
--asr-prefix ${asr_prefix}"
--st-spm-prefix ${st_vocab_prefix}"
fi
fi
if [[ ${speed_perturb} -eq 1 ]]; then
......
......@@ -14,13 +14,13 @@ extra_parameter=
exp_tag=
#config_list=(base)
config_list=(ctc)
#config_list=(sate_ctc)
#config_list=(sate ctc)
#config_list=(ctc conformer rpr)
#config_list=(base sate)
#config_list=(pds_base)
#config_list=(pds_base conformer)
config_list=(sate_pds ctc)
#config_list=(pds_base_8)
#config_list=(pds_base_8 conformer)
# exp full name
exp_name=
......
......@@ -2,6 +2,17 @@ arch: s2t_ctc
encoder-type: pds
#arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: transfer
#relative-pos-enc: True
encoder-attention-type: rel_pos
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
......
#! /bin/bash
# Processing MuST-C Datasets
# Processing LibriSpeech En-Fr Datasets
# Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com)
......@@ -124,8 +124,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip
fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../feature_zip ]]; then
ln -s ${data_dir}/../feature_zip ${data_dir}
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
......
......@@ -6,7 +6,6 @@ gpu_num=2
update_freq=1
max_tokens=40000
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
......@@ -15,10 +14,9 @@ extra_parameter=
exp_tag=
#config_list=(base)
#config_list=(ctc)
#config_list=(base conformer)
#config_list=(pds_base_16)
#config_list=(pds_base_8)
config_list=(pds_base_8 conformer rpr)
# exp full name
......
#encoder-attention-type: rel_selfattn
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 20
max-decoder-relative-length: 20
max-encoder-relative-length: 8
max-decoder-relative-length: 8
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 8_4_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 176
pds-stages: 4
ctc-layer: 16
pds-layers: 4_4_4_4
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 176_176_176_176
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 4_4_4_4
pds-attn-heads: 4_4_4_4
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
#load-pretrained-encoder-from:
......@@ -11,12 +11,12 @@ extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
#exp_tag=
exp_tag=
#config_list=(base)
#config_list=(ctc)
#config_list=(ctc conformer rpr)
config_list=(base conformer rpr)
#config_list=(base conformer)
#config_list=(ConformerCTCSmall)
config_list=(purectc_pds_base_16)
#config_list=(pds_base)
#config_list=(pds_big)
#config_list=(pds_deep)
......
arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
......
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 8_4_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
......@@ -26,17 +35,12 @@ lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
encoder-attention-type: rel_selfattn
encoder-attention-type: rel_pos
#encoder-attention-type: relative
#max-encoder-relative-length: 100
......@@ -125,8 +125,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip
fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../feature_zip ]]; then
ln -s ${data_dir}/../feature_zip ${data_dir}
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
......
......@@ -17,8 +17,8 @@ config_list=(base ctc)
config_list=(purectc)
#config_list=(base conformer)
#config_list=(pds_base_16)
#config_list=(pds_base_16 conformer rpr)
config_list=(pds_base_8)
config_list=(purectc_pds_base_8)
# exp full name
exp_name=
......
......@@ -7,7 +7,7 @@ update_freq=1
max_tokens=8192
exp_tag=baseline
config_list=(base)
config_list=(small)
# exp full name
exp_name=
......
......@@ -2,5 +2,6 @@ ctc-weight: 0.2
intermedia-ctc-layers: 6,9
intermedia-adapter: league
intermedia-ctc-weight: 0.1
#intermedia-drop-prob: 0.2
ctc-self-distill-weight: 0
post-process: sentencepiece
\ No newline at end of file
arch: pdss2t_transformer_s_8
pds-ctc: 1_1_1_1
intermedia-adapter: league
intermedia-ctc-weight: 0.15
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
#pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
......
encoder-attention-type: rel_selfattn
encoder-attention-type: rel_pos
#encoder-attention-type: rel_pos_legacy
#encoder-attention-type: rel_selfattn
#encoder-attention-type: relative
#decoder-attention-type: relative
#max-encoder-relative-length: 100
......
......@@ -133,8 +133,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip
fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../feature_zip ]]; then
ln -s ${data_dir}/../feature_zip ${data_dir}
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi
# create ASR vocabulary if necessary
......
......@@ -14,13 +14,12 @@ extra_parameter=
exp_tag=
#config_list=(base)
config_list=(ctc)
#config_list=(sate_ctc)
#config_list=(ctc conformer rpr)
#config_list=(base sate)
#config_list=(base ctc conformer)
#config_list=(sate ctc)
#config_list=(pds_base)
#config_list=(pds_base_8)
#config_list=(pds_base conformer)
#config_list=(sate_pds ctc)
# exp full name
exp_name=
......
......@@ -327,14 +327,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
export CUDA_VISIBLE_DEVICES=${device}
log=${model_dir}/train.log
if [[ -e ${log} ]]; then
for i in `seq 1 100`; do
if [ ! -e ${log}.${i} ]; then
log=${log}.${i}
break
fi
done
fi
cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then
......
set -e
eval=1
lcrm=0
src_lang=en
tgt_lang=zh
tokenize=1
splits=(tst-COMMON test11)
dataset=wmt20
root_dir=~/st/Fairseq-S2T
data_dir=/home/xuchen/st/data/$dataset/data
vocab_dir=/home/xuchen/st/data/$dataset/mt/unigram32000_tok
dest_dir=$vocab_dir
src_vocab_prefix=spm_unigram32000_en
tgt_vocab_prefix=spm_unigram32000_zh
for split in ${splits[@]}; do
src_file=${data_dir}/${split}/${split}.${src_lang}
tgt_file=${data_dir}/${split}/${split}.${tgt_lang}
if [[ ${tokenize} -eq 1 ]]; then
src_tok_file=${data_dir}/${split}.tok/${split}.tok.${src_lang}
tgt_tok_file=${data_dir}/${split}.tok/${split}.tok.${tgt_lang}
if [[ ! -f ${src_tok_file} ]]; then
cmd="tokenizer.perl -l ${src_lang} --threads 8 -no-escape < ${src_file} > ${src_tok_file}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
fi
if [[ ! -f ${tgt_tok_file} ]]; then
cmd="tokenizer.perl -l ${tgt_lang} --threads 8 -no-escape < ${tgt_file} > ${tgt_tok_file}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
fi
src_file=${src_tok_file}
tgt_file=${tgt_tok_file}
fi
cmd="cat ${src_file}"
if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${src_file}"
fi
cmd="${cmd}
| spm_encode --model ${vocab_dir}/${src_vocab_prefix}.model
--output_format=piece
> ${src_file}.spm"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
cmd="spm_encode
--model ${vocab_dir}/${tgt_vocab_prefix}.model
--output_format=piece
< ${tgt_file}
> ${tgt_file}.spm"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
src_file=${src_file}.spm
tgt_file=${tgt_file}.spm
mkdir -p ${dest_dir}/final
cmd="cp ${src_file} ${dest_dir}/final/${split}.${src_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
cmd="cp ${tgt_file} ${dest_dir}/final/${split}.${tgt_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
done
n_set=${#splits[*]}
for ((i=0;i<$n_set;i++)); do
dataset[$i]=${dest_dir}/final/${splits[$i]}
done
pref=`echo ${dataset[*]} | sed 's/ /,/g'`
cmd="python ${root_dir}/fairseq_cli/preprocess.py
--source-lang ${src_lang}
--target-lang ${tgt_lang}
--testpref ${pref}
--destdir ${dest_dir}/data-bin
--srcdict ${vocab_dir}/${src_vocab_prefix}.txt
--tgtdict ${vocab_dir}/${tgt_vocab_prefix}.txt
--workers 64"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
arch: transformer
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 2e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: transformer
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 4000
lr: 7e-4
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: False
decoder-normalize-before: False
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
train-subset: train
valid-subset: valid
max-epoch: 20
max-update: 100000
patience: 5
best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
no-epoch-checkpoints: True
#keep-last-epochs: 10
keep-best-checkpoints: 5
num-workers: 8
no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
skip-invalid-size-inputs-valid-test: True
max-source-positions: 512
arch: transformer_wmt_en_de_big_t2t
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 7e-4
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.3
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 1024
encoder-ffn-embed-dim: 4096
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 16
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: transformer_wmt_en_de_big
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 4000
lr: 5e-4
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.3
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: False
decoder-normalize-before: False
encoder-embed-dim: 1024
encoder-ffn-embed-dim: 4096
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 16
decoder-embed-dim: 1024
decoder-ffn-embed-dim: 4096
decoder-attention-heads: 16
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: transformer
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 16000
lr: 2e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 30
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
use-enc-dlcl: True
use-dec-dlcl: True
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 8
max-decoder-relative-length: 8
#! /bin/bash
gpu_num=1
data_dir=
test_subset=(test)
exp_name=
if [ "$#" -eq 1 ]; then
exp_name=$1
fi
sacrebleu=0
n_average=5
beam_size=4
len_penalty=0.6
max_tokens=80000
dec_model=checkpoint_best.pt
cmd="./run.sh
--stage 2
--stop_stage 2
--gpu_num ${gpu_num}
--exp_name ${exp_name}
--sacrebleu ${sacrebleu}
--n_average ${n_average}
--beam_size ${beam_size}
--len_penalty ${len_penalty}
--max_tokens ${max_tokens}
--dec_model ${dec_model}
"
if [[ -n ${data_dir} ]]; then
cmd="$cmd --data_dir ${data_dir}"
fi
if [[ -n ${test_subset} ]]; then
test_subset=`echo ${test_subset[*]} | sed 's/ /,/g'`
cmd="$cmd --test_subset ${test_subset}"
fi
echo $cmd
eval $cmd
import sys
import string
in_file = sys.argv[1]
with open(in_file, "r", encoding="utf-8") as f:
for line in f.readlines():
line = line.strip().lower()
for w in string.punctuation:
line = line.replace(w, "")
line = line.replace(" ", "")
print(line)
gpu_num=4
cmd="sh train.sh"
while :
do
record=$(mktemp -t temp.record.XXXXXX)
gpustat > $record
all_devices=$(seq 0 "$(sed '1,2d' ${record} | wc -l)");
count=0
for dev in ${all_devices[@]}
do
line=$((dev + 2))
use=$(head -n $line ${record} | tail -1 | cut -d '|' -f3 | cut -d '/' -f1)
if [[ $use -lt 100 ]]; then
device[$count]=$dev
count=$((count + 1))
if [[ $count -eq $gpu_num ]]; then
break
fi
fi
done
if [[ ${#device[@]} -lt $gpu_num ]]; then
sleep 60s
else
echo "Run $cmd"
eval $cmd
sleep 10s
exit
fi
done
#!/usr/bin/env perl
#
# This file is part of moses. Its use is licensed under the GNU Lesser General
# Public License version 2.1 or, at your option, any later version.
# $Id$
use warnings;
use strict;
my $lowercase = 0;
if ($ARGV[0] eq "-lc") {
$lowercase = 1;
shift;
}
my $stem = $ARGV[0];
if (!defined $stem) {
print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n";
print STDERR "Reads the references from reference or reference0, reference1, ...\n";
exit(1);
}
$stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0";
my @REF;
my $ref=0;
while(-e "$stem$ref") {
&add_to_ref("$stem$ref",\@REF);
$ref++;
}
&add_to_ref($stem,\@REF) if -e $stem;
die("ERROR: could not find reference file $stem") unless scalar @REF;
# add additional references explicitly specified on the command line
shift;
foreach my $stem (@ARGV) {
&add_to_ref($stem,\@REF) if -e $stem;
}
sub add_to_ref {
my ($file,$REF) = @_;
my $s=0;
if ($file =~ /.gz$/) {
open(REF,"gzip -dc $file|") or die "Can't read $file";
} else {
open(REF,$file) or die "Can't read $file";
}
while(<REF>) {
chop;
push @{$$REF[$s++]}, $_;
}
close(REF);
}
my(@CORRECT,@TOTAL,$length_translation,$length_reference);
my $s=0;
while(<STDIN>) {
chop;
$_ = lc if $lowercase;
my @WORD = split;
my %REF_NGRAM = ();
my $length_translation_this_sentence = scalar(@WORD);
my ($closest_diff,$closest_length) = (9999,9999);
foreach my $reference (@{$REF[$s]}) {
# print "$s $_ <=> $reference\n";
$reference = lc($reference) if $lowercase;
my @WORD = split(' ',$reference);
my $length = scalar(@WORD);
my $diff = abs($length_translation_this_sentence-$length);
if ($diff < $closest_diff) {
$closest_diff = $diff;
$closest_length = $length;
# print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n";
} elsif ($diff == $closest_diff) {
$closest_length = $length if $length < $closest_length;
# from two references with the same closeness to me
# take the *shorter* into account, not the "first" one.
}
for(my $n=1;$n<=4;$n++) {
my %REF_NGRAM_N = ();
for(my $start=0;$start<=$#WORD-($n-1);$start++) {
my $ngram = "$n";
for(my $w=0;$w<$n;$w++) {
$ngram .= " ".$WORD[$start+$w];
}
$REF_NGRAM_N{$ngram}++;
}
foreach my $ngram (keys %REF_NGRAM_N) {
if (!defined($REF_NGRAM{$ngram}) ||
$REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) {
$REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram};
# print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}<BR>\n";
}
}
}
}
$length_translation += $length_translation_this_sentence;
$length_reference += $closest_length;
for(my $n=1;$n<=4;$n++) {
my %T_NGRAM = ();
for(my $start=0;$start<=$#WORD-($n-1);$start++) {
my $ngram = "$n";
for(my $w=0;$w<$n;$w++) {
$ngram .= " ".$WORD[$start+$w];
}
$T_NGRAM{$ngram}++;
}
foreach my $ngram (keys %T_NGRAM) {
$ngram =~ /^(\d+) /;
my $n = $1;
# my $corr = 0;
# print "$i e $ngram $T_NGRAM{$ngram}<BR>\n";
$TOTAL[$n] += $T_NGRAM{$ngram};
if (defined($REF_NGRAM{$ngram})) {
if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) {
$CORRECT[$n] += $T_NGRAM{$ngram};
# $corr = $T_NGRAM{$ngram};
# print "$i e correct1 $T_NGRAM{$ngram}<BR>\n";
}
else {
$CORRECT[$n] += $REF_NGRAM{$ngram};
# $corr = $REF_NGRAM{$ngram};
# print "$i e correct2 $REF_NGRAM{$ngram}<BR>\n";
}
}
# $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram};
# print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n"
}
}
$s++;
}
my $brevity_penalty = 1;
my $bleu = 0;
my @bleu=();
for(my $n=1;$n<=4;$n++) {
if (defined ($TOTAL[$n])){
$bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0;
# print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n";
}else{
$bleu[$n]=0;
}
}
if ($length_reference==0){
printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n";
exit(1);
}
if ($length_translation<$length_reference) {
$brevity_penalty = exp(1-$length_reference/$length_translation);
}
$bleu = $brevity_penalty * exp((my_log( $bleu[1] ) +
my_log( $bleu[2] ) +
my_log( $bleu[3] ) +
my_log( $bleu[4] ) ) / 4) ;
printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n",
100*$bleu,
100*$bleu[1],
100*$bleu[2],
100*$bleu[3],
100*$bleu[4],
$brevity_penalty,
$length_translation / $length_reference,
$length_translation,
$length_reference;
sub my_log {
return -9999999999 unless $_[0];
return log($_[0]);
}
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.
#!/usr/bin/env perl
#
# This file is part of moses. Its use is licensed under the GNU Lesser General
# Public License version 2.1 or, at your option, any later version.
use warnings;
use strict;
#binmode(STDIN, ":utf8");
#binmode(STDOUT, ":utf8");
while(<STDIN>) {
s/,/,/g;
s/。 */. /g;
s/、/,/g;
s/”/"/g;
s/“/"/g;
s/∶/:/g;
s/:/:/g;
s/?/\?/g;
s/《/"/g;
s/》/"/g;
s/)/\)/g;
s/!/\!/g;
s/(/\(/g;
s/;/;/g;
s/1/"/g;
s/」/"/g;
s/「/"/g;
s/0/0/g;
s/3/3/g;
s/2/2/g;
s/5/5/g;
s/6/6/g;
s/9/9/g;
s/7/7/g;
s/8/8/g;
s/4/4/g;
s/. */. /g;
s/~/\~/g;
s/’/\'/g;
s/…/\.\.\./g;
s/━/\-/g;
s/〈/\</g;
s/〉/\>/g;
s/【/\[/g;
s/】/\]/g;
s/%/\%/g;
print $_;
}
get_devices(){
gpu_num=$1
use_cpu=$2
device=()
while :
do
record=$(mktemp -t temp.record.XXXXXX)
gpustat > $record
all_devices=$(seq 0 "$(sed '1,2d' ${record} | wc -l)");
count=0
for dev in ${all_devices[@]}
do
line=$((dev + 2))
use=$(head -n $line ${record} | tail -1 | cut -d '|' -f3 | cut -d '/' -f1)
if [[ $use -lt 100 ]]; then
device[$count]=$dev
count=$((count + 1))
if [[ $count -eq $gpu_num ]]; then
break
fi
fi
done
if [[ ${#device[@]} -lt $gpu_num ]]; then
if [[ $use_cpu -eq 1 ]]; then
device=(-1)
else
sleep 60s
fi
else
break
fi
done
echo ${device[*]} | sed 's/ /,/g'
return $?
}
#! /bin/bash
# calculate wmt14 en-de multi-bleu score
if [ $# -ne 1 ]; then
echo "usage: $0 GENERATE_PY_OUTPUT"
exit 1
fi
echo -e "\n RUN >> "$0
requirement_scripts=(detokenizer.perl replace-unicode-punctuation.perl tokenizer.perl multi-bleu.perl)
for script in ${requirement_scripts[@]}; do
if ! which ${script} > /dev/null; then
echo "Error: it seems that moses is not installed or exported int the environment variables." >&2
return 1
fi
done
detokenizer=detokenizer.perl
replace_unicode_punctuation=replace-unicode-punctuation.perl
tokenizer=tokenizer.perl
multi_bleu=multi-bleu.perl
GEN=$1
SYS=$GEN.sys
REF=$GEN.ref
cat $GEN | cut -f 3 > $REF
cat $GEN | cut -f 4 > $SYS
#detokenize the decodes file to format the manner to do tokenize
$detokenizer -l de < $SYS > $SYS.dtk
$detokenizer -l de < $REF > $REF.dtk
#replace unicode
$replace_unicode_punctuation -l de < $SYS.dtk > $SYS.dtk.punc
$replace_unicode_punctuation -l de < $REF.dtk > $REF.dtk.punc
#tokenize the decodes file by moses tokenizer.perl
$tokenizer -l de < $SYS.dtk.punc > $SYS.dtk.punc.tok
$tokenizer -l de < $REF.dtk.punc > $REF.dtk.punc.tok
#"rich-text format" --> rich ##AT##-##AT## text format.
perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $SYS.dtk.punc.tok > $SYS.dtk.punc.tok.atat
perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $REF.dtk.punc.tok > $REF.dtk.punc.tok.atat
$multi_bleu $REF.dtk.punc.tok.atat < $SYS.dtk.punc.tok.atat
rm -f $SYS.dtk $SYS.dtk.punc $SYS.dtk.punc.tok $REF.dtk $REF.dtk.punc $REF.dtk.punc.tok
\ No newline at end of file
#! /bin/bash
# training the model
gpu_num=8
update_freq=2
max_tokens=8192
exp_tag=baseline
#config_list=(base)
config_list=(deep)
# exp full name
exp_name=
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
cmd="./run.sh
--stage 1
--stop_stage 1
--gpu_num ${gpu_num}
--update_freq ${update_freq}
--train_config ${train_config}
--max_tokens ${max_tokens}
"
if [[ -n ${exp_name} ]]; then
cmd="$cmd --exp_name ${exp_name}"
fi
if [[ -n ${exp_tag} ]]; then
cmd="$cmd --exp_tag ${exp_tag}"
fi
if [[ -n ${extra_tag} ]]; then
cmd="$cmd --extra_tag ${extra_tag}"
fi
if [[ -n ${extra_parameter} ]]; then
cmd="$cmd --extra_parameter \"${extra_parameter}\""
fi
echo ${cmd}
eval ${cmd}
......@@ -112,7 +112,7 @@ class AudioDataset(Dataset):
if self.mode == "easy":
real_idx = 0
for idx, v in segments.items():
audio_name = v["audio"]
audio_name = f"{split}_{v['audio']}"
v["audio"] = (wav_root / v["audio"].strip()).as_posix() + ".wav"
if self.speed_perturb is not None:
for perturb in self.speed_perturb:
......@@ -137,8 +137,8 @@ class AudioDataset(Dataset):
for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate)
# _id = f"{split}_{wav_path.stem}_{i}"
_id = f"{wav_path.stem}_{i}"
_id = f"{split}_{wav_path.stem}_{i}"
# _id = f"{wav_path.stem}_{i}"
item = dict()
item["audio"] = wav_path.as_posix()
......@@ -237,7 +237,7 @@ def process(args):
if not Path.exists(zip_path) or args.overwrite:
gen_feature_flag = True
if True and gen_feature_flag:
if gen_feature_flag:
if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
......@@ -264,12 +264,8 @@ def process(args):
utt_id = item['id']
features_path = (feature_root / f"{utt_id}.npy").as_posix()
tag_features_path = (feature_root / f"{split}_{utt_id}.npy").as_posix()
if os.path.exists(tag_features_path):
continue
if os.path.exists(features_path) and not os.path.exists(tag_features_path):
shutil.move(features_path, tag_features_path)
if os.path.exists(features_path):
continue
waveform, sample_rate, _ = dataset.get(idx, need_waveform=True)
......
......@@ -96,16 +96,19 @@ def process(args):
tgt_train_text.extend(manifest["tgt_text"])
# Generate vocab and yaml
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}"
tgt_v_size_str = "" if args.tgt_vocab_type == "char" else str(args.tgt_vocab_size)
tgt_spm_filename_prefix = f"spm_{args.tgt_vocab_type}{tgt_v_size_str}"
if args.share:
tgt_train_text.extend(src_train_text)
src_spm_filename_prefix = spm_filename_prefix + "_share"
tgt_spm_filename_prefix = src_spm_filename_prefix
tgt_spm_filename_prefix = tgt_spm_filename_prefix + "_share"
src_spm_filename_prefix = tgt_spm_filename_prefix
else:
src_spm_filename_prefix = spm_filename_prefix + "_" + src_lang
tgt_spm_filename_prefix = spm_filename_prefix + "_" + tgt_lang
src_v_size_str = "" if args.src_vocab_type == "char" else str(args.src_vocab_size)
src_spm_filename_prefix = f"spm_{args.src_vocab_type}{src_v_size_str}"
src_spm_filename_prefix = src_spm_filename_prefix + "_" + src_lang
tgt_spm_filename_prefix = tgt_spm_filename_prefix + "_" + tgt_lang
with NamedTemporaryFile(mode="w") as f:
for t in tgt_train_text:
......@@ -113,8 +116,8 @@ def process(args):
gen_vocab(
Path(f.name),
output_root / tgt_spm_filename_prefix,
args.vocab_type,
args.vocab_size,
args.tgt_vocab_type,
args.tgt_vocab_size,
normalization_rule_name="identity" if tgt_lang == "zh" else None
)
......@@ -125,8 +128,8 @@ def process(args):
gen_vocab(
Path(f.name),
output_root / src_spm_filename_prefix,
args.vocab_type,
args.vocab_size,
args.src_vocab_type,
args.src_vocab_size,
normalization_rule_name="identity" if tgt_lang == "zh" else None
)
......@@ -135,7 +138,7 @@ def process(args):
if args.share:
yaml_filename = f"config_share.yaml"
conf = {}
conf = dict()
conf["src_vocab_filename"] = src_spm_filename_prefix + ".txt"
conf["tgt_vocab_filename"] = tgt_spm_filename_prefix + ".txt"
conf["src_bpe_tokenizer"] = {
......@@ -157,13 +160,21 @@ def main():
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument("--output-root", "-o", default=None, type=str)
parser.add_argument(
"--vocab-type",
"--src-vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
)
parser.add_argument(
"--tgt-vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=10000, type=int)
)
parser.add_argument("--src-vocab-size", default=10000, type=int)
parser.add_argument("--tgt-vocab-size", default=10000, type=int)
parser.add_argument("--size", default=-1, type=int)
parser.add_argument("--splits", default="train,dev,test", type=str)
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
......
......@@ -704,6 +704,8 @@ def load_pretrained_component_from_model(
if key.startswith(component_type):
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
component_subkey = key[len(component_type) + 1:]
if component_subkey.startswith(component_type):
component_subkey = component_subkey[len(component_type) + 1:]
component_state_dict[component_subkey] = state["model"][key]
mismatch_keys = []
......
......@@ -91,7 +91,7 @@ class Adapter(nn.Module):
logger.info("CTC Compress Strategy: %s" % strategy)
elif self.adapter_type == "league":
self.distribution_cutoff = strategy
if self.distribution_cutoff != -1:
if self.distribution_cutoff is not None:
logger.info("Distribution cutoff: %d" % int(strategy))
def forward(self, x, padding):
......@@ -112,7 +112,7 @@ class Adapter(nn.Module):
elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation)
if self.distribution_cutoff != -1:
if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), distribution.size(-1) - 1)
threshold = distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
distribution = torch.where(distribution > threshold, distribution, torch.zeros_like(distribution))
......
......@@ -192,9 +192,34 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"rel_pos",
"rope",
"abs",
"transfer",
],
help="transformer encoder self-attention layer type"
)
# transfer
parser.add_argument(
"--relative-pos-enc",
action="store_true",
help="use relative position encoding for attention",
)
parser.add_argument(
"--linear-att",
action="store_true",
help="use linear attention",
)
# reduced attention
parser.add_argument(
"--attention-reduced-method",
type=str,
default="conv",
help="reduction method for attention",
)
parser.add_argument(
"--attention-reduced-q",
action="store_true",
help="use reduction for query or not"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
......@@ -450,9 +475,9 @@ class PDSS2TTransformerModel(S2TTransformerModel):
help="the number of the attention heads in each stage",
)
parser.add_argument(
"--pds-attn-ds-ratio",
"--pds-attn-ds-ratios",
type=str,
help="the ratio of the down-sampling in the self attention module",
help="the ratios of the down-sampling in the self attention module",
)
parser.add_argument(
"--pds-ffn-ratios",
......@@ -495,7 +520,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=-1,
default=None,
type=int,
help="cutoff of the distribution",
)
......@@ -641,7 +666,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
pos_embed = None
stage = nn.ModuleList([
PDSTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_ds_ratio)
PDSTransformerEncoderLayer(args, embed_dim, ffn_ratio, num_head, attn_ds_ratio)
for _ in range(num_layers)])
# representation fusion
......@@ -735,9 +760,12 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
(("ctc" in getattr(args, "criterion", "")) and
(getattr(args, "ctc_weight", False) > 0))
if self.use_ctc:
self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers
self.ctc_layer = args.encoder_layers if self.ctc_layer == 0 else self.ctc_layer
self.inter_ctc = True if self.ctc_layer != args.encoder_layers or self.fusion_stages_num != 0 else False
# self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers
# self.ctc_layer = args.encoder_layers if self.ctc_layer == 0 else self.ctc_layer
# self.inter_ctc = True if self.ctc_layer != args.encoder_layers or self.fusion_stages_num != 0 else False
self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != 0 else False
if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
......@@ -1027,7 +1055,7 @@ def base_architecture(args):
# intermedia CTC
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0")
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.ctc_self_distill = getattr(args, "ctc_self_distill", False)
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
def set_pds_base_8(args):
......
......@@ -131,11 +131,34 @@ class S2TCTCModel(FairseqEncoderModel):
"relative",
"rel_pos",
"rope",
"abs"
"abs",
"transfer",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--relative-pos-enc",
action="store_true",
help="use relative position encoding for attention",
)
parser.add_argument(
"--linear-att",
action="store_true",
help="use linear attention",
)
parser.add_argument(
"--attention-reduced-method",
type=str,
default="conv",
help="reduction method for attention",
)
parser.add_argument(
"--attention-reduced-q",
action="store_true",
help="use reduction for query or not",
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
......@@ -412,7 +435,7 @@ class S2TCTCModel(FairseqEncoderModel):
help="the number of the attention heads in each stage",
)
parser.add_argument(
"--pds-attn-ds-ratio",
"--pds-attn-ds-ratios",
type=str,
help="the ratio of the down-sampling in the self attention module",
)
......@@ -457,7 +480,7 @@ class S2TCTCModel(FairseqEncoderModel):
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=-1,
default=None,
type=int,
help="cutoff of the distribution",
)
......@@ -931,6 +954,26 @@ def base_architecture(args):
args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None)
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
# PDS
args.pds_stages = getattr(args, "pds_stages", None)
args.pds_layers = getattr(args, "pds_layers", None)
args.pds_ratios = getattr(args, "pds_ratios", None)
args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", True)
args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
......
......@@ -5,13 +5,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
FairseqEncoder,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.models.speech_to_text import (
S2TTransformerModel,
S2TTransformerEncoder,
......@@ -314,12 +312,12 @@ class S2TSATEEncoder(FairseqEncoder):
if args.adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", "avg")
elif args.adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", -1)
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(args.encoder_embed_dim,
args.adapter,
task.source_dictionary,
embed_tokens,
embed_tokens if task.source_dictionary == task.target_dictionary else None,
strategy=strategy)
if args.share_ctc_and_adapter and hasattr(self.adapter, "embed_adapter"):
......
......@@ -385,7 +385,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=-1,
default=None,
type=int,
help="cutoff of the distribution",
)
......@@ -581,7 +581,7 @@ class S2TTransformerEncoder(FairseqEncoder):
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", -1)
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
......
......@@ -8,6 +8,7 @@ from .squeeze_excitation import SEAttention
from .activations import swish, Swish
from .adaptive_input import AdaptiveInput
from .adaptive_softmax import AdaptiveSoftmax
from .attention import MultiHeadSelfAttentionModule
from .beamable_mm import BeamableMM
from .character_token_embedder import CharacterTokenEmbedder
from .downsample_convolution import DownSampleConvolutionModule
......@@ -91,6 +92,7 @@ __all__ = [
"LinearizedConvolution",
"LocalMultiheadAttention",
"MultiheadAttention",
"MultiHeadSelfAttentionModule",
"PositionalEmbedding",
"PDSTransformerEncoderLayer",
"ReducedMultiheadAttention",
......
......@@ -43,6 +43,7 @@ def get_activation_class(activation: str, dim=None):
else:
raise RuntimeError("activation function {} not supported".format(activation))
def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
......
......@@ -281,7 +281,7 @@ class StridedMultiHeadAttention(MultiHeadAttention):
"""Strided Multi-Head Attention Layer
Strided multi-head attention performs global sequence downsampling by striding
the attention query before aplying scaled dot-product attention. This results in
the attention query before applying scaled dot-product attention. This results in
strided attention maps where query positions can attend to the entire sequence
context to perform downsampling.
......@@ -1321,7 +1321,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
Args:
dim_model: model feature dimension
num_heads: number of attention heads
Pdrop: residual dropout probability
dropout: residual dropout probability
max_pos_encoding: maximum position
relative_pos_enc: whether to use relative postion embedding
causal: True for causal attention with masked future context
......@@ -1335,14 +1335,14 @@ class MultiHeadSelfAttentionModule(nn.Module):
def __init__(self,
dim_model,
num_heads,
Pdrop,
dropout,
max_pos_encoding,
relative_pos_enc,
causal,
group_size,
kernel_size,
stride,
linear_att):
relative_pos_enc=False,
causal=False,
group_size=1,
kernel_size=None,
stride=1,
linear_att=False):
super(MultiHeadSelfAttentionModule, self).__init__()
# Assert
......@@ -1351,7 +1351,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
assert not (linear_att and relative_pos_enc), "Linear attention requires absolute positional encodings"
# Pre Norm
self.norm = nn.LayerNorm(dim_model, eps=1e-6)
# self.norm = nn.LayerNorm(dim_model, eps=1e-6)
# Multi-Head Linear Attention
if linear_att:
......@@ -1394,7 +1394,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
self.mhsa = MultiHeadAttention(dim_model, num_heads)
# Dropout
self.dropout = nn.Dropout(Pdrop)
# self.dropout = nn.Dropout(Pdrop)
# Module Params
self.rel_pos_enc = relative_pos_enc
......@@ -1402,8 +1402,9 @@ class MultiHeadSelfAttentionModule(nn.Module):
def forward(self, x, mask=None, hidden=None):
# Pre Norm
x = self.norm(x)
x = x.transpose(0, 1)
if mask is not None:
mask = mask.view(mask.size(0), 1, 1, mask.size(-1))
# Multi-Head Self-Attention
if self.linear_att:
......@@ -1414,6 +1415,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
x, attention = self.mhsa(x, x, x, mask)
# Dropout
x = self.dropout(x)
# x = self.dropout(x)
return x, attention, hidden
x = x.transpose(0, 1)
return x, attention
......@@ -10,17 +10,18 @@ class ConvolutionModule(nn.Module):
def __init__(
self,
embed_dim,
channels,
expand_embed_dim,
depthwise_kernel_size,
dropout,
activation_fn="swish",
bias=False,
stride=1,
export=False,
):
"""
Args:
embed_dim: Embedding dimension
channels: Number of channels in depthwise conv layers
expand_embed_dim: Number of output embedding dimension
depthwise_kernel_size: Depthwise conv layer kernel size
dropout: dropout value
activation_fn: Activation function to use after depthwise convolution kernel
......@@ -33,7 +34,7 @@ class ConvolutionModule(nn.Module):
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
self.pointwise_conv1 = torch.nn.Conv1d(
embed_dim,
2 * channels,
2 * expand_embed_dim,
kernel_size=1,
stride=1,
padding=0,
......@@ -41,19 +42,19 @@ class ConvolutionModule(nn.Module):
)
self.glu = torch.nn.GLU(dim=1)
self.depthwise_conv = torch.nn.Conv1d(
channels,
channels,
expand_embed_dim,
expand_embed_dim,
depthwise_kernel_size,
stride=1,
stride=stride,
padding=(depthwise_kernel_size - 1) // 2,
groups=channels,
groups=expand_embed_dim,
bias=bias,
)
self.batch_norm = nn.BatchNorm1d(channels)
self.batch_norm = nn.BatchNorm1d(expand_embed_dim)
self.activation = get_activation_class(activation_fn)
self.pointwise_conv2 = torch.nn.Conv1d(
channels,
embed_dim,
expand_embed_dim,
expand_embed_dim,
kernel_size=1,
stride=1,
padding=0,
......@@ -72,8 +73,8 @@ class ConvolutionModule(nn.Module):
x = x.transpose(1, 2)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = self.glu(x) # (batch, channel, dim)
x = self.pointwise_conv1(x) # (batch, 2*expand_embed_dim, dim)
x = self.glu(x) # (batch, expand_embed_dim, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
......@@ -81,10 +82,13 @@ class ConvolutionModule(nn.Module):
x = self.activation(x)
x = self.pointwise_conv2(x)
x = x.transpose(1, 2)
x = self.dropout(x)
return x.transpose(1, 2)
#
return x
# class ConvolutionModule(nn.Module):
# """ConvolutionModule in Conformer model."""
# def __init__(self,
......
from typing import Optional
import torch
from torch import Tensor
import torch.nn as nn
from fairseq.modules import (
LayerNorm,
MultiheadAttention,
......@@ -14,10 +16,11 @@ from fairseq.modules import (
LocalMultiheadAttention,
ReducedMultiheadAttention,
RotaryPositionMultiHeadedAttention,
MultiHeadSelfAttentionModule,
)
from fairseq.modules.s2t_transformer_layer import FeedForwardModule
from fairseq.modules.fairseq_dropout import FairseqDropout
from torch import Tensor
from .utils import Transpose, Permute3D
class PDSTransformerEncoderLayer(nn.Module):
......@@ -35,29 +38,48 @@ class PDSTransformerEncoderLayer(nn.Module):
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, args, embed_dim, ffn_embed_dim, num_head, att_sample_ratio=1):
def __init__(self, args,
embed_dim,
ffn_ratio,
num_head,
attn_sample_ratio=1,
attn_stride=1,
conv_stride=1,
expand_embed_dim=None):
super().__init__()
self.args = args
embed_dim = embed_dim
ffn_dim = args.encoder_ffn_embed_dim
dropout = args.dropout
self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(args, embed_dim, num_head, att_sample_ratio)
self.self_attn_layer_norm = LayerNorm(embed_dim)
if expand_embed_dim is None:
expand_embed_dim = embed_dim
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.normalize_before = args.encoder_normalize_before
activation = getattr(args, 'encoder_activation_fn', 'relu')
# attention
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(args, embed_dim, num_head, attn_sample_ratio)
self.self_attn_layer_norm = LayerNorm(embed_dim)
# Attention Residual
self.attn_res = nn.Sequential(
Permute3D(1, 2, 0),
nn.MaxPool1d(kernel_size=1, stride=attn_stride),
Permute3D(2, 0, 1)
) if attn_stride > 1 else nn.Identity()
if args.macaron_style:
self.macaron_ffn = FeedForwardModule(
embed_dim,
ffn_dim,
embed_dim * ffn_ratio,
dropout,
dropout,
activation
......@@ -73,24 +95,37 @@ class PDSTransformerEncoderLayer(nn.Module):
self.conv_norm = LayerNorm(embed_dim)
self.conv_module = ConvolutionModule(
embed_dim,
embed_dim,
expand_embed_dim,
depthwise_kernel_size=args.cnn_module_kernel,
dropout=args.dropout,
activation_fn=getattr(args, 'activation_fn', 'swish'))
self.final_norm = LayerNorm(embed_dim)
activation_fn=activation,
stride=conv_stride
)
self.final_norm = LayerNorm(expand_embed_dim)
# Convolution Residual
self.conv_res = nn.Sequential(
Permute3D(1, 2, 0),
nn.Conv1d(embed_dim, expand_embed_dim, kernel_size=1, stride=conv_stride),
Permute3D(2, 0, 1)
) if embed_dim != expand_embed_dim else nn.Sequential(
Permute3D(1, 2, 0),
nn.MaxPool1d(kernel_size=1, stride=conv_stride),
Permute3D(2, 0, 1)
) if conv_stride > 1 else nn.Identity()
else:
self.conv_norm = None
self.conv_module = None
self.final_norm = None
self.ffn = FeedForwardModule(
embed_dim,
ffn_dim,
expand_embed_dim,
expand_embed_dim * ffn_ratio,
dropout,
dropout,
activation
)
self.ffn_norm = LayerNorm(embed_dim)
self.ffn_norm = LayerNorm(expand_embed_dim)
def build_self_attention(self, args, embed_dim, num_head, sample_ratio=1):
attention_heads = num_head
......@@ -165,6 +200,17 @@ class PDSTransformerEncoderLayer(nn.Module):
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
sample_ratio=sample_ratio,
reduced_method=getattr(args, "attention_reduced_method", "conv"),
reduced_q=getattr(args, "attention_reduced_q", False)
)
elif self.attn_type == "transfer":
return MultiHeadSelfAttentionModule(
embed_dim,
attention_heads,
dropout,
max_pos_encoding=args.max_source_positions,
relative_pos_enc=getattr(args, "relative_pos_enc", False),
linear_att=getattr(args, "linear_att", False),
)
else:
print("The encoder attention type %s is not supported!" % self.attn_type)
......@@ -248,6 +294,10 @@ class PDSTransformerEncoderLayer(nn.Module):
attn_mask=attn_mask,
pos_emb=pos_emb
)
elif self.attn_type == "transfer":
x, _ = self.self_attn(
x, encoder_padding_mask
)
else:
x, _ = self.self_attn(
query=x,
......@@ -258,7 +308,7 @@ class PDSTransformerEncoderLayer(nn.Module):
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
x = self.residual_connection(self.attn_res(x), residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.layer_norm import LayerNorm
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor, nn
from torch.nn import Parameter
......@@ -38,6 +33,8 @@ class ReducedMultiheadAttention(nn.Module):
q_noise=0.0,
qn_block_size=8,
sample_ratio=1,
reduced_method="conv",
reduced_q=False,
):
super().__init__()
self.embed_dim = embed_dim
......@@ -85,13 +82,25 @@ class ReducedMultiheadAttention(nn.Module):
self.add_zero_attn = add_zero_attn
self.sample_ratio = sample_ratio
self.reduced_method = reduced_method
self.reduced_q = reduced_q
if reduced_q:
assert self.reduced_method == 'group', "only support grouped method for query reduction"
if self.sample_ratio > 1:
self.sr = nn.Conv1d(embed_dim, embed_dim,
kernel_size=sample_ratio,
stride=sample_ratio,
# padding=(sample_ratio - 1) // 2
)
self.norm = nn.LayerNorm(embed_dim)
if reduced_method == "conv":
self.sr = nn.Conv1d(embed_dim, embed_dim,
kernel_size=sample_ratio,
stride=sample_ratio,
# padding=(sample_ratio - 1) // 2
)
self.norm = LayerNorm(embed_dim)
elif reduced_method == "pool":
self.linear = nn.Linear(embed_dim, embed_dim)
self.norm = LayerNorm(embed_dim)
self.act = nn.GELU()
elif reduced_method == "group":
pass
self.reset_parameters()
......@@ -159,41 +168,6 @@ class ReducedMultiheadAttention(nn.Module):
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if (
self.sample_ratio == 1 and
not self.onnx_trace
and not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
):
assert key is not None and value is not None
return F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
......@@ -205,16 +179,41 @@ class ReducedMultiheadAttention(nn.Module):
else:
saved_state = None
q = self.q_proj(query)
# only support self attention
if self.self_attention:
query_ = query
if self.sample_ratio > 1:
query_ = query.permute(1, 2, 0) # bsz, dim, seq_len:
query_ = self.sr(query_).permute(2, 0, 1) # seq_len, bsz, dim
query = self.norm(query_)
assert tgt_len % self.sample_ratio == 0, \
("sample ratio %d is mismatched with length %d" % (self.sample_ratio, tgt_len))
if self.reduced_method == "conv":
query_ = query.permute(1, 2, 0) # bsz, dim, seq_len
query_ = self.sr(query_).permute(2, 0, 1) # seq_len, bsz, dim
query_ = self.norm(query_)
elif self.reduced_method == "pool":
query_ = query.permute(1, 2, 0) # bsz, dim, seq_len:
pool_length = int(tgt_len / self.sample_ratio)
query_ = nn.functional.adaptive_max_pool1d(query_, pool_length).permute(2, 0, 1)
query_ = self.act(self.norm(query_))
if self.reduced_q:
q = self.q_proj(query_)
tgt_len = int(tgt_len / self.sample_ratio)
else:
q = self.q_proj(query)
k = self.k_proj(query_)
v = self.v_proj(query_)
if self.sample_ratio > 1 and self.reduced_method == "group":
assert self.reduced_q is True
self.head_dim *= self.sample_ratio
q = q.transpose(0, 1).contiguous().view(bsz, -1, self.embed_dim * self.sample_ratio).transpose(0, 1)
k = q.transpose(0, 1).view(bsz, -1, self.embed_dim * self.sample_ratio).transpose(0, 1)
v = q.transpose(0, 1).view(bsz, -1, self.embed_dim * self.sample_ratio).transpose(0, 1)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
q = self.q_proj(query)
# encoder-decoder attention
if key is None:
assert value is None
......@@ -224,10 +223,12 @@ class ReducedMultiheadAttention(nn.Module):
v = self.v_proj(key)
else:
q = self.q_proj(query)
assert key is not None and value is not None
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
# q *= self.scaling
q *= (self.head_dim ** -0.5)
if self.bias_k is not None:
assert self.bias_v is not None
......@@ -313,13 +314,15 @@ class ReducedMultiheadAttention(nn.Module):
if key_padding_mask is not None:
if self.sample_ratio > 1:
lengths = (~key_padding_mask).sum(-1)
lengths = (lengths / self.sample_ratio).long()
# lengths = ((lengths.float() - 1) / self.sample_ratio + 1).floor().long()
max_length = src_len
assert max_length >= max(lengths), (max_length, max(lengths))
mask = torch.arange(max_length).to(lengths.device).view(1, max_length)
key_padding_mask = mask.expand(bsz, -1) >= lengths.view(bsz, 1).expand(-1, max_length)
key_padding_mask = key_padding_mask[:, ::self.sample_ratio]
# lengths = (~key_padding_mask).sum(-1)
# lengths = (lengths / self.sample_ratio).long()
# # lengths = ((lengths.float() - 1) / self.sample_ratio + 1).floor().long()
# max_length = src_len
# assert max_length >= max(lengths), (max_length, max(lengths))
# mask = torch.arange(max_length).to(lengths.device).view(1, max_length)
# key_padding_mask = mask.expand(bsz, -1) >= lengths.view(bsz, 1).expand(-1, max_length)
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
......@@ -380,6 +383,10 @@ class ReducedMultiheadAttention(nn.Module):
assert v is not None
attn = torch.bmm(attn_probs, v)
if self.sample_ratio > 1 and self.reduced_q:
tgt_len = attn.size(1) * self.sample_ratio
self.head_dim = int(self.head_dim / self.sample_ratio)
attn = attn.view(bsz * self.num_heads, tgt_len, self.head_dim)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
......
import torch
from torch import nn as nn
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
return x.transpose(self.dim0, self.dim1)
class Permute3D(nn.Module):
def __init__(self, dim0, dim1, dim2):
super(Permute3D, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
self.dim2 = dim2
def forward(self, x):
return x.permute(self.dim0, self.dim1, self.dim2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论