Commit d4255246 by xuchen

optimize the implementation of the Efficient Conformer

parent 0bd92062
File mode changed from 100644 to 100755
...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -29,7 +29,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -29,7 +29,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
File mode changed from 100644 to 100755
...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -12,7 +12,7 @@ zero_infinity: True ...@@ -12,7 +12,7 @@ zero_infinity: True
post-process: sentencepiece post-process: sentencepiece
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
arch: transformer arch: transformer
share-all-embeddings: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
......
arch: transformer arch: transformer
share-all-embeddings: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
......
...@@ -44,10 +44,10 @@ lcrm=1 ...@@ -44,10 +44,10 @@ lcrm=1
tokenizer=0 tokenizer=0
use_specific_dict=1 use_specific_dict=1
specific_prefix=st specific_prefix=asr5k_st10k
specific_dir=${root_dir}/data/mustc/st specific_dir=${root_dir}/data/${dataset}/st_lcrm_asr
src_vocab_prefix=spm_unigram10000_st_share src_vocab_prefix=spm_unigram5000_asr
tgt_vocab_prefix=spm_unigram10000_st_share tgt_vocab_prefix=spm_unigram10000_st
org_data_dir=${root_dir}/data/${dataset} org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/mt data_dir=${root_dir}/data/${dataset}/mt
...@@ -82,7 +82,6 @@ len_penalty=1.0 ...@@ -82,7 +82,6 @@ len_penalty=1.0
if [[ ${use_specific_dict} -eq 1 ]]; then if [[ ${use_specific_dict} -eq 1 ]]; then
exp_prefix=${exp_prefix}_${specific_prefix} exp_prefix=${exp_prefix}_${specific_prefix}
data_dir=${data_dir}/${specific_prefix} data_dir=${data_dir}/${specific_prefix}
mkdir -p ${data_dir}
else else
if [[ "${tgt_vocab_type}" == "char" ]]; then if [[ "${tgt_vocab_type}" == "char" ]]; then
vocab_name=char vocab_name=char
...@@ -159,6 +158,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -159,6 +158,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="$cmd cmd="$cmd
--share" --share"
fi fi
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
else else
...@@ -171,13 +171,15 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -171,13 +171,15 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
for split in ${train_subset} ${valid_subset} ${trans_subset}; do for split in ${train_subset} ${valid_subset} ${trans_subset}; do
{ {
if [[ -d ${org_data_dir}/data/${split}/txt ]]; then if [[ -d ${org_data_dir}/data/${split}/txt ]]; then
txt_dir=${org_data_dir}/data/${split}/txt text_dir=${org_data_dir}/data/${split}/txt
else else
txt_dir=${org_data_dir}/data/${split} text_dir=${org_data_dir}/data/${split}
fi fi
cmd="cat ${txt_dir}/${split}.${src_lang}" src_text=${text_dir}/${split}.${src_lang}
tgt_text=${text_dir}/${split}.${tgt_lang}
cmd="cat ${src_text}"
if [[ ${lcrm} -eq 1 ]]; then if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${org_data_dir}/data/${split}.${src_lang}" cmd="python local/lower_rm.py ${src_text}"
fi fi
cmd="${cmd} cmd="${cmd}
| spm_encode --model ${data_dir}/${src_vocab_prefix}.model | spm_encode --model ${data_dir}/${src_vocab_prefix}.model
...@@ -190,7 +192,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -190,7 +192,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="spm_encode cmd="spm_encode
--model ${data_dir}/${tgt_vocab_prefix}.model --model ${data_dir}/${tgt_vocab_prefix}.model
--output_format=piece --output_format=piece
< ${txt_dir}/${split}.${tgt_lang} < ${tgt_text}
> ${data_dir}/data/${split}.${tgt_lang}" > ${data_dir}/data/${split}.${tgt_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
...@@ -329,11 +331,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -329,11 +331,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
mv tmp.log $log mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
cmd="nohup ${cmd} >> ${model_dir}/train.log 2>&1 &" log=${model_dir}/train.log
cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
sleep 2s sleep 2s
tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log}
fi fi
fi fi
wait wait
......
...@@ -6,17 +6,17 @@ gpu_num=1 ...@@ -6,17 +6,17 @@ gpu_num=1
update_freq=1 update_freq=1
max_tokens=8192 max_tokens=8192
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
exp_tag=baseline exp_tag=baseline
config_list=(base) config_list=(base)
# exp full name # exp full name
exp_name= exp_name=
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
train_config=$(echo ${config_list[*]} | sed 's/ /,/g') train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
cmd="./run.sh cmd="./run.sh
......
File mode changed from 100644 to 100755
...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -15,7 +15,7 @@ encoder-normalize-before: True ...@@ -15,7 +15,7 @@ encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -15,7 +15,7 @@ encoder-normalize-before: True ...@@ -15,7 +15,7 @@ encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -14,7 +14,7 @@ label_smoothing: 0.1 ...@@ -14,7 +14,7 @@ label_smoothing: 0.1
encoder-normalize-before: True encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -13,7 +13,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -13,7 +13,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
arch: s2t_ctc arch: s2t_ctc
encoder-type: pds encoder-type: pds
#arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0 #pds-ctc: 0_1_1_0
#intermedia-adapter: league #intermedia-adapter: league
#intermedia-ctc-weight: 1 #intermedia-ctc-weight: 1
#encoder-attention-type: transfer #encoder-attention-type: reduced
#relative-pos-enc: True
encoder-attention-type: rel_pos
#pds-attn-ds-ratios: 4_2_1_1 #pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool #attention-reduced-method: pool
#attention-reduced-q: True #attention-reduced-q: True
encoder-embed-dim: 256
pds-stages: 4 encoder-embed-dim: 240
ctc-layer: 12 pds-stages: 3
pds-layers: 3_3_3_3 #ctc-layer: 15
pds-ratios: 2_2_1_2 pds-layers: 4_5_6
pds-fusion: True pds-ratios: 2_2_2
pds-fusion: False
pds-fusion-method: all_conv pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256 pds-embed-dims: 120_168_240
pds-ds-method: conv pds-ds-method: conv
pds-embed-norm: True pds-embed-norm: True
pds-position-embed: 1_1_1_1 pds-position-embed: 1_1_1
pds-kernel-sizes: 5_5_5_5 pds-kernel-sizes: 3_3_3
pds-ffn-ratios: 8_8_8_8 pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4_4 pds-attn-heads: 4_4_4
share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 0.0015
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: ctc
post-process: sentencepiece
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-layers: 15
encoder-layers: 12
decoder-layers: 6 macaron-style: True
encoder-attention-heads: 4 use-cnn-module: True
cnn-module-kernel: 15
decoder-embed-dim: 256 encoder-activation-fn: swish
decoder-ffn-embed-dim: 2048 encoder-attention-type: rel_pos
decoder-attention-heads: 4
#load-pretrained-encoder-from:
...@@ -13,7 +13,7 @@ zero_infinity: True ...@@ -13,7 +13,7 @@ zero_infinity: True
post-process: sentencepiece post-process: sentencepiece
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
File mode changed from 100644 to 100755
...@@ -13,7 +13,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -13,7 +13,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -15,7 +15,7 @@ encoder-normalize-before: True ...@@ -15,7 +15,7 @@ encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -15,7 +15,7 @@ encoder-normalize-before: True ...@@ -15,7 +15,7 @@ encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -14,7 +14,7 @@ label_smoothing: 0.1 ...@@ -14,7 +14,7 @@ label_smoothing: 0.1
encoder-normalize-before: True encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
File mode changed from 100644 to 100755
arch: s2t_ctc arch: s2t_ctc
encoder-type: transformer
optimizer: adam optimizer: adam
#clip-norm: 10.0 #clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
...@@ -12,7 +14,7 @@ criterion: ctc ...@@ -12,7 +14,7 @@ criterion: ctc
post-process: sentencepiece post-process: sentencepiece
subsampling-type: conv2d subsampling-type: conv2d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 176 subsampling-filter: 176
subsampling-kernel: 3 subsampling-kernel: 3
subsampling-stride: 2 subsampling-stride: 2
......
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
pds-stages: 3
pds-layers: 4_6_6
pds-ratios: -1_0_0
pds-conv-strides: 2_2_1
pds-fusion: False
pds-fusion-method: all_conv
pds-embed-dims: 180_256_360
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 3_3_3
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
subsampling-type: conv2d
subsampling-layers: 1
subsampling-filter: 180
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: batch2d
subsampling-activation: swish
optimizer: adam
#clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 360
encoder-layers: 15
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
\ No newline at end of file
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
pds-stages: 3
pds-layers: 5_5_5
pds-ratios: -1_0_0
pds-conv-strides: 2_2_1
pds-fusion: False
pds-fusion-method: all_conv
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 3_3_3
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
subsampling-type: conv2d
subsampling-layers: 1
subsampling-filter: 120
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: batch2d
subsampling-activation: swish
optimizer: adam
#clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 240
encoder-layers: 15
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
\ No newline at end of file
...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -11,7 +11,7 @@ criterion: ctc ...@@ -11,7 +11,7 @@ criterion: ctc
post-process: sentencepiece post-process: sentencepiece
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -38,14 +38,7 @@ criterion: ctc ...@@ -38,14 +38,7 @@ criterion: ctc
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
\ No newline at end of file
...@@ -38,10 +38,7 @@ post-process: sentencepiece ...@@ -38,10 +38,7 @@ post-process: sentencepiece
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
......
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 240
pds-stages: 3
#ctc-layer: 15
pds-layers: 4_5_6
pds-ratios: 2_2_2
pds-fusion: False
pds-fusion-method: all_conv
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 3_3_3
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-layers: 15
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
#load-pretrained-encoder-from:
File mode changed from 100644 to 100755
...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -12,7 +12,7 @@ zero_infinity: True ...@@ -12,7 +12,7 @@ zero_infinity: True
post-process: sentencepiece post-process: sentencepiece
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
File mode changed from 100644 to 100755
...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -15,7 +15,7 @@ encoder-normalize-before: True ...@@ -15,7 +15,7 @@ encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 1024 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -15,7 +15,7 @@ encoder-normalize-before: True ...@@ -15,7 +15,7 @@ encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
...@@ -14,7 +14,7 @@ label_smoothing: 0.1 ...@@ -14,7 +14,7 @@ label_smoothing: 0.1
encoder-normalize-before: True encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -2,21 +2,21 @@ ...@@ -2,21 +2,21 @@
# training the model # training the model
gpu_num=8 gpu_num=4
update_freq=1 update_freq=4
max_tokens=8192 max_tokens=8192
exp_tag=baseline
config_list=(base)
# exp full name
exp_name=
extra_tag= extra_tag=
extra_parameter= extra_parameter=
#extra_tag="${extra_tag}" #extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} " #extra_parameter="${extra_parameter} "
exp_tag=baseline
config_list=(deep)
# exp full name
exp_name=
train_config=$(echo ${config_list[*]} | sed 's/ /,/g') train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
cmd="./run.sh cmd="./run.sh
......
File mode changed from 100644 to 100755
#! /bin/bash #! /bin/bash
# Processing WMT16 En-De Datasets # Processing WMT20 En-Zh Datasets
# Copyright 2021 Natural Language Processing Laboratory # Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com) # Xu Chen (xuchenneu@163.com)
...@@ -35,18 +35,19 @@ lang=${src_lang}-${tgt_lang} ...@@ -35,18 +35,19 @@ lang=${src_lang}-${tgt_lang}
dataset=wmt20 dataset=wmt20
task=translation task=translation
vocab_type=unigram src_vocab_type=unigram
vocab_size=32000 tgt_vocab_type=unigram
src_vocab_size=32000
tgt_vocab_size=32000
share_dict=0 share_dict=0
lcrm=1 lcrm=1
tokenizer=1 tokenizer=1
use_specific_dict=0 use_specific_dict=1
subword=0 specific_prefix=asr5k_st10k
specific_prefix=subword32000_share specific_dir=${root_dir}/data/iwslt2022/st_lcrm_asr
specific_dir=${root_dir}/data/mustc/st src_vocab_prefix=spm_unigram5000_asr
src_vocab_prefix=spm_unigram10000_st_share tgt_vocab_prefix=spm_unigram10000_st
tgt_vocab_prefix=spm_unigram10000_st_share
org_data_dir=${root_dir}/data/${dataset} org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/mt data_dir=${root_dir}/data/${dataset}/mt
...@@ -81,17 +82,24 @@ len_penalty=1.0 ...@@ -81,17 +82,24 @@ len_penalty=1.0
if [[ ${use_specific_dict} -eq 1 ]]; then if [[ ${use_specific_dict} -eq 1 ]]; then
exp_prefix=${exp_prefix}_${specific_prefix} exp_prefix=${exp_prefix}_${specific_prefix}
data_dir=${data_dir}/${specific_prefix} data_dir=${data_dir}/${specific_prefix}
mkdir -p ${data_dir}
else else
if [[ "${vocab_type}" == "char" ]]; then if [[ "${tgt_vocab_type}" == "char" ]]; then
vocab_name=${vocab_type} vocab_name=char
exp_prefix=${exp_prefix}_${vocab_type} exp_prefix=${exp_prefix}_char
else else
vocab_name=${vocab_type}${vocab_size} if [[ ${src_vocab_size} -ne ${tgt_vocab_size} || "${src_vocab_type}" -ne "${tgt_vocab_type}" ]]; then
src_vocab_name=${src_vocab_type}${src_vocab_size}
tgt_vocab_name=${tgt_vocab_type}${tgt_vocab_size}
vocab_name=${src_vocab_name}_${tgt_vocab_name}
else
vocab_name=${tgt_vocab_type}${tgt_vocab_size}
src_vocab_name=${vocab_name}
tgt_vocab_name=${vocab_name}
fi
fi fi
data_dir=${data_dir}/${vocab_name} data_dir=${data_dir}/${vocab_name}
src_vocab_prefix=spm_${vocab_name}_${src_lang} src_vocab_prefix=spm_${src_vocab_name}_${src_lang}
tgt_vocab_prefix=spm_${vocab_name}_${tgt_lang} tgt_vocab_prefix=spm_${tgt_vocab_name}_${tgt_lang}
if [[ $share_dict -eq 1 ]]; then if [[ $share_dict -eq 1 ]]; then
data_dir=${data_dir}_share data_dir=${data_dir}_share
src_vocab_prefix=spm_${vocab_name}_share src_vocab_prefix=spm_${vocab_name}_share
...@@ -103,6 +111,9 @@ if [[ ${lcrm} -eq 1 ]]; then ...@@ -103,6 +111,9 @@ if [[ ${lcrm} -eq 1 ]]; then
exp_prefix=${exp_prefix}_lcrm exp_prefix=${exp_prefix}_lcrm
fi fi
if [[ ${tokenizer} -eq 1 ]]; then if [[ ${tokenizer} -eq 1 ]]; then
train_subset=${train_subset}.tok
valid_subset=${valid_subset}.tok
trans_subset=${trans_subset}.tok
data_dir=${data_dir}_tok data_dir=${data_dir}_tok
exp_prefix=${exp_prefix}_tok exp_prefix=${exp_prefix}_tok
fi fi
...@@ -139,16 +150,14 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -139,16 +150,14 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--splits ${train_subset},${valid_subset},${trans_subset} --splits ${train_subset},${valid_subset},${trans_subset}
--src-lang ${src_lang} --src-lang ${src_lang}
--tgt-lang ${tgt_lang} --tgt-lang ${tgt_lang}
--vocab-type ${vocab_type} --src-vocab-type ${src_vocab_type}
--vocab-size ${vocab_size}" --tgt-vocab-type ${tgt_vocab_type}
--src-vocab-size ${src_vocab_size}
--tgt-vocab-size ${tgt_vocab_size}"
if [[ $share_dict -eq 1 ]]; then if [[ $share_dict -eq 1 ]]; then
cmd="$cmd cmd="$cmd
--share" --share"
fi fi
if [[ ${tokenizer} -eq 1 ]]; then
cmd="$cmd
--tokenizer"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
...@@ -168,10 +177,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -168,10 +177,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi fi
src_text=${text_dir}/${split}.${src_lang} src_text=${text_dir}/${split}.${src_lang}
tgt_text=${text_dir}/${split}.${tgt_lang} tgt_text=${text_dir}/${split}.${tgt_lang}
if [[ ${tokenizer} -eq 1 ]]; then
src_text=${text_dir}/${split}.tok.${src_lang}
tgt_text=${text_dir}/${split}.tok.${tgt_lang}
fi
cmd="cat ${src_text}" cmd="cat ${src_text}"
if [[ ${lcrm} -eq 1 ]]; then if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${src_text}" cmd="python local/lower_rm.py ${src_text}"
...@@ -327,16 +332,14 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -327,16 +332,14 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
log=${model_dir}/train.log log=${model_dir}/train.log
cmd="nohup ${cmd} >> ${log} 2>&1 &" cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
sleep 2s sleep 2s
tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log} tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log}
fi fi
wait
echo -e " >> finish training \n"
fi fi
wait
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: MT Decoding" echo "stage 2: MT Decoding"
...@@ -381,15 +384,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -381,15 +384,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--results-path ${model_dir} --results-path ${model_dir}
--max-tokens ${max_tokens} --max-tokens ${max_tokens}
--beam ${beam_size} --beam ${beam_size}
--lenpen ${len_penalty}" --lenpen ${len_penalty}
if [[ ${subword} -eq 1 ]]; then
cmd="${cmd}
--post-process subword_nmt"
else
cmd="${cmd}
--post-process sentencepiece" --post-process sentencepiece"
fi
if [[ ${sacrebleu} -eq 1 ]]; then if [[ ${sacrebleu} -eq 1 ]]; then
cmd="${cmd} cmd="${cmd}
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# training the model # training the model
gpu_num=8 gpu_num=4
update_freq=2 update_freq=4
max_tokens=8192 max_tokens=8192
exp_tag=baseline exp_tag=baseline
......
...@@ -24,6 +24,9 @@ from fairseq.modules import ( ...@@ -24,6 +24,9 @@ from fairseq.modules import (
PDSTransformerEncoderLayer, PDSTransformerEncoderLayer,
DownSampleConvolutionModule DownSampleConvolutionModule
) )
from fairseq.modules.speech_to_text import (
subsampling
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -65,16 +68,14 @@ class Downsampling(nn.Module): ...@@ -65,16 +68,14 @@ class Downsampling(nn.Module):
self.stride = stride self.stride = stride
self.reduced_way = reduced_way self.reduced_way = reduced_way
if stride == 0:
return
# default conv # default conv
if self.reduced_way == "conv": if self.reduced_way == "conv":
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding), nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
) )
elif self.reduced_way == "glu":
self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels * 2, kernel_sizes, stride=stride, padding=padding),
nn.GLU(dim=1)
)
elif self.reduced_way == "proj": elif self.reduced_way == "proj":
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding), nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
...@@ -88,6 +89,9 @@ class Downsampling(nn.Module): ...@@ -88,6 +89,9 @@ class Downsampling(nn.Module):
self.norm = LayerNorm(out_channels) self.norm = LayerNorm(out_channels)
def forward(self, x, lengths): def forward(self, x, lengths):
if self.stride == 0:
return x, lengths
seq_len, bsz, dim = x.size() seq_len, bsz, dim = x.size()
assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride) assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride)
...@@ -110,23 +114,20 @@ class Downsampling(nn.Module): ...@@ -110,23 +114,20 @@ class Downsampling(nn.Module):
else: else:
x = x.permute(1, 2, 0) # B * D * T x = x.permute(1, 2, 0) # B * D * T
x = self.conv(x) x = self.conv(x)
if self.reduced_way == "glu":
x = self.glu(x)
x = x.permute(2, 0, 1) # T * B * D x = x.permute(2, 0, 1) # T * B * D
if self.embed_norm: if self.embed_norm:
x = self.norm(x) x = self.norm(x)
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, x.size(0))
# mask batch padding # mask batch padding
if not torch.all(lengths == x.size(-1)): if not torch.all(lengths == x.size(-1)):
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, x.size(0))
mask_pad = padding_mask.unsqueeze(2) mask_pad = padding_mask.unsqueeze(2)
if mask_pad is not None: if mask_pad is not None:
x = x.transpose(0, 1) x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0) x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1) x = x.transpose(0, 1)
return x, lengths, padding_mask return x, lengths
@register_model("pdss2t_transformer") @register_model("pdss2t_transformer")
...@@ -139,6 +140,44 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -139,6 +140,44 @@ class PDSS2TTransformerModel(S2TTransformerModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# subsampling
parser.add_argument(
"--subsampling-type",
type=str,
help="subsampling type, like conv1d and conv2d",
)
parser.add_argument(
"--subsampling-layers",
type=int,
help="subsampling layers",
)
parser.add_argument(
"--subsampling-filter",
type=int,
help="subsampling filter",
)
parser.add_argument(
"--subsampling-kernel",
type=int,
help="subsampling kernel",
)
parser.add_argument(
"--subsampling-stride",
type=int,
help="subsampling stride",
)
parser.add_argument(
"--subsampling-norm",
type=str,
default="none",
help="subsampling normalization type",
)
parser.add_argument(
"--subsampling-activation",
type=str,
default="none",
help="subsampling activation function type",
)
# Transformer # Transformer
parser.add_argument( parser.add_argument(
"--activation-fn", "--activation-fn",
...@@ -482,7 +521,17 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -482,7 +521,17 @@ class PDSS2TTransformerModel(S2TTransformerModel):
parser.add_argument( parser.add_argument(
"--pds-ffn-ratios", "--pds-ffn-ratios",
type=str, type=str,
help="the ratio of the ffn in each stage", help="the ratio of the ffn in each stage",
)
parser.add_argument(
"--pds-conv-strides",
type=str,
help="the strides of the convolutional module (conformer) in each stage",
)
parser.add_argument(
"--pds-attn-strides",
type=str,
help="the strides of the attention module (conformer) in each stage",
) )
parser.add_argument( parser.add_argument(
"--pds-fusion", "--pds-fusion",
...@@ -565,8 +614,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -565,8 +614,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
) )
self.pds_stages = getattr(args, "pds_stages", 4) self.pds_stages = getattr(args, "pds_stages", 4)
self.pds_layers = [int(n) for n in args.pds_layers.split("_")] self.pds_layers = [int(n) for n in args.pds_layers.split("_")]
self.layers = sum(self.pds_layers)
self.pds_ratios = [int(n) for n in args.pds_ratios.split("_")] self.pds_ratios = [int(n) for n in args.pds_ratios.split("_")]
# down-sampling module # down-sampling module
...@@ -582,6 +631,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -582,6 +631,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
else: else:
self.pds_attn_ds_ratios = None self.pds_attn_ds_ratios = None
self.pds_conv_strides = [int(n) for n in args.pds_conv_strides.split("_")]
self.pds_attn_strides = [int(n) for n in args.pds_attn_strides.split("_")]
# fusion
self.pds_fusion = args.pds_fusion self.pds_fusion = args.pds_fusion
self.pds_fusion_method = args.pds_fusion_method self.pds_fusion_method = args.pds_fusion_method
...@@ -619,15 +672,23 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -619,15 +672,23 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
use_pos_embed = self.pds_position_embed[i] use_pos_embed = self.pds_position_embed[i]
use_ctc = self.pds_ctc[i] use_ctc = self.pds_ctc[i]
ffn_ratio = self.pds_ffn_ratios[i]
num_head = self.pds_attn_heads[i] num_head = self.pds_attn_heads[i]
attn_ds_ratio = self.pds_attn_ds_ratios[i] if self.attn_type == "reduced" else -1 attn_ds_ratio = self.pds_attn_ds_ratios[i] if self.attn_type == "reduced" else -1
ffn_ratio = self.pds_ffn_ratios[i] conv_stride = self.pds_conv_strides[i]
attn_stride = self.pds_attn_strides[i]
if conv_stride != 1 or attn_stride != 1:
expand_embed_dim = embed_dim if i == self.pds_stages - 1 else self.pds_embed_dims[i + 1]
else:
expand_embed_dim = None
logger.info("The stage {}: layer {}, down-sample ratio {}, embed dim {}, " logger.info("The stage {}: layer {}, down-sample ratio {}, embed dim {}, "
"kernel size {}, position embed {}, ffn ratio {}, num head {}, " "kernel size {}, position embed {}, ffn ratio {}, num head {}, "
"attn down-sample ratio {}, conv stride {}, attn stride {}, "
"fusion {}, fusion method {}, fusion transformer {}.". "fusion {}, fusion method {}, fusion transformer {}.".
format(i, num_layers, ds_ratio, embed_dim, format(i, num_layers, ds_ratio, embed_dim,
kernel_size, use_pos_embed, ffn_ratio, num_head, kernel_size, use_pos_embed, ffn_ratio, num_head,
attn_ds_ratio, conv_stride, attn_stride,
self.pds_fusion, self.pds_fusion_method, self.pds_fusion_transform)) self.pds_fusion, self.pds_fusion_method, self.pds_fusion_transform))
if i == 0: if i == 0:
...@@ -636,38 +697,50 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -636,38 +697,50 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.embed_scale = 1.0 self.embed_scale = 1.0
# down-sampling # down-sampling
downsampling = Downsampling( if ds_ratio == -1:
self.pds_ds_method, downsampling = subsampling(args, embed_dim)
self.pds_embed_norm, else:
args.input_feat_per_channel * args.input_channels if i == 0 else self.pds_embed_dims[i - 1], downsampling = Downsampling(
embed_dim, self.pds_ds_method,
kernel_sizes=kernel_size, self.pds_embed_norm,
stride=ds_ratio, args.input_feat_per_channel * args.input_channels if i == 0 else self.pds_embed_dims[i - 1],
padding=(kernel_size - 1) // 2, embed_dim,
) kernel_sizes=kernel_size,
stride=ds_ratio,
padding=(kernel_size - 1) // 2,
)
# position encoding # position encoding
if use_pos_embed: if use_pos_embed:
if self.attn_type == "rel_pos": if self.attn_type == "rel_pos":
pos_embed = RelPositionalEncoding( pos_embed = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim args.max_source_positions, embed_dim
) )
elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]: elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
pos_embed = LegacyRelPositionalEncoding( pos_embed = LegacyRelPositionalEncoding(
args.encoder_embed_dim, args.dropout, args.max_source_positions embed_dim, args.dropout, args.max_source_positions
) )
elif self.attn_type == "rope": elif self.attn_type == "rope":
self.embed_positions = None pos_embed = None
else: # Use absolute positional embedding else: # Use absolute positional embedding
pos_embed = PositionalEmbedding( pos_embed = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx args.max_source_positions, embed_dim, self.padding_idx
) )
else: else:
pos_embed = None pos_embed = None
stage = nn.ModuleList([ stage = nn.ModuleList([
PDSTransformerEncoderLayer(args, embed_dim, ffn_ratio, num_head, attn_ds_ratio) PDSTransformerEncoderLayer(
for _ in range(num_layers)]) args,
embed_dim,
ffn_ratio,
num_head,
attn_ds_ratio,
conv_stride=conv_stride if layer_idx == num_layers - 1 else 1,
attn_stride=attn_stride if layer_idx == num_layers - 1 else 1,
expand_embed_dim=expand_embed_dim if layer_idx == num_layers - 1 else None,
)
for layer_idx in range(num_layers)])
# representation fusion # representation fusion
fusion_pre_layer_norm = None fusion_pre_layer_norm = None
...@@ -760,9 +833,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -760,9 +833,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
(("ctc" in getattr(args, "criterion", "")) and (("ctc" in getattr(args, "criterion", "")) and
(getattr(args, "ctc_weight", False) > 0)) (getattr(args, "ctc_weight", False) > 0))
if self.use_ctc: if self.use_ctc:
# self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers # self.ctc_layer = (args.ctc_layer + self.layers) % self.layers
# self.ctc_layer = args.encoder_layers if self.ctc_layer == 0 else self.ctc_layer # self.ctc_layer = self.layers if self.ctc_layer == 0 else self.ctc_layer
# self.inter_ctc = True if self.ctc_layer != args.encoder_layers or self.fusion_stages_num != 0 else False # self.inter_ctc = True if self.ctc_layer != self.layers or self.fusion_stages_num != 0 else False
self.ctc_layer = args.ctc_layer self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != 0 else False self.inter_ctc = True if self.ctc_layer != 0 else False
...@@ -824,9 +897,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -824,9 +897,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# padding to the multiply of 2 # padding to the multiply of 2
max_len = x.size(0) max_len = x.size(0)
length = reduce(lambda a, b: a * b, self.pds_ratios) length = reduce(lambda a, b: max(1, a) * max(1, b), self.pds_ratios)
padding_to_len = (length - max_len % length) padding_to_len = (length - max_len % length)
if padding_to_len > 0: if length > 1 and padding_to_len > 0:
padding_for_pds = x.new_zeros((padding_to_len, batch, x.size(2))) padding_for_pds = x.new_zeros((padding_to_len, batch, x.size(2)))
x = torch.cat([x, padding_for_pds], dim=0) x = torch.cat([x, padding_for_pds], dim=0)
...@@ -848,7 +921,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -848,7 +921,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
ctc = getattr(self, f"ctc{i + 1}") ctc = getattr(self, f"ctc{i + 1}")
adapter = getattr(self, f"adapter{i + 1}") adapter = getattr(self, f"adapter{i + 1}")
x, input_lengths, encoder_padding_mask = downsampling(x, input_lengths) x, input_lengths = downsampling(x, input_lengths)
encoder_padding_mask = lengths_to_padding_mask_with_maxlen(input_lengths, x.size(0))
# gather cosine similarity # gather cosine similarity
cos_sim_idx += 10 cos_sim_idx += 10
...@@ -881,6 +955,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -881,6 +955,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1 layer_idx += 1
if layer.conv_stride > 1:
# Stride Mask (B, 1, T // S, T // S)
if encoder_padding_mask is not None:
encoder_padding_mask = encoder_padding_mask[:, ::layer.conv_stride]
# Update Seq Lengths
if input_lengths is not None:
input_lengths = torch.div(input_lengths - 1, layer.conv_stride, rounding_mode='floor') + 1
# gather cosine similarity # gather cosine similarity
if self.gather_cos_sim: if self.gather_cos_sim:
cos_sim_idx += 1 cos_sim_idx += 1
...@@ -983,12 +1067,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -983,12 +1067,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
@register_model_architecture(model_name="pdss2t_transformer", arch_name="pdss2t_transformer") @register_model_architecture(model_name="pdss2t_transformer", arch_name="pdss2t_transformer")
def base_architecture(args): def base_architecture(args):
# Convolutional subsampler # Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "") args.subsampling_type = getattr(args, "subsampling_type", "conv1d")
args.conv_channels = getattr(args, "conv_channels", 1024) args.subsampling_layers = getattr(args, "subsampling_layers", 2)
args.subsampling_filter = getattr(args, "subsampling_filter", 1024)
args.subsampling_kernel = getattr(args, "subsampling_kernel", 5)
args.subsampling_stride = getattr(args, "subsampling_stride", 2)
args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn") args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
...@@ -1046,6 +1134,9 @@ def base_architecture(args): ...@@ -1046,6 +1134,9 @@ def base_architecture(args):
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None) args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None) args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
args.ctc_layer = getattr(args, "ctc_layer", 0) args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout) args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
......
import logging import logging
import math from typing import Dict, Optional
from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqEncoderModel, FairseqEncoderModel,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.models.speech_to_text.modules import Adapter, CTC
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
S2TTransformerEncoderLayer,
DynamicLinearCombination,
)
from fairseq.modules.speech_to_text import (
subsampling
)
from torch import Tensor from torch import Tensor
...@@ -445,6 +430,16 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -445,6 +430,16 @@ class S2TCTCModel(FairseqEncoderModel):
help="the ratio of the ffn in each stage", help="the ratio of the ffn in each stage",
) )
parser.add_argument( parser.add_argument(
"--pds-conv-strides",
type=str,
help="the strides of the convolutional module (conformer) in each stage",
)
parser.add_argument(
"--pds-attn-strides",
type=str,
help="the strides of the attention module (conformer) in each stage",
)
parser.add_argument(
"--pds-fusion", "--pds-fusion",
action="store_true", action="store_true",
help="use the representation fusion method", help="use the representation fusion method",
...@@ -573,236 +568,15 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -573,236 +568,15 @@ class S2TCTCEncoder(FairseqEncoder):
logger.error("Unsupported architecture: %s." % encoder_type) logger.error("Unsupported architecture: %s." % encoder_type)
return return
# dim = args.encoder_embed_dim
# self.dropout_module = FairseqDropout(
# p=args.dropout, module_name=self.__class__.__name__
# )
# self.embed_scale = math.sqrt(dim)
# if args.no_scale_embedding:
# self.embed_scale = 1.0
# self.padding_idx = 1
#
# self.subsample = subsampling(args)
#
# self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
#
# if self.attn_type == "rel_pos":
# self.embed_positions = RelPositionalEncoding(
# args.max_source_positions, args.encoder_embed_dim
# )
# elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
# self.embed_positions = LegacyRelPositionalEncoding(
# args.encoder_embed_dim, args.dropout, args.max_source_positions
# )
# elif self.attn_type == "rope":
# self.embed_positions = None
# else: # Use absolute positional embedding
# self.embed_positions = PositionalEmbedding(
# args.max_source_positions, args.encoder_embed_dim, self.padding_idx
# )
#
# self.layers = nn.ModuleList(
# [S2TTransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
# )
#
# if args.encoder_normalize_before:
# self.layer_norm = LayerNorm(dim)
# else:
# self.layer_norm = None
#
# if args.use_enc_dlcl:
# self.history = DynamicLinearCombination(args, is_encoder=True)
# else:
# self.history = None
#
# self.ctc = CTC(dim,
# dictionary_size=len(task.source_dictionary),
# dropout=args.dropout,
# )
#
# # gather cosine similarity of the representation
# self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
# # self.gather_cos_sim = True
# self.dis = 2
# self.cos_sim = dict()
#
# self.intermedia_ctc_layers = []
#
# if args.intermedia_ctc_layers is not None:
# intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
# for layer_idx in intermedia_ctc_layers:
# layer_idx = int(layer_idx)
# if layer_idx <= 0:
# layer_idx += args.encoder_layers
# self.intermedia_ctc_layers.append(layer_idx)
#
# logger.info("Intermedia CTC loss in layer %d" % layer_idx)
#
# strategy = None
# if args.intermedia_adapter == "shrink":
# strategy = getattr(args, "ctc_compress_strategy", "avg")
# elif args.intermedia_adapter == "league":
# strategy = getattr(args, "intermedia_distribution_cutoff", -1)
# self.adapter = Adapter(dim, args.intermedia_adapter,
# task.source_dictionary, strategy=strategy)
# self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
def add_to_dict(self, x, dis, idx):
sim = 0
seq_len = x.size(0)
cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
for i in range(dis, seq_len - dis):
a = x[i, :, :]
for j in range(-dis, dis + 1):
if j == 0:
continue
b = x[i + j, :, :]
sim_j = cos(a, b).mean()
sim += sim_j
sim = sim / 2 / dis / (seq_len - 2 * dis)
if idx not in self.cos_sim:
self.cos_sim[idx] = []
self.cos_sim[idx].append(float(sim))
def forward(self, src_tokens, src_lengths, **kwargs): def forward(self, src_tokens, src_lengths, **kwargs):
return self.encoder(src_tokens, src_lengths, **kwargs) return self.encoder(src_tokens, src_lengths, **kwargs)
#
# if self.history is not None:
# self.history.clean()
#
# # gather cosine similarity
# cos_sim_idx = -1
# dis = self.dis
# if self.gather_cos_sim:
# self.add_to_dict(src_tokens.transpose(0, 1), dis, cos_sim_idx)
#
# # down-sampling
# x, input_lengths = self.subsample(src_tokens, src_lengths)
# # (B, T, D) -> (T, B, D)
# x = x.transpose(0, 1)
#
# # embedding scaling
# x = self.embed_scale * x
#
# # padding and position embedding
# encoder_padding_mask = lengths_to_padding_mask(input_lengths)
#
# if self.attn_type in ["rel_selfattn", "rel_pos", "rel_pos_legacy"]:
# positions = self.embed_positions(x)
#
# elif self.attn_type == "rope":
# positions = None
#
# else:
# positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
# x += positions
# positions = None
#
# x = self.dropout_module(x)
#
# # add emb into history
# if self.history is not None:
# self.history.push(x)
#
# # gather cosine similarity
# cos_sim_idx = (cos_sim_idx + 10) // 10 * 10 - 1
# if self.gather_cos_sim:
# cos_sim_idx += 1
# self.add_to_dict(x, dis, cos_sim_idx)
#
# layer_idx = 0
# intermedia_ctc_logits = []
# for layer in self.layers:
# layer_idx += 1
#
# if self.history is not None:
# x = self.history.pop()
#
# # encoder layer
# x = layer(x, encoder_padding_mask, pos_emb=positions)
#
# # interleave CTC
# if layer_idx in self.intermedia_ctc_layers:
# if self.intermedia_drop_prob > 0:
# p = torch.rand(1).uniform_()
# if p < self.intermedia_drop_prob:
# break
#
# norm_x = self.layer_norm(x)
# logit = self.ctc(norm_x)
# intermedia_ctc_logits.append(logit)
#
# prob = F.softmax(logit, dim=-1, dtype=torch.float32)
# x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
#
# # gather cosine similarity
# if self.gather_cos_sim:
# cos_sim_idx += 1
# self.add_to_dict(x, dis, cos_sim_idx)
#
# if self.history is not None:
# self.history.push(x)
#
# if self.history is not None:
# x = self.history.pop()
#
# if self.layer_norm is not None:
# x = self.layer_norm(x)
#
# ctc_logit = self.ctc(x)
#
# return {
# "encoder_out": [x], # T x B x C
# "ctc_logit": [ctc_logit], # B x T x C
# "intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C
# "encoder_padding_mask": [encoder_padding_mask], # B x T
# "encoder_embedding": [], # B x T x C
# "encoder_states": [], # List[T x B x C]
# "src_tokens": [],
# "src_lengths": [],
# }
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
self.encoder.reorder_encoder_out(encoder_out, new_order) self.encoder.reorder_encoder_out(encoder_out, new_order)
return return
new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
)
new_ctc_logit = (
[] if len(encoder_out["ctc_logit"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["ctc_logit"] if x is not None]
)
new_encoder_padding_mask = (
[] if len(encoder_out["encoder_padding_mask"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
)
new_encoder_embedding = (
[] if len(encoder_out["encoder_embedding"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]]
)
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"ctc_logit": new_ctc_logit, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [], # B x T
"src_lengths": [], # B x 1
}
class CTCDecoder(object): class CTCDecoder(object):
...@@ -968,6 +742,9 @@ def base_architecture(args): ...@@ -968,6 +742,9 @@ def base_architecture(args):
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None) args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None) args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
args.ctc_layer = getattr(args, "ctc_layer", 0) args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout) args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
......
...@@ -620,9 +620,9 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -620,9 +620,9 @@ class S2TTransformerEncoder(FairseqEncoder):
self.add_to_dict(src_tokens.transpose(0, 1), dis, cos_sim_idx) self.add_to_dict(src_tokens.transpose(0, 1), dis, cos_sim_idx)
# down-sampling # down-sampling
x, input_lengths = self.subsample(src_tokens, src_lengths)
# (B, T, D) -> (T, B, D) # (B, T, D) -> (T, B, D)
x = x.transpose(0, 1) x = src_tokens.transpose(0, 1)
x, input_lengths = self.subsample(x, src_lengths)
# embedding scaling # embedding scaling
x = self.embed_scale * x x = self.embed_scale * x
......
...@@ -205,7 +205,7 @@ class LegacyRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention): ...@@ -205,7 +205,7 @@ class LegacyRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention):
Args: Args:
n_head (int): The number of heads. n_head (int): The number of heads.
n_feat (int): The number of features. n_feat (int): The number of features.
dropout_rate (float): Dropout rate. dropout (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
""" """
def __init__(self, n_feat, n_head, dropout, zero_triu=False): def __init__(self, n_feat, n_head, dropout, zero_triu=False):
......
...@@ -104,6 +104,7 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -104,6 +104,7 @@ class PDSTransformerEncoderLayer(nn.Module):
self.final_norm = LayerNorm(expand_embed_dim) self.final_norm = LayerNorm(expand_embed_dim)
# Convolution Residual # Convolution Residual
self.conv_stride = conv_stride
self.conv_res = nn.Sequential( self.conv_res = nn.Sequential(
Permute3D(1, 2, 0), Permute3D(1, 2, 0),
nn.Conv1d(embed_dim, expand_embed_dim, kernel_size=1, stride=conv_stride), nn.Conv1d(embed_dim, expand_embed_dim, kernel_size=1, stride=conv_stride),
...@@ -322,7 +323,7 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -322,7 +323,7 @@ class PDSTransformerEncoderLayer(nn.Module):
x = self.conv_module(x) x = self.conv_module(x)
x = x.transpose(0, 1) x = x.transpose(0, 1)
x = residual + x x = self.conv_res(residual) + x
if not self.normalize_before: if not self.normalize_before:
x = self.conv_norm(x) x = self.conv_norm(x)
......
...@@ -144,8 +144,8 @@ class Conv1dSubsampling(nn.Module): ...@@ -144,8 +144,8 @@ class Conv1dSubsampling(nn.Module):
def forward(self, x, x_len): def forward(self, x, x_len):
# (B, T, D) -> (B, D, T) # (T, B, D) -> (B, D, T)
x = x.transpose(1, 2) x = x.permute(1, 2, 0)
# Layers # Layers
for layer in self.layers: for layer in self.layers:
x = layer(x) x = layer(x)
...@@ -153,7 +153,9 @@ class Conv1dSubsampling(nn.Module): ...@@ -153,7 +153,9 @@ class Conv1dSubsampling(nn.Module):
# Update Sequence Lengths # Update Sequence Lengths
if x_len is not None: if x_len is not None:
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1 x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
x = x.transpose(1, 2)
# (B, D, T) -> (T, B, D)
x = x.permute(2, 0, 1)
return x, x_len return x, x_len
...@@ -168,8 +170,8 @@ class Conv2dSubsampling(nn.Module): ...@@ -168,8 +170,8 @@ class Conv2dSubsampling(nn.Module):
act: activation function act: activation function
Shape: Shape:
Input: (batch_size, in_length, in_dim) Input: (in_length, batch_size in_dim)
Output: (batch_size, out_length, out_dim) Output: (out_length, batch_size, out_dim)
""" """
...@@ -199,8 +201,8 @@ class Conv2dSubsampling(nn.Module): ...@@ -199,8 +201,8 @@ class Conv2dSubsampling(nn.Module):
def forward(self, x, x_len): def forward(self, x, x_len):
# (B, T, D) -> (B, D, T) -> (B, 1, D, T) # (T, B, D) -> (B, D, T) -> (B, 1, D, T)
x = x.transpose(1, 2).unsqueeze(dim=1) x = x.permute(1, 2, 0).unsqueeze(dim=1)
# Layers # Layers
for layer in self.layers: for layer in self.layers:
...@@ -212,17 +214,17 @@ class Conv2dSubsampling(nn.Module): ...@@ -212,17 +214,17 @@ class Conv2dSubsampling(nn.Module):
# (B, C, D // S, T // S) -> (B, C * D // S, T // S) # (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size() batch_size, channels, subsampled_dim, subsampled_length = x.size()
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).transpose(1, 2) x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).permute(2, 0, 1)
x = self.linear(x) x = self.linear(x)
return x, x_len return x, x_len
def subsampling(args): def subsampling(args, out_dim=None):
subsampling_type = getattr(args, "subsampling_type", "conv1d") subsampling_type = getattr(args, "subsampling_type", "conv1d")
layers = getattr(args, "subsampling_layers", 2) layers = getattr(args, "subsampling_layers", 2)
in_dim = args.input_feat_per_channel * args.input_channels in_dim = args.input_feat_per_channel * args.input_channels
filters = [getattr(args, "subsampling_filter")] + [args.encoder_embed_dim] filters = [getattr(args, "subsampling_filter")] + [args.encoder_embed_dim if out_dim is None else out_dim]
kernel_size = getattr(args, "subsampling_kernel", 5) kernel_size = getattr(args, "subsampling_kernel", 5)
stride = getattr(args, "subsampling_stride", 2) stride = getattr(args, "subsampling_stride", 2)
norm = getattr(args, "subsampling_norm", "none") norm = getattr(args, "subsampling_norm", "none")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论