Commit d4255246 by xuchen

optimize the implementation of the Efficient Conformer

parent 0bd92062
File mode changed from 100644 to 100755
......@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -29,7 +29,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......
File mode changed from 100644 to 100755
......@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -12,7 +12,7 @@ zero_infinity: True
post-process: sentencepiece
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
arch: transformer
share-all-embeddings: True
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
......
arch: transformer
share-all-embeddings: True
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
......
......@@ -44,10 +44,10 @@ lcrm=1
tokenizer=0
use_specific_dict=1
specific_prefix=st
specific_dir=${root_dir}/data/mustc/st
src_vocab_prefix=spm_unigram10000_st_share
tgt_vocab_prefix=spm_unigram10000_st_share
specific_prefix=asr5k_st10k
specific_dir=${root_dir}/data/${dataset}/st_lcrm_asr
src_vocab_prefix=spm_unigram5000_asr
tgt_vocab_prefix=spm_unigram10000_st
org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/mt
......@@ -82,7 +82,6 @@ len_penalty=1.0
if [[ ${use_specific_dict} -eq 1 ]]; then
exp_prefix=${exp_prefix}_${specific_prefix}
data_dir=${data_dir}/${specific_prefix}
mkdir -p ${data_dir}
else
if [[ "${tgt_vocab_type}" == "char" ]]; then
vocab_name=char
......@@ -159,6 +158,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="$cmd
--share"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
else
......@@ -171,13 +171,15 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
for split in ${train_subset} ${valid_subset} ${trans_subset}; do
{
if [[ -d ${org_data_dir}/data/${split}/txt ]]; then
txt_dir=${org_data_dir}/data/${split}/txt
text_dir=${org_data_dir}/data/${split}/txt
else
txt_dir=${org_data_dir}/data/${split}
text_dir=${org_data_dir}/data/${split}
fi
cmd="cat ${txt_dir}/${split}.${src_lang}"
src_text=${text_dir}/${split}.${src_lang}
tgt_text=${text_dir}/${split}.${tgt_lang}
cmd="cat ${src_text}"
if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${org_data_dir}/data/${split}.${src_lang}"
cmd="python local/lower_rm.py ${src_text}"
fi
cmd="${cmd}
| spm_encode --model ${data_dir}/${src_vocab_prefix}.model
......@@ -190,7 +192,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="spm_encode
--model ${data_dir}/${tgt_vocab_prefix}.model
--output_format=piece
< ${txt_dir}/${split}.${tgt_lang}
< ${tgt_text}
> ${data_dir}/data/${split}.${tgt_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
......@@ -329,11 +331,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device}
cmd="nohup ${cmd} >> ${model_dir}/train.log 2>&1 &"
log=${model_dir}/train.log
cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then
eval $cmd
sleep 2s
tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log
tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log}
fi
fi
wait
......
......@@ -6,17 +6,17 @@ gpu_num=1
update_freq=1
max_tokens=8192
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
exp_tag=baseline
config_list=(base)
# exp full name
exp_name=
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
cmd="./run.sh
......
File mode changed from 100644 to 100755
......@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -15,7 +15,7 @@ encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -15,7 +15,7 @@ encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -14,7 +14,7 @@ label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -13,7 +13,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
arch: s2t_ctc
encoder-type: pds
#arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: transfer
#relative-pos-enc: True
encoder-attention-type: rel_pos
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
encoder-embed-dim: 240
pds-stages: 3
#ctc-layer: 15
pds-layers: 4_5_6
pds-ratios: 2_2_2
pds-fusion: False
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
pds-position-embed: 1_1_1
pds-kernel-sizes: 3_3_3
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
encoder-layers: 15
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
#load-pretrained-encoder-from:
......@@ -13,7 +13,7 @@ zero_infinity: True
post-process: sentencepiece
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
File mode changed from 100644 to 100755
......@@ -13,7 +13,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -15,7 +15,7 @@ encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -15,7 +15,7 @@ encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -14,7 +14,7 @@ label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......
File mode changed from 100644 to 100755
arch: s2t_ctc
encoder-type: transformer
optimizer: adam
#clip-norm: 10.0
lr-scheduler: inverse_sqrt
......@@ -12,7 +14,7 @@ criterion: ctc
post-process: sentencepiece
subsampling-type: conv2d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 176
subsampling-kernel: 3
subsampling-stride: 2
......
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
pds-stages: 3
pds-layers: 4_6_6
pds-ratios: -1_0_0
pds-conv-strides: 2_2_1
pds-fusion: False
pds-fusion-method: all_conv
pds-embed-dims: 180_256_360
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 3_3_3
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
subsampling-type: conv2d
subsampling-layers: 1
subsampling-filter: 180
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: batch2d
subsampling-activation: swish
optimizer: adam
#clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 360
encoder-layers: 15
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
\ No newline at end of file
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
pds-stages: 3
pds-layers: 5_5_5
pds-ratios: -1_0_0
pds-conv-strides: 2_2_1
pds-fusion: False
pds-fusion-method: all_conv
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 3_3_3
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
subsampling-type: conv2d
subsampling-layers: 1
subsampling-filter: 120
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: batch2d
subsampling-activation: swish
optimizer: adam
#clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 240
encoder-layers: 15
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
\ No newline at end of file
......@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -11,7 +11,7 @@ criterion: ctc
post-process: sentencepiece
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -38,14 +38,7 @@ criterion: ctc
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
......@@ -38,10 +38,7 @@ post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
......
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 240
pds-stages: 3
#ctc-layer: 15
pds-layers: 4_5_6
pds-ratios: 2_2_2
pds-fusion: False
pds-fusion-method: all_conv
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 3_3_3
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-layers: 15
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
#load-pretrained-encoder-from:
File mode changed from 100644 to 100755
......@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -12,7 +12,7 @@ zero_infinity: True
post-process: sentencepiece
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
File mode changed from 100644 to 100755
......@@ -12,7 +12,7 @@ criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -15,7 +15,7 @@ encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -15,7 +15,7 @@ encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......
......@@ -14,7 +14,7 @@ label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -2,21 +2,21 @@
# training the model
gpu_num=8
update_freq=1
gpu_num=4
update_freq=4
max_tokens=8192
exp_tag=baseline
config_list=(base)
# exp full name
exp_name=
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
exp_tag=baseline
config_list=(deep)
# exp full name
exp_name=
train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
cmd="./run.sh
......
File mode changed from 100644 to 100755
#! /bin/bash
# Processing WMT16 En-De Datasets
# Processing WMT20 En-Zh Datasets
# Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com)
......@@ -35,18 +35,19 @@ lang=${src_lang}-${tgt_lang}
dataset=wmt20
task=translation
vocab_type=unigram
vocab_size=32000
src_vocab_type=unigram
tgt_vocab_type=unigram
src_vocab_size=32000
tgt_vocab_size=32000
share_dict=0
lcrm=1
tokenizer=1
use_specific_dict=0
subword=0
specific_prefix=subword32000_share
specific_dir=${root_dir}/data/mustc/st
src_vocab_prefix=spm_unigram10000_st_share
tgt_vocab_prefix=spm_unigram10000_st_share
use_specific_dict=1
specific_prefix=asr5k_st10k
specific_dir=${root_dir}/data/iwslt2022/st_lcrm_asr
src_vocab_prefix=spm_unigram5000_asr
tgt_vocab_prefix=spm_unigram10000_st
org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/mt
......@@ -81,17 +82,24 @@ len_penalty=1.0
if [[ ${use_specific_dict} -eq 1 ]]; then
exp_prefix=${exp_prefix}_${specific_prefix}
data_dir=${data_dir}/${specific_prefix}
mkdir -p ${data_dir}
else
if [[ "${vocab_type}" == "char" ]]; then
vocab_name=${vocab_type}
exp_prefix=${exp_prefix}_${vocab_type}
if [[ "${tgt_vocab_type}" == "char" ]]; then
vocab_name=char
exp_prefix=${exp_prefix}_char
else
if [[ ${src_vocab_size} -ne ${tgt_vocab_size} || "${src_vocab_type}" -ne "${tgt_vocab_type}" ]]; then
src_vocab_name=${src_vocab_type}${src_vocab_size}
tgt_vocab_name=${tgt_vocab_type}${tgt_vocab_size}
vocab_name=${src_vocab_name}_${tgt_vocab_name}
else
vocab_name=${vocab_type}${vocab_size}
vocab_name=${tgt_vocab_type}${tgt_vocab_size}
src_vocab_name=${vocab_name}
tgt_vocab_name=${vocab_name}
fi
fi
data_dir=${data_dir}/${vocab_name}
src_vocab_prefix=spm_${vocab_name}_${src_lang}
tgt_vocab_prefix=spm_${vocab_name}_${tgt_lang}
src_vocab_prefix=spm_${src_vocab_name}_${src_lang}
tgt_vocab_prefix=spm_${tgt_vocab_name}_${tgt_lang}
if [[ $share_dict -eq 1 ]]; then
data_dir=${data_dir}_share
src_vocab_prefix=spm_${vocab_name}_share
......@@ -103,6 +111,9 @@ if [[ ${lcrm} -eq 1 ]]; then
exp_prefix=${exp_prefix}_lcrm
fi
if [[ ${tokenizer} -eq 1 ]]; then
train_subset=${train_subset}.tok
valid_subset=${valid_subset}.tok
trans_subset=${trans_subset}.tok
data_dir=${data_dir}_tok
exp_prefix=${exp_prefix}_tok
fi
......@@ -139,16 +150,14 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--splits ${train_subset},${valid_subset},${trans_subset}
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--vocab-type ${vocab_type}
--vocab-size ${vocab_size}"
--src-vocab-type ${src_vocab_type}
--tgt-vocab-type ${tgt_vocab_type}
--src-vocab-size ${src_vocab_size}
--tgt-vocab-size ${tgt_vocab_size}"
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi
if [[ ${tokenizer} -eq 1 ]]; then
cmd="$cmd
--tokenizer"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
......@@ -168,10 +177,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi
src_text=${text_dir}/${split}.${src_lang}
tgt_text=${text_dir}/${split}.${tgt_lang}
if [[ ${tokenizer} -eq 1 ]]; then
src_text=${text_dir}/${split}.tok.${src_lang}
tgt_text=${text_dir}/${split}.tok.${tgt_lang}
fi
cmd="cat ${src_text}"
if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${src_text}"
......@@ -327,16 +332,14 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
export CUDA_VISIBLE_DEVICES=${device}
log=${model_dir}/train.log
cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then
eval $cmd
sleep 2s
tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log}
fi
wait
echo -e " >> finish training \n"
fi
wait
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: MT Decoding"
......@@ -381,15 +384,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--results-path ${model_dir}
--max-tokens ${max_tokens}
--beam ${beam_size}
--lenpen ${len_penalty}"
if [[ ${subword} -eq 1 ]]; then
cmd="${cmd}
--post-process subword_nmt"
else
cmd="${cmd}
--lenpen ${len_penalty}
--post-process sentencepiece"
fi
if [[ ${sacrebleu} -eq 1 ]]; then
cmd="${cmd}
......
......@@ -2,8 +2,8 @@
# training the model
gpu_num=8
update_freq=2
gpu_num=4
update_freq=4
max_tokens=8192
exp_tag=baseline
......
......@@ -24,6 +24,9 @@ from fairseq.modules import (
PDSTransformerEncoderLayer,
DownSampleConvolutionModule
)
from fairseq.modules.speech_to_text import (
subsampling
)
logger = logging.getLogger(__name__)
......@@ -65,16 +68,14 @@ class Downsampling(nn.Module):
self.stride = stride
self.reduced_way = reduced_way
if stride == 0:
return
# default conv
if self.reduced_way == "conv":
self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
)
elif self.reduced_way == "glu":
self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels * 2, kernel_sizes, stride=stride, padding=padding),
nn.GLU(dim=1)
)
elif self.reduced_way == "proj":
self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
......@@ -88,6 +89,9 @@ class Downsampling(nn.Module):
self.norm = LayerNorm(out_channels)
def forward(self, x, lengths):
if self.stride == 0:
return x, lengths
seq_len, bsz, dim = x.size()
assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride)
......@@ -110,23 +114,20 @@ class Downsampling(nn.Module):
else:
x = x.permute(1, 2, 0) # B * D * T
x = self.conv(x)
if self.reduced_way == "glu":
x = self.glu(x)
x = x.permute(2, 0, 1) # T * B * D
if self.embed_norm:
x = self.norm(x)
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, x.size(0))
# mask batch padding
if not torch.all(lengths == x.size(-1)):
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, x.size(0))
mask_pad = padding_mask.unsqueeze(2)
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
return x, lengths, padding_mask
return x, lengths
@register_model("pdss2t_transformer")
......@@ -139,6 +140,44 @@ class PDSS2TTransformerModel(S2TTransformerModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# subsampling
parser.add_argument(
"--subsampling-type",
type=str,
help="subsampling type, like conv1d and conv2d",
)
parser.add_argument(
"--subsampling-layers",
type=int,
help="subsampling layers",
)
parser.add_argument(
"--subsampling-filter",
type=int,
help="subsampling filter",
)
parser.add_argument(
"--subsampling-kernel",
type=int,
help="subsampling kernel",
)
parser.add_argument(
"--subsampling-stride",
type=int,
help="subsampling stride",
)
parser.add_argument(
"--subsampling-norm",
type=str,
default="none",
help="subsampling normalization type",
)
parser.add_argument(
"--subsampling-activation",
type=str,
default="none",
help="subsampling activation function type",
)
# Transformer
parser.add_argument(
"--activation-fn",
......@@ -485,6 +524,16 @@ class PDSS2TTransformerModel(S2TTransformerModel):
help="the ratio of the ffn in each stage",
)
parser.add_argument(
"--pds-conv-strides",
type=str,
help="the strides of the convolutional module (conformer) in each stage",
)
parser.add_argument(
"--pds-attn-strides",
type=str,
help="the strides of the attention module (conformer) in each stage",
)
parser.add_argument(
"--pds-fusion",
action="store_true",
help="use the representation fusion method",
......@@ -565,8 +614,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
)
self.pds_stages = getattr(args, "pds_stages", 4)
self.pds_layers = [int(n) for n in args.pds_layers.split("_")]
self.layers = sum(self.pds_layers)
self.pds_ratios = [int(n) for n in args.pds_ratios.split("_")]
# down-sampling module
......@@ -582,6 +631,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
else:
self.pds_attn_ds_ratios = None
self.pds_conv_strides = [int(n) for n in args.pds_conv_strides.split("_")]
self.pds_attn_strides = [int(n) for n in args.pds_attn_strides.split("_")]
# fusion
self.pds_fusion = args.pds_fusion
self.pds_fusion_method = args.pds_fusion_method
......@@ -619,15 +672,23 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
use_pos_embed = self.pds_position_embed[i]
use_ctc = self.pds_ctc[i]
ffn_ratio = self.pds_ffn_ratios[i]
num_head = self.pds_attn_heads[i]
attn_ds_ratio = self.pds_attn_ds_ratios[i] if self.attn_type == "reduced" else -1
ffn_ratio = self.pds_ffn_ratios[i]
conv_stride = self.pds_conv_strides[i]
attn_stride = self.pds_attn_strides[i]
if conv_stride != 1 or attn_stride != 1:
expand_embed_dim = embed_dim if i == self.pds_stages - 1 else self.pds_embed_dims[i + 1]
else:
expand_embed_dim = None
logger.info("The stage {}: layer {}, down-sample ratio {}, embed dim {}, "
"kernel size {}, position embed {}, ffn ratio {}, num head {}, "
"attn down-sample ratio {}, conv stride {}, attn stride {}, "
"fusion {}, fusion method {}, fusion transformer {}.".
format(i, num_layers, ds_ratio, embed_dim,
kernel_size, use_pos_embed, ffn_ratio, num_head,
attn_ds_ratio, conv_stride, attn_stride,
self.pds_fusion, self.pds_fusion_method, self.pds_fusion_transform))
if i == 0:
......@@ -636,6 +697,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.embed_scale = 1.0
# down-sampling
if ds_ratio == -1:
downsampling = subsampling(args, embed_dim)
else:
downsampling = Downsampling(
self.pds_ds_method,
self.pds_embed_norm,
......@@ -650,24 +714,33 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if use_pos_embed:
if self.attn_type == "rel_pos":
pos_embed = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
args.max_source_positions, embed_dim
)
elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
pos_embed = LegacyRelPositionalEncoding(
args.encoder_embed_dim, args.dropout, args.max_source_positions
embed_dim, args.dropout, args.max_source_positions
)
elif self.attn_type == "rope":
self.embed_positions = None
pos_embed = None
else: # Use absolute positional embedding
pos_embed = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
args.max_source_positions, embed_dim, self.padding_idx
)
else:
pos_embed = None
stage = nn.ModuleList([
PDSTransformerEncoderLayer(args, embed_dim, ffn_ratio, num_head, attn_ds_ratio)
for _ in range(num_layers)])
PDSTransformerEncoderLayer(
args,
embed_dim,
ffn_ratio,
num_head,
attn_ds_ratio,
conv_stride=conv_stride if layer_idx == num_layers - 1 else 1,
attn_stride=attn_stride if layer_idx == num_layers - 1 else 1,
expand_embed_dim=expand_embed_dim if layer_idx == num_layers - 1 else None,
)
for layer_idx in range(num_layers)])
# representation fusion
fusion_pre_layer_norm = None
......@@ -760,9 +833,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
(("ctc" in getattr(args, "criterion", "")) and
(getattr(args, "ctc_weight", False) > 0))
if self.use_ctc:
# self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers
# self.ctc_layer = args.encoder_layers if self.ctc_layer == 0 else self.ctc_layer
# self.inter_ctc = True if self.ctc_layer != args.encoder_layers or self.fusion_stages_num != 0 else False
# self.ctc_layer = (args.ctc_layer + self.layers) % self.layers
# self.ctc_layer = self.layers if self.ctc_layer == 0 else self.ctc_layer
# self.inter_ctc = True if self.ctc_layer != self.layers or self.fusion_stages_num != 0 else False
self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != 0 else False
......@@ -824,9 +897,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# padding to the multiply of 2
max_len = x.size(0)
length = reduce(lambda a, b: a * b, self.pds_ratios)
length = reduce(lambda a, b: max(1, a) * max(1, b), self.pds_ratios)
padding_to_len = (length - max_len % length)
if padding_to_len > 0:
if length > 1 and padding_to_len > 0:
padding_for_pds = x.new_zeros((padding_to_len, batch, x.size(2)))
x = torch.cat([x, padding_for_pds], dim=0)
......@@ -848,7 +921,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
ctc = getattr(self, f"ctc{i + 1}")
adapter = getattr(self, f"adapter{i + 1}")
x, input_lengths, encoder_padding_mask = downsampling(x, input_lengths)
x, input_lengths = downsampling(x, input_lengths)
encoder_padding_mask = lengths_to_padding_mask_with_maxlen(input_lengths, x.size(0))
# gather cosine similarity
cos_sim_idx += 10
......@@ -881,6 +955,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
if layer.conv_stride > 1:
# Stride Mask (B, 1, T // S, T // S)
if encoder_padding_mask is not None:
encoder_padding_mask = encoder_padding_mask[:, ::layer.conv_stride]
# Update Seq Lengths
if input_lengths is not None:
input_lengths = torch.div(input_lengths - 1, layer.conv_stride, rounding_mode='floor') + 1
# gather cosine similarity
if self.gather_cos_sim:
cos_sim_idx += 1
......@@ -983,12 +1067,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
@register_model_architecture(model_name="pdss2t_transformer", arch_name="pdss2t_transformer")
def base_architecture(args):
# Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "")
args.conv_channels = getattr(args, "conv_channels", 1024)
args.subsampling_type = getattr(args, "subsampling_type", "conv1d")
args.subsampling_layers = getattr(args, "subsampling_layers", 2)
args.subsampling_filter = getattr(args, "subsampling_filter", 1024)
args.subsampling_kernel = getattr(args, "subsampling_kernel", 5)
args.subsampling_stride = getattr(args, "subsampling_stride", 2)
args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
......@@ -1046,6 +1134,9 @@ def base_architecture(args):
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
......
import logging
import math
from typing import Dict, List, Optional, Tuple
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text.modules import Adapter, CTC
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
S2TTransformerEncoderLayer,
DynamicLinearCombination,
)
from fairseq.modules.speech_to_text import (
subsampling
)
from torch import Tensor
......@@ -445,6 +430,16 @@ class S2TCTCModel(FairseqEncoderModel):
help="the ratio of the ffn in each stage",
)
parser.add_argument(
"--pds-conv-strides",
type=str,
help="the strides of the convolutional module (conformer) in each stage",
)
parser.add_argument(
"--pds-attn-strides",
type=str,
help="the strides of the attention module (conformer) in each stage",
)
parser.add_argument(
"--pds-fusion",
action="store_true",
help="use the representation fusion method",
......@@ -573,236 +568,15 @@ class S2TCTCEncoder(FairseqEncoder):
logger.error("Unsupported architecture: %s." % encoder_type)
return
# dim = args.encoder_embed_dim
# self.dropout_module = FairseqDropout(
# p=args.dropout, module_name=self.__class__.__name__
# )
# self.embed_scale = math.sqrt(dim)
# if args.no_scale_embedding:
# self.embed_scale = 1.0
# self.padding_idx = 1
#
# self.subsample = subsampling(args)
#
# self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
#
# if self.attn_type == "rel_pos":
# self.embed_positions = RelPositionalEncoding(
# args.max_source_positions, args.encoder_embed_dim
# )
# elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
# self.embed_positions = LegacyRelPositionalEncoding(
# args.encoder_embed_dim, args.dropout, args.max_source_positions
# )
# elif self.attn_type == "rope":
# self.embed_positions = None
# else: # Use absolute positional embedding
# self.embed_positions = PositionalEmbedding(
# args.max_source_positions, args.encoder_embed_dim, self.padding_idx
# )
#
# self.layers = nn.ModuleList(
# [S2TTransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
# )
#
# if args.encoder_normalize_before:
# self.layer_norm = LayerNorm(dim)
# else:
# self.layer_norm = None
#
# if args.use_enc_dlcl:
# self.history = DynamicLinearCombination(args, is_encoder=True)
# else:
# self.history = None
#
# self.ctc = CTC(dim,
# dictionary_size=len(task.source_dictionary),
# dropout=args.dropout,
# )
#
# # gather cosine similarity of the representation
# self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
# # self.gather_cos_sim = True
# self.dis = 2
# self.cos_sim = dict()
#
# self.intermedia_ctc_layers = []
#
# if args.intermedia_ctc_layers is not None:
# intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
# for layer_idx in intermedia_ctc_layers:
# layer_idx = int(layer_idx)
# if layer_idx <= 0:
# layer_idx += args.encoder_layers
# self.intermedia_ctc_layers.append(layer_idx)
#
# logger.info("Intermedia CTC loss in layer %d" % layer_idx)
#
# strategy = None
# if args.intermedia_adapter == "shrink":
# strategy = getattr(args, "ctc_compress_strategy", "avg")
# elif args.intermedia_adapter == "league":
# strategy = getattr(args, "intermedia_distribution_cutoff", -1)
# self.adapter = Adapter(dim, args.intermedia_adapter,
# task.source_dictionary, strategy=strategy)
# self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
def add_to_dict(self, x, dis, idx):
sim = 0
seq_len = x.size(0)
cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
for i in range(dis, seq_len - dis):
a = x[i, :, :]
for j in range(-dis, dis + 1):
if j == 0:
continue
b = x[i + j, :, :]
sim_j = cos(a, b).mean()
sim += sim_j
sim = sim / 2 / dis / (seq_len - 2 * dis)
if idx not in self.cos_sim:
self.cos_sim[idx] = []
self.cos_sim[idx].append(float(sim))
def forward(self, src_tokens, src_lengths, **kwargs):
return self.encoder(src_tokens, src_lengths, **kwargs)
#
# if self.history is not None:
# self.history.clean()
#
# # gather cosine similarity
# cos_sim_idx = -1
# dis = self.dis
# if self.gather_cos_sim:
# self.add_to_dict(src_tokens.transpose(0, 1), dis, cos_sim_idx)
#
# # down-sampling
# x, input_lengths = self.subsample(src_tokens, src_lengths)
# # (B, T, D) -> (T, B, D)
# x = x.transpose(0, 1)
#
# # embedding scaling
# x = self.embed_scale * x
#
# # padding and position embedding
# encoder_padding_mask = lengths_to_padding_mask(input_lengths)
#
# if self.attn_type in ["rel_selfattn", "rel_pos", "rel_pos_legacy"]:
# positions = self.embed_positions(x)
#
# elif self.attn_type == "rope":
# positions = None
#
# else:
# positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
# x += positions
# positions = None
#
# x = self.dropout_module(x)
#
# # add emb into history
# if self.history is not None:
# self.history.push(x)
#
# # gather cosine similarity
# cos_sim_idx = (cos_sim_idx + 10) // 10 * 10 - 1
# if self.gather_cos_sim:
# cos_sim_idx += 1
# self.add_to_dict(x, dis, cos_sim_idx)
#
# layer_idx = 0
# intermedia_ctc_logits = []
# for layer in self.layers:
# layer_idx += 1
#
# if self.history is not None:
# x = self.history.pop()
#
# # encoder layer
# x = layer(x, encoder_padding_mask, pos_emb=positions)
#
# # interleave CTC
# if layer_idx in self.intermedia_ctc_layers:
# if self.intermedia_drop_prob > 0:
# p = torch.rand(1).uniform_()
# if p < self.intermedia_drop_prob:
# break
#
# norm_x = self.layer_norm(x)
# logit = self.ctc(norm_x)
# intermedia_ctc_logits.append(logit)
#
# prob = F.softmax(logit, dim=-1, dtype=torch.float32)
# x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
#
# # gather cosine similarity
# if self.gather_cos_sim:
# cos_sim_idx += 1
# self.add_to_dict(x, dis, cos_sim_idx)
#
# if self.history is not None:
# self.history.push(x)
#
# if self.history is not None:
# x = self.history.pop()
#
# if self.layer_norm is not None:
# x = self.layer_norm(x)
#
# ctc_logit = self.ctc(x)
#
# return {
# "encoder_out": [x], # T x B x C
# "ctc_logit": [ctc_logit], # B x T x C
# "intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C
# "encoder_padding_mask": [encoder_padding_mask], # B x T
# "encoder_embedding": [], # B x T x C
# "encoder_states": [], # List[T x B x C]
# "src_tokens": [],
# "src_lengths": [],
# }
def reorder_encoder_out(self, encoder_out, new_order):
self.encoder.reorder_encoder_out(encoder_out, new_order)
return
new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
)
new_ctc_logit = (
[] if len(encoder_out["ctc_logit"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["ctc_logit"] if x is not None]
)
new_encoder_padding_mask = (
[] if len(encoder_out["encoder_padding_mask"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
)
new_encoder_embedding = (
[] if len(encoder_out["encoder_embedding"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]]
)
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"ctc_logit": new_ctc_logit, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [], # B x T
"src_lengths": [], # B x 1
}
class CTCDecoder(object):
......@@ -968,6 +742,9 @@ def base_architecture(args):
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
......
......@@ -620,9 +620,9 @@ class S2TTransformerEncoder(FairseqEncoder):
self.add_to_dict(src_tokens.transpose(0, 1), dis, cos_sim_idx)
# down-sampling
x, input_lengths = self.subsample(src_tokens, src_lengths)
# (B, T, D) -> (T, B, D)
x = x.transpose(0, 1)
x = src_tokens.transpose(0, 1)
x, input_lengths = self.subsample(x, src_lengths)
# embedding scaling
x = self.embed_scale * x
......
......@@ -205,7 +205,7 @@ class LegacyRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention):
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
dropout (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_feat, n_head, dropout, zero_triu=False):
......
......@@ -104,6 +104,7 @@ class PDSTransformerEncoderLayer(nn.Module):
self.final_norm = LayerNorm(expand_embed_dim)
# Convolution Residual
self.conv_stride = conv_stride
self.conv_res = nn.Sequential(
Permute3D(1, 2, 0),
nn.Conv1d(embed_dim, expand_embed_dim, kernel_size=1, stride=conv_stride),
......@@ -322,7 +323,7 @@ class PDSTransformerEncoderLayer(nn.Module):
x = self.conv_module(x)
x = x.transpose(0, 1)
x = residual + x
x = self.conv_res(residual) + x
if not self.normalize_before:
x = self.conv_norm(x)
......
......@@ -144,8 +144,8 @@ class Conv1dSubsampling(nn.Module):
def forward(self, x, x_len):
# (B, T, D) -> (B, D, T)
x = x.transpose(1, 2)
# (T, B, D) -> (B, D, T)
x = x.permute(1, 2, 0)
# Layers
for layer in self.layers:
x = layer(x)
......@@ -153,7 +153,9 @@ class Conv1dSubsampling(nn.Module):
# Update Sequence Lengths
if x_len is not None:
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
x = x.transpose(1, 2)
# (B, D, T) -> (T, B, D)
x = x.permute(2, 0, 1)
return x, x_len
......@@ -168,8 +170,8 @@ class Conv2dSubsampling(nn.Module):
act: activation function
Shape:
Input: (batch_size, in_length, in_dim)
Output: (batch_size, out_length, out_dim)
Input: (in_length, batch_size in_dim)
Output: (out_length, batch_size, out_dim)
"""
......@@ -199,8 +201,8 @@ class Conv2dSubsampling(nn.Module):
def forward(self, x, x_len):
# (B, T, D) -> (B, D, T) -> (B, 1, D, T)
x = x.transpose(1, 2).unsqueeze(dim=1)
# (T, B, D) -> (B, D, T) -> (B, 1, D, T)
x = x.permute(1, 2, 0).unsqueeze(dim=1)
# Layers
for layer in self.layers:
......@@ -212,17 +214,17 @@ class Conv2dSubsampling(nn.Module):
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size()
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).transpose(1, 2)
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).permute(2, 0, 1)
x = self.linear(x)
return x, x_len
def subsampling(args):
def subsampling(args, out_dim=None):
subsampling_type = getattr(args, "subsampling_type", "conv1d")
layers = getattr(args, "subsampling_layers", 2)
in_dim = args.input_feat_per_channel * args.input_channels
filters = [getattr(args, "subsampling_filter")] + [args.encoder_embed_dim]
filters = [getattr(args, "subsampling_filter")] + [args.encoder_embed_dim if out_dim is None else out_dim]
kernel_size = getattr(args, "subsampling_kernel", 5)
stride = getattr(args, "subsampling_stride", 2)
norm = getattr(args, "subsampling_norm", "none")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论