Commit 1d60b3a6 by xuchen

big update! I integrate the latest updates of shell scripts, optimize the…

big update! I integrate the latest updates of shell scripts, optimize the implementation of sae and fix some bugs.
parent 1288e535
...@@ -13,7 +13,7 @@ label_smoothing: 0.1 ...@@ -13,7 +13,7 @@ label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsampling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
subsampling-norm: none subsampling-norm: none
......
arch: s2t_transformer_m
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 1e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv2d
subsmapling-layers: 2
subsampling-filter: 512
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: relu
dropout: 0.15
activation-fn: relu
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
load-pretrained-encoder-from: /home/xuchen/after.pt
load-pretrained-decoder-from: /home/xuchen/after.pt
#load-pretrained-decoder-from:
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 15
encoder-attention-type: rel_pos encoder-attention-type: rel_pos
encoder-activation-fn: swish encoder-activation-fn: swish
\ No newline at end of file
...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8 ...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
......
#! /bin/bash #! /bin/bash
# Processing ASR Datasets # Processing aishell ASR Datasets
# Copyright 2021 Natural Language Processing Laboratory # Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com) # Xu Chen (xuchenneu@163.com)
...@@ -323,7 +323,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -323,7 +323,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ -z ${cer} && ${cer} -eq 1 ]]; then
suffix=${suffix}_cer
else
suffix=${suffix}_wer
fi
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=${test_subset//,/ } test_subset=${test_subset//,/ }
...@@ -352,6 +361,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -352,6 +361,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file} tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi fi
done done
cat ${result_file} cat ${result_file}
......
...@@ -13,12 +13,12 @@ extra_parameter= ...@@ -13,12 +13,12 @@ extra_parameter=
exp_tag= exp_tag=
#config_list=(base) #config_list=(base ctc)
#config_list=(ctc) config_list=(base ctc conformer)
#config_list=(base conformer) config_list=(big ctc conformer)
#config_list=(pds_base_16) #config_list=(pds_base_16)
config_list=(pds_base_16 conformer rpr) config_list=(pds_base_16 conformer)
# exp full name # exp full name
exp_name= exp_name=
......
...@@ -14,7 +14,7 @@ sacrebleu=0 ...@@ -14,7 +14,7 @@ sacrebleu=0
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=80000 max_tokens=20000
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
......
...@@ -25,6 +25,7 @@ pds-kernel-sizes: 5_5_5_5 ...@@ -25,6 +25,7 @@ pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4 pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8 pds-attn-heads: 4_6_6_8
fp16-scale-tolerance: 0.25
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
......
...@@ -18,7 +18,7 @@ exp_tag= ...@@ -18,7 +18,7 @@ exp_tag=
#config_list=(base conformer) #config_list=(base conformer)
config_list=(pds_base_8 ctc) config_list=(pds_base_8 ctc)
#config_list=(pds_base_16 conformer rpr) #config_list=(pds_base_16 conformer)
# exp full name # exp full name
exp_name= exp_name=
......
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
...@@ -3,15 +3,15 @@ ...@@ -3,15 +3,15 @@
gpu_num=1 gpu_num=1
data_dir= data_dir=
test_subset=(test) test_subset=(valid test)
exp_name= exp_name=
if [ "$#" -eq 1 ]; then if [ "$#" -eq 1 ]; then
exp_name=$1 exp_name=$1
fi fi
sacrebleu=1 sacrebleu=0
n_average=10 n_average=5
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=80000 max_tokens=80000
......
...@@ -3,14 +3,16 @@ arch: s2t_dual ...@@ -3,14 +3,16 @@ arch: s2t_dual
asr-encoder: pds asr-encoder: pds
mt-encoder-layers: 30 mt-encoder-layers: 30
encoder-drop-net: True
encoder-drop-net-prob: 0.8
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 1000 warmup-updates: 1000
lr: 5e-4 lr: 1e-3
#lr: 1e-5
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: join_speech_and_text_loss criterion: join_speech_and_text_loss
...@@ -56,9 +58,9 @@ pds-kernel-sizes: 5_5_5_5 ...@@ -56,9 +58,9 @@ pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4 pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8 pds-attn-heads: 4_6_6_8
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt #load-pretrained-encoder-from:
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt #load-pretrained-decoder-from:
load-pretrained-asr-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/asr/0308_lcrm_unified_pds_base_8_grow_conformer_ctc_baseline_clamp/avg_10_checkpoint.pt #load-pretrained-asr-encoder-from:
load-pretrained-mt-encoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt #load-pretrained-mt-encoder-from:
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt #load-pretrained-decoder-from:
...@@ -6,7 +6,6 @@ lr-scheduler: inverse_sqrt ...@@ -6,7 +6,6 @@ lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 1000 warmup-updates: 1000
lr: 5e-4 lr: 5e-4
#lr: 1e-5
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
...@@ -52,9 +51,7 @@ pds-kernel-sizes: 5_5_5_5 ...@@ -52,9 +51,7 @@ pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4 pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8 pds-attn-heads: 4_6_6_8
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt #load-pretrained-encoder-from:
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt #load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
load-pretrained-acoustic-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/asr/0308_lcrm_unified_pds_base_8_grow_conformer_ctc_baseline_clamp/avg_10_checkpoint.pt #load-pretrained-decoder-from:
load-pretrained-text-encoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt \ No newline at end of file
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
...@@ -10,8 +10,8 @@ if [ "$#" -eq 1 ]; then ...@@ -10,8 +10,8 @@ if [ "$#" -eq 1 ]; then
exp_name=$1 exp_name=$1
fi fi
sacrebleu=0 sacrebleu=1
n_average=1 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=80000 max_tokens=80000
......
...@@ -14,13 +14,11 @@ extra_parameter= ...@@ -14,13 +14,11 @@ extra_parameter=
exp_tag= exp_tag=
#config_list=(base) #config_list=(base)
#config_list=(sate ctc) #config_list=(base ctc)
#config_list=(ctc conformer rpr) #config_list=(pds_base_8 conformer)
#config_list=(base sate)
#config_list=(sate ctc)
config_list=(sate_pds ctc) config_list=(sate_pds ctc)
#config_list=(pds_base_8)
#config_list=(pds_base_8 conformer)
# exp full name # exp full name
exp_name= exp_name=
......
arch: s2t_ctc arch: s2t_sate
encoder-type: transformer share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 0.0015 weight-decay: 1e-6
lr: 2e-3
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: label_smoothed_cross_entropy_with_ctc
ctc-weight: 1.0 label_smoothing: 0.1
subsampling-type: conv2d subsampling-type: conv1d
subsampling-layers: 2 subsampling-layers: 2
subsampling-filter: 176 subsampling-filter: 1024
subsampling-kernel: 3 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
subsampling-norm: batch2d subsampling-norm: none
subsampling-activation: swish subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 176 encoder-embed-dim: 256
encoder-ffn-embed-dim: 704 encoder-ffn-embed-dim: 2048
encoder-layers: 16 encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True decoder-embed-dim: 256
use-cnn-module: True decoder-ffn-embed-dim: 2048
cnn-module-kernel: 31 decoder-attention-heads: 4
encoder-activation-fn: swish attention-dropout: 0.1
encoder-attention-type: rel_pos activation-dropout: 0.1
\ No newline at end of file
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
#inter_mixup: True
#inter_mixup_layer: -1
#inter_mixup_ratio: 0.2
ctc-weight: 0.2
interleaved-ctc-weight: 0.1
interleaved-ctc-layers: 6,9
interleaved-temperature: 2
#target-ctc-weight: 0.3
#target-ctc-layer: 6
target-interleaved-ctc-weight: 0.1
target-interleaved-ctc-layers: 2,4
sae-adapter: league
share-ctc-and-sae: False
sae-drop-prob: 0.2
interleaved-ctc-drop-prob: 0.2
sae-distribution-cutoff: 10
ctc-self-distill-weight: 0
post-process: sentencepiece
\ No newline at end of file
#! /bin/bash #! /bin/bash
# Processing LibriSpeech En-Fr Datasets # Processing LibriSpeech En-Fr ST Datasets
# Copyright 2021 Natural Language Processing Laboratory # Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com) # Xu Chen (xuchenneu@163.com)
......
...@@ -13,11 +13,11 @@ extra_parameter= ...@@ -13,11 +13,11 @@ extra_parameter=
exp_tag= exp_tag=
#config_list=(base) #config_list=(base ctc)
#config_list=(base conformer) #config_list=(base ctc conformer)
#config_list=(pds_base_8) #config_list=(pds_base_8 ctc)
config_list=(pds_base_8 conformer rpr) config_list=(pds_base_8 conformer)
# exp full name # exp full name
exp_name= exp_name=
......
arch: s2t_ctc
encoder-type: transformer
optimizer: adam
#clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
subsampling-type: conv2d
subsampling-layers: 2
subsampling-filter: 176
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: batch2d
subsampling-activation: swish
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 176
encoder-ffn-embed-dim: 704
encoder-layers: 16
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
\ No newline at end of file
...@@ -8,6 +8,7 @@ best-checkpoint-metric: loss ...@@ -8,6 +8,7 @@ best-checkpoint-metric: loss
maximize-best-checkpoint-metric: False maximize-best-checkpoint-metric: False
post-process: sentencepiece post-process: sentencepiece
validate-interval: 1
no-epoch-checkpoints: True no-epoch-checkpoints: True
#keep-last-epochs: 10 #keep-last-epochs: 10
keep-best-checkpoints: 10 keep-best-checkpoints: 10
......
...@@ -5,7 +5,7 @@ clip-norm: 10.0 ...@@ -5,7 +5,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 0.0014
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
......
...@@ -38,11 +38,11 @@ post-process: sentencepiece ...@@ -38,11 +38,11 @@ post-process: sentencepiece
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-layers: 16 encoder-layers: 12
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 15 cnn-module-kernel: 31
encoder-activation-fn: swish encoder-activation-fn: swish
encoder-attention-type: rel_pos encoder-attention-type: rel_pos
......
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 240
pds-stages: 3
#ctc-layer: 15
pds-layers: 5_5_5
pds-ratios: 2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 5_5_5
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-layers: 15
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
#load-pretrained-encoder-from:
arch: pdss2t_transformer_s_16
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_9_3
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 160_192_224_256
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 16
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
arch: pdss2t_transformer_s_16
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_256_320
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_32 ...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_32
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 5 pds-stages: 5
ctc-layer: 12 #ctc-layer: 12
pds-layers: 2_2_3_3_2 pds-layers: 2_2_3_3_2
pds-ratios: 2_2_2_2_2 pds-ratios: 2_2_2_2_2
pds-fusion: True pds-fusion: True
......
arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_1
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8 ...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
......
arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 5_3_3_5
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_224_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 16
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_256_256_320
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
...@@ -21,7 +21,7 @@ clip-norm: 10.0 ...@@ -21,7 +21,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 0.0014
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
......
...@@ -2,7 +2,6 @@ arch: pdss2t_transformer_m_32 ...@@ -2,7 +2,6 @@ arch: pdss2t_transformer_m_32
encoder-embed-dim: 512 encoder-embed-dim: 512
pds-stages: 5 pds-stages: 5
#pds-dropout: 0
pds-layers: 2_2_3_3_2 pds-layers: 2_2_3_3_2
pds-ratios: 2_2_2_2_2 pds-ratios: 2_2_2_2_2
pds-fusion: True pds-fusion: True
...@@ -21,7 +20,7 @@ clip-norm: 10.0 ...@@ -21,7 +20,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 0.0014
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
......
...@@ -20,7 +20,7 @@ clip-norm: 10.0 ...@@ -20,7 +20,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 0.0014
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
......
arch: s2t_ctc
encoder-type: transformer
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.002
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
...@@ -12,7 +12,7 @@ encoder-type: pds ...@@ -12,7 +12,7 @@ encoder-type: pds
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 #ctc-layer: 12
pds-layers: 2_2_6_2 pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2 pds-ratios: 2_2_2_2
pds-fusion: True pds-fusion: True
...@@ -41,4 +41,4 @@ activation-fn: relu ...@@ -41,4 +41,4 @@ activation-fn: relu
encoder-layers: 12 encoder-layers: 12
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
\ No newline at end of file
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_9_3
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 160_192_224_256
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 18
encoder-attention-heads: 4
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_256_320
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 320
pds-stages: 4
#ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_256_320
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
...@@ -12,7 +12,7 @@ encoder-type: pds ...@@ -12,7 +12,7 @@ encoder-type: pds
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
...@@ -48,4 +48,4 @@ decoder-ffn-embed-dim: 2048 ...@@ -48,4 +48,4 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
\ No newline at end of file
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 5_3_3_5
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_224_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 16
encoder-attention-heads: 4
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_256_256_320
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
gpu_num=1 gpu_num=1
data_dir= data_dir=
test_subset=(dev-clean dev-other test-clean test-other) test_subset=(dev-clean dev-other test-clean test-other all)
exp_name= exp_name=
if [ "$#" -eq 1 ]; then if [ "$#" -eq 1 ]; then
...@@ -13,7 +13,7 @@ fi ...@@ -13,7 +13,7 @@ fi
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=80000 max_tokens=100000
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
......
...@@ -55,7 +55,7 @@ exp_tag=baseline ...@@ -55,7 +55,7 @@ exp_tag=baseline
exp_name= exp_name=
# config # config
train_config=ctc train_config=base
data_config=config.yaml data_config=config.yaml
# training setting # training setting
...@@ -190,32 +190,19 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -190,32 +190,19 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--fp16" --fp16"
fi fi
if [[ $step_valid -eq 1 ]]; then if [[ $step_valid -eq 1 ]]; then
validate_interval=1
save_interval=1 save_interval=1
keep_last_epochs=10
no_epoch_checkpoints=0 no_epoch_checkpoints=0
save_interval_updates=500 save_interval_updates=500
keep_interval_updates=10 keep_interval_updates=10
else
validate_interval=1
keep_last_epochs=10
fi fi
if [[ -n $no_epoch_checkpoints && $no_epoch_checkpoints -eq 1 ]]; then if [[ -n $no_epoch_checkpoints && $no_epoch_checkpoints -eq 1 ]]; then
cmd="$cmd cmd="$cmd
--no-epoch-checkpoints" --no-epoch-checkpoints"
fi fi
if [[ -n $validate_interval ]]; then
cmd="${cmd}
--validate-interval $validate_interval "
fi
if [[ -n $save_interval ]]; then if [[ -n $save_interval ]]; then
cmd="${cmd} cmd="${cmd}
--save-interval $save_interval " --save-interval $save_interval "
fi fi
if [[ -n $keep_last_epochs ]]; then
cmd="${cmd}
--keep-last-epochs $keep_last_epochs "
fi
if [[ -n $save_interval_updates ]]; then if [[ -n $save_interval_updates ]]; then
cmd="${cmd} cmd="${cmd}
--save-interval-updates $save_interval_updates" --save-interval-updates $save_interval_updates"
...@@ -271,10 +258,19 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -271,10 +258,19 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ -z ${cer} && ${cer} -eq 1 ]]; then
suffix=${suffix}_cer
else
suffix=${suffix}_wer
fi
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=(${test_subset//,/ }) test_subset=${test_subset//,/ }
for subset in ${test_subset[@]}; do for subset in ${test_subset[@]}; do
subset=${subset} subset=${subset}
cmd="python ${code_dir}/fairseq_cli/generate.py cmd="python ${code_dir}/fairseq_cli/generate.py
...@@ -293,6 +289,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -293,6 +289,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file} tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi fi
done done
cat ${result_file} cat ${result_file}
......
...@@ -12,14 +12,45 @@ extra_parameter= ...@@ -12,14 +12,45 @@ extra_parameter=
#extra_parameter="${extra_parameter} " #extra_parameter="${extra_parameter} "
exp_tag= exp_tag=
# Transformer
#config_list=(base) #config_list=(base)
#config_list=(pds_base_16)
#config_list=(pds_base_8)
# CTC
#config_list=(purectc_base)
#config_list=(purectc_pds_base_8)
#config_list=(purectc_pds_base_8_growth)
#config_list=(purectc_pds_base_8_growth_fusion256)
#config_list=(purectc_pds_base_16)
#config_list=(purectc_pds_base_16_growth)
#config_list=(purectc_pds_base_16_growth_fusion256)
#config_list=(purectc_pds_base_16_growth_fusion320)
# conformer
#config_list=(base conformer) #config_list=(base conformer)
#config_list=(ConformerCTCSmall) #config_list=(big conformer)
#config_list=(pds_base_4 conformer)
#config_list=(pds_base_16 conformer)
config_list=(pds_base_32 conformer)
#config_list=(pds_big_8 conformer)
#config_list=(pds_big_16 conformer)
#config_list=(pds_big_32 conformer)
#config_list=(pds_base_8_growth_fusion256 conformer)
# growth validation
#config_list=(pds_base_8_growth)
#config_list=(pds_base_8_growth_fusion256)
#config_list=(pds_base_16_growth_fusion256)
#config_list=(pds_base_16_growth)
config_list=(purectc_pds_base_16) # compare with Effective
#config_list=(pds_base) #config_list=(purectc_base_compare)
#config_list=(pds_big) #config_list=(purectc_pds_base_8_compare)
#config_list=(pds_deep) #config_list=(purectc_pds_base_8_compare2)
#config_list=(EffecientConformerCTCSmall)
#config_list=(purectc_pds_base_16)
# exp full name # exp full name
exp_name= exp_name=
......
...@@ -16,4 +16,5 @@ no-progress-bar: True ...@@ -16,4 +16,5 @@ no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
skip-invalid-size-inputs-valid-test: True skip-invalid-size-inputs-valid-test: True
\ No newline at end of file post-process: sentencepiece
\ No newline at end of file
ctc-weight: 0.2 ctc-weight: 0.2
intermedia-ctc-layers: 6,9 interleaved-ctc-weight: 0.1
intermedia-adapter: league interleaved-ctc-layers: 6,9
intermedia-ctc-weight: 0.1 interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
sae-adapter: league
sae-drop-prob: 0.2
sae-distribution-cutoff: 10
share-ctc-and-sae: False
ctc-self-distill-weight: 0 ctc-self-distill-weight: 0
post-process: sentencepiece
\ No newline at end of file
inter_mixup: True inter_mixup: True
inter_mixup_layer: 0 inter_mixup_layer: -1
inter_mixup_prob: 1.0
inter_mixup_ratio: 0.2 inter_mixup_ratio: 0.2
\ No newline at end of file
...@@ -240,13 +240,9 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -240,13 +240,9 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [[ $step_valid -eq 1 ]]; then if [[ $step_valid -eq 1 ]]; then
validate_interval=1 validate_interval=1
save_interval=1 save_interval=1
keep_last_epochs=10
no_epoch_checkpoints=0 no_epoch_checkpoints=0
save_interval_updates=500 save_interval_updates=500
keep_interval_updates=10 keep_interval_updates=10
else
validate_interval=1
keep_last_epochs=10
fi fi
if [[ -n $no_epoch_checkpoints && $no_epoch_checkpoints -eq 1 ]]; then if [[ -n $no_epoch_checkpoints && $no_epoch_checkpoints -eq 1 ]]; then
cmd="$cmd cmd="$cmd
...@@ -260,10 +256,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -260,10 +256,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
cmd="${cmd} cmd="${cmd}
--save-interval $save_interval " --save-interval $save_interval "
fi fi
if [[ -n $keep_last_epochs ]]; then
cmd="${cmd}
--keep-last-epochs $keep_last_epochs "
fi
if [[ -n $save_interval_updates ]]; then if [[ -n $save_interval_updates ]]; then
cmd="${cmd} cmd="${cmd}
--save-interval-updates $save_interval_updates" --save-interval-updates $save_interval_updates"
...@@ -282,11 +274,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -282,11 +274,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
mv tmp.log $log mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
cmd="nohup ${cmd} >> ${model_dir}/train.log 2>&1 &" log=${model_dir}/train.log
cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
sleep 2s sleep 2s
tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log}
fi fi
fi fi
wait wait
...@@ -319,7 +312,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -319,7 +312,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ -z ${cer} && ${cer} -eq 1 ]]; then
suffix=${suffix}_cer
else
suffix=${suffix}_wer
fi
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=${test_subset//,/ } test_subset=${test_subset//,/ }
...@@ -351,8 +353,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -351,8 +353,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file} tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi fi
done done
cat ${result_file} cat ${result_file}
fi fi
train-subset: train train-subset: train
valid-subset: dev valid-subset: valid
max-epoch: 50 max-epoch: 50
max-update: 100000 max-update: 100000
......
...@@ -351,10 +351,14 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -351,10 +351,14 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=(${test_subset//,/ }) test_subset=${test_subset//,/ }
for subset in ${test_subset[@]}; do for subset in ${test_subset[@]}; do
cmd="python ${code_dir}/fairseq_cli/generate.py cmd="python ${code_dir}/fairseq_cli/generate.py
${data_dir} ${data_dir}
...@@ -385,6 +389,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -385,6 +389,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file} tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi fi
done done
cat ${result_file} cat ${result_file}
......
...@@ -16,4 +16,5 @@ no-progress-bar: True ...@@ -16,4 +16,5 @@ no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
skip-invalid-size-inputs-valid-test: True skip-invalid-size-inputs-valid-test: True
\ No newline at end of file post-process: sentencepiece
\ No newline at end of file
...@@ -45,6 +45,6 @@ decoder-ffn-embed-dim: 2048 ...@@ -45,6 +45,6 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-asr-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/0225_st_purectc_pds_base_8_baseline_topctc/avg_10_checkpoint.pt #load-pretrained-asr-encoder-from:
#load-pretrained-mt-encoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt #load-pretrained-mt-encoder-from:
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt #load-pretrained-decoder-from:
\ No newline at end of file \ No newline at end of file
ctc-weight: 0.2 ctc-weight: 0.2
intermedia-ctc-weight: 0.1 interleaved-ctc-weight: 0.1
intermedia-ctc-layers: 6,9 interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
#target-ctc-weight: 0.3 #target-ctc-weight: 0.3
#target-ctc-layer: 6 #target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1 target-interleaved-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4 target-interleaved-ctc-layers: 2,4
intermedia-adapter: league sae-adapter: league
#intermedia-drop-prob: 0.2 sae-drop-prob: 0.0
#intermedia-temperature: 5 #sae-distribution-cutoff: 10
share-ctc-and-sae: False
share-target-ctc-and-sae: False
ctc-self-distill-weight: 0 ctc-self-distill-weight: 0
post-process: sentencepiece \ No newline at end of file
\ No newline at end of file
inter_mixup: True
inter_mixup_layer: -1
inter_mixup_prob: 1.0
inter_mixup_ratio: 0.2
\ No newline at end of file
...@@ -369,7 +369,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -369,7 +369,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
if [[ ${sacrebleu} -eq 1 ]]; then
suffix=${suffix}_sacrebleu
else
suffix=${suffix}_multibleu
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=${test_subset//,/ } test_subset=${test_subset//,/ }
...@@ -402,6 +411,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -402,6 +411,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file} tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi fi
done done
cat ${result_file} cat ${result_file}
......
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
...@@ -14,7 +14,7 @@ sacrebleu=0 ...@@ -14,7 +14,7 @@ sacrebleu=0
n_average=5 n_average=5
beam_size=4 beam_size=4
len_penalty=0.6 len_penalty=0.6
max_tokens=80000 max_tokens=4000
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
......
...@@ -12,6 +12,7 @@ extra_parameter= ...@@ -12,6 +12,7 @@ extra_parameter=
#extra_parameter="${extra_parameter} " #extra_parameter="${extra_parameter} "
exp_tag=baseline exp_tag=baseline
config_list=(base)
config_list=(deep) config_list=(deep)
# exp full name # exp full name
......
...@@ -484,7 +484,7 @@ def main(): ...@@ -484,7 +484,7 @@ def main():
default="unigram", default="unigram",
required=True, required=True,
type=str, type=str,
choices=["bpe", "unigram", "char"], choices=["word", "bpe", "unigram", "char"],
), ),
parser.add_argument("--vocab-size", default=8000, type=int) parser.add_argument("--vocab-size", default=8000, type=int)
parser.add_argument("--share", action="store_true", parser.add_argument("--share", action="store_true",
......
...@@ -296,7 +296,8 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): ...@@ -296,7 +296,8 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
if arg_overrides is not None: if arg_overrides is not None:
overwrite_args_by_name(state["cfg"], arg_overrides) overwrite_args_by_name(state["cfg"], arg_overrides)
state = _upgrade_state_dict(state) if len(state.keys()) != 1:
state = _upgrade_state_dict(state)
return state return state
......
...@@ -89,6 +89,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -89,6 +89,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
if self.ctc_criterion.all_ctc_weight > 0: if self.ctc_criterion.all_ctc_weight > 0:
ctc_loss, logging_output = self.ctc_criterion.compute_ctc_loss(model, sample, encoder_out, logging_output) ctc_loss, logging_output = self.ctc_criterion.compute_ctc_loss(model, sample, encoder_out, logging_output)
loss = (1 - self.ctc_weight) * loss + ctc_loss loss = (1 - self.ctc_weight) * loss + ctc_loss
# if hasattr(model.encoder, "get_loss"):
# encoder_loss = model.encoder.get_loss()
# if encoder_loss != 0:
# loss += encoder_loss * sample_size
# logging_output["encoder_loss"] = utils.item(encoder_loss.data)
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output return loss, sample_size, logging_output
...@@ -103,6 +109,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -103,6 +109,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
nll_loss_sum = utils.item( nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs) sum(log.get("nll_loss", 0) for log in logging_outputs)
) )
enc_loss_sum = utils.item(
sum(log.get("encoder_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item( sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs) sum(log.get("sample_size", 0) for log in logging_outputs)
...@@ -121,6 +130,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -121,6 +130,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
metrics.log_derived( metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
) )
if enc_loss_sum != 0:
metrics.log_scalar("enc_loss", enc_loss_sum, sample_size, round=3)
if "ctc_loss" in logging_outputs[0] or "all_ctc_loss" in logging_outputs[0]: if "ctc_loss" in logging_outputs[0] or "all_ctc_loss" in logging_outputs[0]:
CtcCriterion.reduce_metrics(logging_outputs) CtcCriterion.reduce_metrics(logging_outputs)
......
...@@ -286,9 +286,6 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -286,9 +286,6 @@ class TransformerModel(FairseqEncoderDecoderModel):
help="freeze the module of the decoder", help="freeze the module of the decoder",
) )
parser.add_argument('--interleave-dropout', default=0, type=float, metavar='D',
help='interleaved dropout probability')
parser.add_argument( parser.add_argument(
"--squeeze-excitation", "--squeeze-excitation",
default=False, default=False,
......
...@@ -2,21 +2,23 @@ import torch ...@@ -2,21 +2,23 @@ import torch
from torch import nn from torch import nn
from fairseq.modules.activations import get_activation_class from fairseq.modules.activations import get_activation_class
from fairseq.modules.layer_norm import LayerNorm
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
"""Convolution block used in the conformer block""" """Convolution block used in the conformer block"""
def __init__( def __init__(
self, self,
embed_dim, embed_dim,
expand_embed_dim, expand_embed_dim,
depthwise_kernel_size, depthwise_kernel_size,
dropout, dropout,
activation_fn="swish", activation_fn="swish",
bias=False, bias=False,
stride=1, stride=1,
export=False, export=False,
norm_type="batch_norm"
): ):
""" """
Args: Args:
...@@ -30,8 +32,8 @@ class ConvolutionModule(nn.Module): ...@@ -30,8 +32,8 @@ class ConvolutionModule(nn.Module):
""" """
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
assert ( assert (
depthwise_kernel_size - 1 depthwise_kernel_size - 1
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
self.pointwise_conv1 = torch.nn.Conv1d( self.pointwise_conv1 = torch.nn.Conv1d(
embed_dim, embed_dim,
2 * expand_embed_dim, 2 * expand_embed_dim,
...@@ -50,7 +52,13 @@ class ConvolutionModule(nn.Module): ...@@ -50,7 +52,13 @@ class ConvolutionModule(nn.Module):
groups=expand_embed_dim, groups=expand_embed_dim,
bias=bias, bias=bias,
) )
self.batch_norm = nn.BatchNorm1d(expand_embed_dim) self.norm_type = norm_type
if norm_type == "batch_norm":
self.norm = nn.BatchNorm1d(expand_embed_dim)
elif norm_type == "layer_norm":
self.norm = LayerNorm(expand_embed_dim)
else:
assert False, "Unsupported normalization type in convolution module"
self.activation = get_activation_class(activation_fn) self.activation = get_activation_class(activation_fn)
self.pointwise_conv2 = torch.nn.Conv1d( self.pointwise_conv2 = torch.nn.Conv1d(
expand_embed_dim, expand_embed_dim,
...@@ -62,7 +70,7 @@ class ConvolutionModule(nn.Module): ...@@ -62,7 +70,7 @@ class ConvolutionModule(nn.Module):
) )
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
def forward(self, x): def forward(self, x, mask_pad=None):
""" """
Args: Args:
x: Input of shape B X T X C x: Input of shape B X T X C
...@@ -72,23 +80,36 @@ class ConvolutionModule(nn.Module): ...@@ -72,23 +80,36 @@ class ConvolutionModule(nn.Module):
# exchange the temporal dimension and the feature dimension # exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2) x = x.transpose(1, 2)
# zero_mask_pad = mask_pad.unsqueeze(1)
# # mask batch padding
# if mask_pad is not None:
# x.masked_fill_(zero_mask_pad, 0.0)
# GLU mechanism # GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*expand_embed_dim, dim) x = self.pointwise_conv1(x) # (batch, 2*expand_embed_dim, dim)
x = self.glu(x) # (batch, expand_embed_dim, dim) x = self.glu(x) # (batch, expand_embed_dim, dim)
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.batch_norm(x)
if self.norm_type == "layer_norm":
x = x.transpose(1, 2)
x = self.norm(x)
x = self.activation(x) x = self.activation(x)
if self.norm_type == "layer_norm":
x = x.transpose(1, 2)
x = self.pointwise_conv2(x) x = self.pointwise_conv2(x)
# # mask batch padding
# if zero_mask_pad is not None:
# x.masked_fill_(zero_mask_pad, 0.0)
x = x.transpose(1, 2) x = x.transpose(1, 2)
x = self.dropout(x) x = self.dropout(x)
return x return x
# class ConvolutionModule(nn.Module): # class ConvolutionModule(nn.Module):
# """ConvolutionModule in Conformer model.""" # """ConvolutionModule in Conformer model."""
# def __init__(self, # def __init__(self,
......
...@@ -332,7 +332,7 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -332,7 +332,7 @@ class PDSTransformerEncoderLayer(nn.Module):
if self.normalize_before: if self.normalize_before:
x = self.conv_norm(x) x = self.conv_norm(x)
x = self.conv_module(x) x = self.conv_module(x, encoder_padding_mask)
x = x.transpose(0, 1) x = x.transpose(0, 1)
x = self.conv_res(residual) + x x = self.conv_res(residual) + x
......
...@@ -122,7 +122,9 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -122,7 +122,9 @@ class S2TTransformerEncoderLayer(nn.Module):
self.embed_dim, self.embed_dim,
depthwise_kernel_size=args.cnn_module_kernel, depthwise_kernel_size=args.cnn_module_kernel,
dropout=args.dropout, dropout=args.dropout,
activation_fn=getattr(args, 'activation_fn', 'swish')) activation_fn=getattr(args, 'activation_fn', 'swish'),
norm_type=args.cnn_module_norm
)
self.final_norm = LayerNorm(embed_dim) self.final_norm = LayerNorm(embed_dim)
else: else:
self.conv_norm = None self.conv_norm = None
......
...@@ -3,6 +3,7 @@ import logging ...@@ -3,6 +3,7 @@ import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from itertools import groupby
from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.modules import LayerNorm from fairseq.modules import LayerNorm
...@@ -61,9 +62,12 @@ class Adapter(nn.Module): ...@@ -61,9 +62,12 @@ class Adapter(nn.Module):
super().__init__() super().__init__()
dim = dim dim = dim
self.adapter_type = adapter_type self.adapter_type = adapter_type
self.cal_linear = False
self.cal_context = False
if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]: if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
self.cal_linear = True
self.linear_adapter = nn.Sequential( self.linear_adapter = nn.Sequential(
nn.Linear(dim, dim), nn.Linear(dim, dim),
LayerNorm(dim), LayerNorm(dim),
...@@ -71,14 +75,10 @@ class Adapter(nn.Module): ...@@ -71,14 +75,10 @@ class Adapter(nn.Module):
) )
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]: if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
self.cal_context = True
self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization
if embed_tokens is not None: if embed_tokens is not None:
self.embed_adapter.weight = embed_tokens.weight self.embed_adapter.weight = embed_tokens.weight
# if embed_tokens is None:
# num_embeddings = len(dictionary)
# self.embed_adapter = nn.Linear(num_embeddings, dim) # Embedding(num_embeddings, dim, dictionary.pad())
# else:
# self.embed_adapter = embed_tokens
if self.adapter_type == "gated_league": if self.adapter_type == "gated_league":
self.gate_linear = nn.Linear(2 * dim, dim) self.gate_linear = nn.Linear(2 * dim, dim)
...@@ -86,14 +86,22 @@ class Adapter(nn.Module): ...@@ -86,14 +86,22 @@ class Adapter(nn.Module):
self.gate_linear1 = nn.Linear(dim, dim) self.gate_linear1 = nn.Linear(dim, dim)
self.gate_linear2 = nn.Linear(dim, dim) self.gate_linear2 = nn.Linear(dim, dim)
# additional strategy
if self.adapter_type == "shrink": if self.adapter_type == "shrink":
assert strategy is not None assert strategy is not None
self.ctc_compress = getattr(CTCCompressStrategy, strategy) ctc_compress_strategy = getattr(strategy, "ctc_compress_strategy", "avg")
logger.info("CTC Compress Strategy: %s" % strategy) self.ctc_compress = getattr(CTCCompressStrategy, ctc_compress_strategy)
elif self.adapter_type == "league": logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy)
self.distribution_cutoff = strategy
if "league" in self.adapter_type:
self.distribution_cutoff = strategy.get("distribution_cutoff", None)
if self.distribution_cutoff is not None: if self.distribution_cutoff is not None:
logger.info("Distribution cutoff: %d" % int(strategy)) self.distribution_cutoff = int(self.distribution_cutoff)
logger.info("Distribution cutoff: %d" % self.distribution_cutoff)
self.drop_prob = strategy.get("drop_prob", 0)
if self.drop_prob != 0:
logger.info("Adapter drop probability: %f" % self.drop_prob)
def forward(self, x, padding=None): def forward(self, x, padding=None):
...@@ -103,14 +111,11 @@ class Adapter(nn.Module): ...@@ -103,14 +111,11 @@ class Adapter(nn.Module):
org_distribution = distribution org_distribution = distribution
distribution = distribution.contiguous().view(-1, distribution.size(-1)) distribution = distribution.contiguous().view(-1, distribution.size(-1))
if self.adapter_type == "linear": linear_out = None
out = self.linear_adapter(representation) soft_out = None
if self.cal_linear:
elif self.adapter_type == "context":
out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
if self.cal_context:
if self.distribution_cutoff is not None: if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), org_distribution.size(-1) - 1) cutoff = min(int(self.distribution_cutoff), org_distribution.size(-1) - 1)
threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1] threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
...@@ -120,24 +125,33 @@ class Adapter(nn.Module): ...@@ -120,24 +125,33 @@ class Adapter(nn.Module):
distribution = distribution.view(-1, distribution.size(-1)) distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
out = linear_out + soft_out
elif self.adapter_type == "gated_league": if self.adapter_type == "linear":
linear_out = self.linear_adapter(representation) out = linear_out
soft_out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
elif self.adapter_type == "context":
out = soft_out
elif self.adapter_type == "league":
if self.drop_prob > 0 and torch.rand(1).uniform_() < self.drop_prob:
if torch.rand(1).uniform_() < 0.5:
out = linear_out
else:
out = soft_out
else:
out = linear_out + soft_out
elif self.adapter_type == "gated_league":
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid() coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league": elif self.adapter_type == "inter_league":
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
out = representation + soft_out out = representation + soft_out
elif self.adapter_type == "none": elif self.adapter_type == "none":
out = representation out = representation
elif self.adapter_type == "shrink": elif self.adapter_type == "shrink":
from itertools import groupby
lengths = (~padding).long().sum(-1) lengths = (~padding).long().sum(-1)
with torch.no_grad(): with torch.no_grad():
......
...@@ -13,16 +13,14 @@ logger = logging.getLogger(__name__) ...@@ -13,16 +13,14 @@ logger = logging.getLogger(__name__)
class CTC(nn.Module): class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout, need_layernorm=False): def __init__(self, embed_dim, dictionary_size, dropout, need_layernorm=False):
super(CTC, self).__init__() super(CTC, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.ctc_projection = nn.Linear(embed_dim, dictionary_size, bias=False) self.ctc_projection = nn.Linear(embed_dim, dictionary_size)
nn.init.normal_( # nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5)
self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5
)
self.ctc_dropout_module = FairseqDropout( self.ctc_dropout_module = FairseqDropout(
p=dropout, module_name=self.__class__.__name__ p=dropout, module_name=self.__class__.__name__
...@@ -46,4 +44,3 @@ class CTC(nn.Module): ...@@ -46,4 +44,3 @@ class CTC(nn.Module):
def argmax(self, x): def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1) return torch.argmax(self.ctc_projection(x), dim=-1)
...@@ -191,7 +191,8 @@ class Conv2dSubsampling(nn.Module): ...@@ -191,7 +191,8 @@ class Conv2dSubsampling(nn.Module):
filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id], filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
kernel_size, kernel_size,
stride=stride, stride=stride,
padding=(kernel_size - 1) // 2), # padding=(kernel_size - 1) // 2
),
get_norm(norm, get_norm(norm,
filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id], filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
transpose=True if norm == "layer" else False), transpose=True if norm == "layer" else False),
...@@ -214,6 +215,8 @@ class Conv2dSubsampling(nn.Module): ...@@ -214,6 +215,8 @@ class Conv2dSubsampling(nn.Module):
# (B, C, D // S, T // S) -> (B, C * D // S, T // S) # (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size() batch_size, channels, subsampled_dim, subsampled_length = x.size()
assert subsampled_length == max(x_len), "The lengths are mismatched."
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).permute(2, 0, 1) x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).permute(2, 0, 1)
x = self.linear(x) x = self.linear(x)
......
...@@ -156,22 +156,22 @@ def main(cfg: FairseqConfig) -> None: ...@@ -156,22 +156,22 @@ def main(cfg: FairseqConfig) -> None:
) )
break break
if getattr(cfg.model, "cl_dropout", False): # if getattr(cfg.model, "cl_dropout", False):
cl_dropout_epoch = getattr(cfg.model, "cl_dropout_epoch", None) # cl_dropout_epoch = getattr(cfg.model, "cl_dropout_epoch", None)
cl_dropout_strategy = getattr(cfg.model, "cl_dropout_strategy", "linear") # cl_dropout_strategy = getattr(cfg.model, "cl_dropout_strategy", "linear")
dropout = getattr(cfg.model, "dropout", False) # dropout = getattr(cfg.model, "dropout", False)
assert cl_dropout_epoch > 0 # assert cl_dropout_epoch > 0
curr_epoch = epoch_itr.epoch # curr_epoch = epoch_itr.epoch
if curr_epoch <= cl_dropout_epoch: # if curr_epoch <= cl_dropout_epoch:
if curr_epoch == cl_dropout_epoch: # if curr_epoch == cl_dropout_epoch:
curr_dropout = dropout # curr_dropout = dropout
else: # else:
curr_dropout = curr_epoch / cl_dropout_epoch * dropout # curr_dropout = curr_epoch / cl_dropout_epoch * dropout
logger.info("Epoch {}: dropout ratio: {}.".format(curr_epoch, curr_dropout)) # logger.info("Epoch {}: dropout ratio: {}.".format(curr_epoch, curr_dropout))
for name, module in trainer.model.named_modules(): # for name, module in trainer.model.named_modules():
from fairseq.modules.fairseq_dropout import FairseqDropout # from fairseq.modules.fairseq_dropout import FairseqDropout
if isinstance(module, FairseqDropout): # if isinstance(module, FairseqDropout):
module.p = curr_dropout # module.p = curr_dropout
# train for one epoch # train for one epoch
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论