Commit 1d60b3a6 by xuchen

big update! I integrate the latest updates of shell scripts, optimize the…

big update! I integrate the latest updates of shell scripts, optimize the implementation of sae and fix some bugs.
parent 1288e535
...@@ -13,7 +13,7 @@ label_smoothing: 0.1 ...@@ -13,7 +13,7 @@ label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsampling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 1024
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
subsampling-norm: none subsampling-norm: none
......
arch: s2t_transformer_m
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 1e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv2d
subsmapling-layers: 2
subsampling-filter: 512
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: relu
dropout: 0.15
activation-fn: relu
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
load-pretrained-encoder-from: /home/xuchen/after.pt
load-pretrained-decoder-from: /home/xuchen/after.pt
#load-pretrained-decoder-from:
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 15
encoder-attention-type: rel_pos encoder-attention-type: rel_pos
encoder-activation-fn: swish encoder-activation-fn: swish
\ No newline at end of file
...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8 ...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
......
#! /bin/bash #! /bin/bash
# Processing ASR Datasets # Processing aishell ASR Datasets
# Copyright 2021 Natural Language Processing Laboratory # Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com) # Xu Chen (xuchenneu@163.com)
...@@ -323,7 +323,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -323,7 +323,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ -z ${cer} && ${cer} -eq 1 ]]; then
suffix=${suffix}_cer
else
suffix=${suffix}_wer
fi
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=${test_subset//,/ } test_subset=${test_subset//,/ }
...@@ -352,6 +361,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -352,6 +361,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file} tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi fi
done done
cat ${result_file} cat ${result_file}
......
...@@ -13,12 +13,12 @@ extra_parameter= ...@@ -13,12 +13,12 @@ extra_parameter=
exp_tag= exp_tag=
#config_list=(base) #config_list=(base ctc)
#config_list=(ctc) config_list=(base ctc conformer)
#config_list=(base conformer) config_list=(big ctc conformer)
#config_list=(pds_base_16) #config_list=(pds_base_16)
config_list=(pds_base_16 conformer rpr) config_list=(pds_base_16 conformer)
# exp full name # exp full name
exp_name= exp_name=
......
...@@ -14,7 +14,7 @@ sacrebleu=0 ...@@ -14,7 +14,7 @@ sacrebleu=0
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=80000 max_tokens=20000
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
......
...@@ -25,6 +25,7 @@ pds-kernel-sizes: 5_5_5_5 ...@@ -25,6 +25,7 @@ pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4 pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8 pds-attn-heads: 4_6_6_8
fp16-scale-tolerance: 0.25
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
......
...@@ -18,7 +18,7 @@ exp_tag= ...@@ -18,7 +18,7 @@ exp_tag=
#config_list=(base conformer) #config_list=(base conformer)
config_list=(pds_base_8 ctc) config_list=(pds_base_8 ctc)
#config_list=(pds_base_16 conformer rpr) #config_list=(pds_base_16 conformer)
# exp full name # exp full name
exp_name= exp_name=
......
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
...@@ -3,15 +3,15 @@ ...@@ -3,15 +3,15 @@
gpu_num=1 gpu_num=1
data_dir= data_dir=
test_subset=(test) test_subset=(valid test)
exp_name= exp_name=
if [ "$#" -eq 1 ]; then if [ "$#" -eq 1 ]; then
exp_name=$1 exp_name=$1
fi fi
sacrebleu=1 sacrebleu=0
n_average=10 n_average=5
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=80000 max_tokens=80000
......
...@@ -3,14 +3,16 @@ arch: s2t_dual ...@@ -3,14 +3,16 @@ arch: s2t_dual
asr-encoder: pds asr-encoder: pds
mt-encoder-layers: 30 mt-encoder-layers: 30
encoder-drop-net: True
encoder-drop-net-prob: 0.8
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 1000 warmup-updates: 1000
lr: 5e-4 lr: 1e-3
#lr: 1e-5
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: join_speech_and_text_loss criterion: join_speech_and_text_loss
...@@ -56,9 +58,9 @@ pds-kernel-sizes: 5_5_5_5 ...@@ -56,9 +58,9 @@ pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4 pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8 pds-attn-heads: 4_6_6_8
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt #load-pretrained-encoder-from:
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt #load-pretrained-decoder-from:
load-pretrained-asr-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/asr/0308_lcrm_unified_pds_base_8_grow_conformer_ctc_baseline_clamp/avg_10_checkpoint.pt #load-pretrained-asr-encoder-from:
load-pretrained-mt-encoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt #load-pretrained-mt-encoder-from:
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt #load-pretrained-decoder-from:
...@@ -6,7 +6,6 @@ lr-scheduler: inverse_sqrt ...@@ -6,7 +6,6 @@ lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 1000 warmup-updates: 1000
lr: 5e-4 lr: 5e-4
#lr: 1e-5
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
...@@ -52,9 +51,7 @@ pds-kernel-sizes: 5_5_5_5 ...@@ -52,9 +51,7 @@ pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4 pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8 pds-attn-heads: 4_6_6_8
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt #load-pretrained-encoder-from:
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt #load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
load-pretrained-acoustic-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/asr/0308_lcrm_unified_pds_base_8_grow_conformer_ctc_baseline_clamp/avg_10_checkpoint.pt #load-pretrained-decoder-from:
load-pretrained-text-encoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt \ No newline at end of file
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
...@@ -10,8 +10,8 @@ if [ "$#" -eq 1 ]; then ...@@ -10,8 +10,8 @@ if [ "$#" -eq 1 ]; then
exp_name=$1 exp_name=$1
fi fi
sacrebleu=0 sacrebleu=1
n_average=1 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=80000 max_tokens=80000
......
...@@ -14,13 +14,11 @@ extra_parameter= ...@@ -14,13 +14,11 @@ extra_parameter=
exp_tag= exp_tag=
#config_list=(base) #config_list=(base)
#config_list=(sate ctc) #config_list=(base ctc)
#config_list=(ctc conformer rpr) #config_list=(pds_base_8 conformer)
#config_list=(base sate)
#config_list=(sate ctc)
config_list=(sate_pds ctc) config_list=(sate_pds ctc)
#config_list=(pds_base_8)
#config_list=(pds_base_8 conformer)
# exp full name # exp full name
exp_name= exp_name=
......
arch: s2t_ctc arch: s2t_sate
encoder-type: transformer share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 0.0015 weight-decay: 1e-6
lr: 2e-3
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: label_smoothed_cross_entropy_with_ctc
ctc-weight: 1.0 label_smoothing: 0.1
subsampling-type: conv2d subsampling-type: conv1d
subsampling-layers: 2 subsampling-layers: 2
subsampling-filter: 176 subsampling-filter: 1024
subsampling-kernel: 3 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
subsampling-norm: batch2d subsampling-norm: none
subsampling-activation: swish subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 176 encoder-embed-dim: 256
encoder-ffn-embed-dim: 704 encoder-ffn-embed-dim: 2048
encoder-layers: 16 encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True decoder-embed-dim: 256
use-cnn-module: True decoder-ffn-embed-dim: 2048
cnn-module-kernel: 31 decoder-attention-heads: 4
encoder-activation-fn: swish attention-dropout: 0.1
encoder-attention-type: rel_pos activation-dropout: 0.1
\ No newline at end of file
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
#inter_mixup: True
#inter_mixup_layer: -1
#inter_mixup_ratio: 0.2
ctc-weight: 0.2
interleaved-ctc-weight: 0.1
interleaved-ctc-layers: 6,9
interleaved-temperature: 2
#target-ctc-weight: 0.3
#target-ctc-layer: 6
target-interleaved-ctc-weight: 0.1
target-interleaved-ctc-layers: 2,4
sae-adapter: league
share-ctc-and-sae: False
sae-drop-prob: 0.2
interleaved-ctc-drop-prob: 0.2
sae-distribution-cutoff: 10
ctc-self-distill-weight: 0
post-process: sentencepiece
\ No newline at end of file
#! /bin/bash #! /bin/bash
# Processing LibriSpeech En-Fr Datasets # Processing LibriSpeech En-Fr ST Datasets
# Copyright 2021 Natural Language Processing Laboratory # Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com) # Xu Chen (xuchenneu@163.com)
......
...@@ -13,11 +13,11 @@ extra_parameter= ...@@ -13,11 +13,11 @@ extra_parameter=
exp_tag= exp_tag=
#config_list=(base) #config_list=(base ctc)
#config_list=(base conformer) #config_list=(base ctc conformer)
#config_list=(pds_base_8) #config_list=(pds_base_8 ctc)
config_list=(pds_base_8 conformer rpr) config_list=(pds_base_8 conformer)
# exp full name # exp full name
exp_name= exp_name=
......
arch: s2t_ctc
encoder-type: transformer
optimizer: adam
#clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
subsampling-type: conv2d
subsampling-layers: 2
subsampling-filter: 176
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: batch2d
subsampling-activation: swish
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 176
encoder-ffn-embed-dim: 704
encoder-layers: 16
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
\ No newline at end of file
...@@ -8,6 +8,7 @@ best-checkpoint-metric: loss ...@@ -8,6 +8,7 @@ best-checkpoint-metric: loss
maximize-best-checkpoint-metric: False maximize-best-checkpoint-metric: False
post-process: sentencepiece post-process: sentencepiece
validate-interval: 1
no-epoch-checkpoints: True no-epoch-checkpoints: True
#keep-last-epochs: 10 #keep-last-epochs: 10
keep-best-checkpoints: 10 keep-best-checkpoints: 10
......
...@@ -5,7 +5,7 @@ clip-norm: 10.0 ...@@ -5,7 +5,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 0.0014
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
......
...@@ -38,11 +38,11 @@ post-process: sentencepiece ...@@ -38,11 +38,11 @@ post-process: sentencepiece
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-layers: 16 encoder-layers: 12
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 15 cnn-module-kernel: 31
encoder-activation-fn: swish encoder-activation-fn: swish
encoder-attention-type: rel_pos encoder-attention-type: rel_pos
......
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 240
pds-stages: 3
#ctc-layer: 15
pds-layers: 5_5_5
pds-ratios: 2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 5_5_5
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-layers: 15
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
#load-pretrained-encoder-from:
arch: pdss2t_transformer_s_16
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_9_3
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 160_192_224_256
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 16
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
arch: pdss2t_transformer_s_16
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_256_320
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_32 ...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_32
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 5 pds-stages: 5
ctc-layer: 12 #ctc-layer: 12
pds-layers: 2_2_3_3_2 pds-layers: 2_2_3_3_2
pds-ratios: 2_2_2_2_2 pds-ratios: 2_2_2_2_2
pds-fusion: True pds-fusion: True
......
arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_1
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8 ...@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
......
arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 5_3_3_5
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_224_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 16
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_256_256_320
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
...@@ -21,7 +21,7 @@ clip-norm: 10.0 ...@@ -21,7 +21,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 0.0014
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
......
...@@ -2,7 +2,6 @@ arch: pdss2t_transformer_m_32 ...@@ -2,7 +2,6 @@ arch: pdss2t_transformer_m_32
encoder-embed-dim: 512 encoder-embed-dim: 512
pds-stages: 5 pds-stages: 5
#pds-dropout: 0
pds-layers: 2_2_3_3_2 pds-layers: 2_2_3_3_2
pds-ratios: 2_2_2_2_2 pds-ratios: 2_2_2_2_2
pds-fusion: True pds-fusion: True
...@@ -21,7 +20,7 @@ clip-norm: 10.0 ...@@ -21,7 +20,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 0.0014
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
......
...@@ -20,7 +20,7 @@ clip-norm: 10.0 ...@@ -20,7 +20,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 0.0014
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
......
arch: s2t_ctc
encoder-type: transformer
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.002
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
...@@ -12,7 +12,7 @@ encoder-type: pds ...@@ -12,7 +12,7 @@ encoder-type: pds
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 #ctc-layer: 12
pds-layers: 2_2_6_2 pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2 pds-ratios: 2_2_2_2
pds-fusion: True pds-fusion: True
......
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_9_3
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 160_192_224_256
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 18
encoder-attention-heads: 4
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_256_320
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 320
pds-stages: 4
#ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_256_320
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
...@@ -12,7 +12,7 @@ encoder-type: pds ...@@ -12,7 +12,7 @@ encoder-type: pds
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
......
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 5_3_3_5
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_224_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 16
encoder-attention-heads: 4
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_256_256_320
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
gpu_num=1 gpu_num=1
data_dir= data_dir=
test_subset=(dev-clean dev-other test-clean test-other) test_subset=(dev-clean dev-other test-clean test-other all)
exp_name= exp_name=
if [ "$#" -eq 1 ]; then if [ "$#" -eq 1 ]; then
...@@ -13,7 +13,7 @@ fi ...@@ -13,7 +13,7 @@ fi
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=80000 max_tokens=100000
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
......
...@@ -55,7 +55,7 @@ exp_tag=baseline ...@@ -55,7 +55,7 @@ exp_tag=baseline
exp_name= exp_name=
# config # config
train_config=ctc train_config=base
data_config=config.yaml data_config=config.yaml
# training setting # training setting
...@@ -190,32 +190,19 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -190,32 +190,19 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--fp16" --fp16"
fi fi
if [[ $step_valid -eq 1 ]]; then if [[ $step_valid -eq 1 ]]; then
validate_interval=1
save_interval=1 save_interval=1
keep_last_epochs=10
no_epoch_checkpoints=0 no_epoch_checkpoints=0
save_interval_updates=500 save_interval_updates=500
keep_interval_updates=10 keep_interval_updates=10
else
validate_interval=1
keep_last_epochs=10
fi fi
if [[ -n $no_epoch_checkpoints && $no_epoch_checkpoints -eq 1 ]]; then if [[ -n $no_epoch_checkpoints && $no_epoch_checkpoints -eq 1 ]]; then
cmd="$cmd cmd="$cmd
--no-epoch-checkpoints" --no-epoch-checkpoints"
fi fi
if [[ -n $validate_interval ]]; then
cmd="${cmd}
--validate-interval $validate_interval "
fi
if [[ -n $save_interval ]]; then if [[ -n $save_interval ]]; then
cmd="${cmd} cmd="${cmd}
--save-interval $save_interval " --save-interval $save_interval "
fi fi
if [[ -n $keep_last_epochs ]]; then
cmd="${cmd}
--keep-last-epochs $keep_last_epochs "
fi
if [[ -n $save_interval_updates ]]; then if [[ -n $save_interval_updates ]]; then
cmd="${cmd} cmd="${cmd}
--save-interval-updates $save_interval_updates" --save-interval-updates $save_interval_updates"
...@@ -271,10 +258,19 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -271,10 +258,19 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ -z ${cer} && ${cer} -eq 1 ]]; then
suffix=${suffix}_cer
else
suffix=${suffix}_wer
fi
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=(${test_subset//,/ }) test_subset=${test_subset//,/ }
for subset in ${test_subset[@]}; do for subset in ${test_subset[@]}; do
subset=${subset} subset=${subset}
cmd="python ${code_dir}/fairseq_cli/generate.py cmd="python ${code_dir}/fairseq_cli/generate.py
...@@ -293,6 +289,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -293,6 +289,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file} tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi fi
done done
cat ${result_file} cat ${result_file}
......
...@@ -12,14 +12,45 @@ extra_parameter= ...@@ -12,14 +12,45 @@ extra_parameter=
#extra_parameter="${extra_parameter} " #extra_parameter="${extra_parameter} "
exp_tag= exp_tag=
# Transformer
#config_list=(base) #config_list=(base)
#config_list=(pds_base_16)
#config_list=(pds_base_8)
# CTC
#config_list=(purectc_base)
#config_list=(purectc_pds_base_8)
#config_list=(purectc_pds_base_8_growth)
#config_list=(purectc_pds_base_8_growth_fusion256)
#config_list=(purectc_pds_base_16)
#config_list=(purectc_pds_base_16_growth)
#config_list=(purectc_pds_base_16_growth_fusion256)
#config_list=(purectc_pds_base_16_growth_fusion320)
# conformer
#config_list=(base conformer) #config_list=(base conformer)
#config_list=(ConformerCTCSmall) #config_list=(big conformer)
#config_list=(pds_base_4 conformer)
#config_list=(pds_base_16 conformer)
config_list=(pds_base_32 conformer)
#config_list=(pds_big_8 conformer)
#config_list=(pds_big_16 conformer)
#config_list=(pds_big_32 conformer)
#config_list=(pds_base_8_growth_fusion256 conformer)
# growth validation
#config_list=(pds_base_8_growth)
#config_list=(pds_base_8_growth_fusion256)
#config_list=(pds_base_16_growth_fusion256)
#config_list=(pds_base_16_growth)
config_list=(purectc_pds_base_16) # compare with Effective
#config_list=(pds_base) #config_list=(purectc_base_compare)
#config_list=(pds_big) #config_list=(purectc_pds_base_8_compare)
#config_list=(pds_deep) #config_list=(purectc_pds_base_8_compare2)
#config_list=(EffecientConformerCTCSmall)
#config_list=(purectc_pds_base_16)
# exp full name # exp full name
exp_name= exp_name=
......
...@@ -17,3 +17,4 @@ log-interval: 100 ...@@ -17,3 +17,4 @@ log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
skip-invalid-size-inputs-valid-test: True skip-invalid-size-inputs-valid-test: True
post-process: sentencepiece
\ No newline at end of file
ctc-weight: 0.2 ctc-weight: 0.2
intermedia-ctc-layers: 6,9 interleaved-ctc-weight: 0.1
intermedia-adapter: league interleaved-ctc-layers: 6,9
intermedia-ctc-weight: 0.1 interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
sae-adapter: league
sae-drop-prob: 0.2
sae-distribution-cutoff: 10
share-ctc-and-sae: False
ctc-self-distill-weight: 0 ctc-self-distill-weight: 0
post-process: sentencepiece
\ No newline at end of file
inter_mixup: True inter_mixup: True
inter_mixup_layer: 0 inter_mixup_layer: -1
inter_mixup_prob: 1.0
inter_mixup_ratio: 0.2 inter_mixup_ratio: 0.2
\ No newline at end of file
...@@ -240,13 +240,9 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -240,13 +240,9 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [[ $step_valid -eq 1 ]]; then if [[ $step_valid -eq 1 ]]; then
validate_interval=1 validate_interval=1
save_interval=1 save_interval=1
keep_last_epochs=10
no_epoch_checkpoints=0 no_epoch_checkpoints=0
save_interval_updates=500 save_interval_updates=500
keep_interval_updates=10 keep_interval_updates=10
else
validate_interval=1
keep_last_epochs=10
fi fi
if [[ -n $no_epoch_checkpoints && $no_epoch_checkpoints -eq 1 ]]; then if [[ -n $no_epoch_checkpoints && $no_epoch_checkpoints -eq 1 ]]; then
cmd="$cmd cmd="$cmd
...@@ -260,10 +256,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -260,10 +256,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
cmd="${cmd} cmd="${cmd}
--save-interval $save_interval " --save-interval $save_interval "
fi fi
if [[ -n $keep_last_epochs ]]; then
cmd="${cmd}
--keep-last-epochs $keep_last_epochs "
fi
if [[ -n $save_interval_updates ]]; then if [[ -n $save_interval_updates ]]; then
cmd="${cmd} cmd="${cmd}
--save-interval-updates $save_interval_updates" --save-interval-updates $save_interval_updates"
...@@ -282,11 +274,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -282,11 +274,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
mv tmp.log $log mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
cmd="nohup ${cmd} >> ${model_dir}/train.log 2>&1 &" log=${model_dir}/train.log
cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
sleep 2s sleep 2s
tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log}
fi fi
fi fi
wait wait
...@@ -319,7 +312,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -319,7 +312,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ -z ${cer} && ${cer} -eq 1 ]]; then
suffix=${suffix}_cer
else
suffix=${suffix}_wer
fi
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=${test_subset//,/ } test_subset=${test_subset//,/ }
...@@ -351,8 +353,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -351,8 +353,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file} tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi fi
done done
cat ${result_file} cat ${result_file}
fi fi
train-subset: train train-subset: train
valid-subset: dev valid-subset: valid
max-epoch: 50 max-epoch: 50
max-update: 100000 max-update: 100000
......
...@@ -351,10 +351,14 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -351,10 +351,14 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=(${test_subset//,/ }) test_subset=${test_subset//,/ }
for subset in ${test_subset[@]}; do for subset in ${test_subset[@]}; do
cmd="python ${code_dir}/fairseq_cli/generate.py cmd="python ${code_dir}/fairseq_cli/generate.py
${data_dir} ${data_dir}
...@@ -385,6 +389,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -385,6 +389,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file} tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi fi
done done
cat ${result_file} cat ${result_file}
......
...@@ -17,3 +17,4 @@ log-interval: 100 ...@@ -17,3 +17,4 @@ log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
skip-invalid-size-inputs-valid-test: True skip-invalid-size-inputs-valid-test: True
post-process: sentencepiece
\ No newline at end of file
...@@ -45,6 +45,6 @@ decoder-ffn-embed-dim: 2048 ...@@ -45,6 +45,6 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-asr-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/0225_st_purectc_pds_base_8_baseline_topctc/avg_10_checkpoint.pt #load-pretrained-asr-encoder-from:
#load-pretrained-mt-encoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt #load-pretrained-mt-encoder-from:
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt #load-pretrained-decoder-from:
\ No newline at end of file \ No newline at end of file
ctc-weight: 0.2 ctc-weight: 0.2
intermedia-ctc-weight: 0.1 interleaved-ctc-weight: 0.1
intermedia-ctc-layers: 6,9 interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
#target-ctc-weight: 0.3 #target-ctc-weight: 0.3
#target-ctc-layer: 6 #target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1 target-interleaved-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4 target-interleaved-ctc-layers: 2,4
intermedia-adapter: league sae-adapter: league
#intermedia-drop-prob: 0.2 sae-drop-prob: 0.0
#intermedia-temperature: 5 #sae-distribution-cutoff: 10
share-ctc-and-sae: False
share-target-ctc-and-sae: False
ctc-self-distill-weight: 0 ctc-self-distill-weight: 0
\ No newline at end of file
post-process: sentencepiece
\ No newline at end of file
inter_mixup: True
inter_mixup_layer: -1
inter_mixup_prob: 1.0
inter_mixup_ratio: 0.2
\ No newline at end of file
...@@ -369,7 +369,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -369,7 +369,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
if [[ ${sacrebleu} -eq 1 ]]; then
suffix=${suffix}_sacrebleu
else
suffix=${suffix}_multibleu
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=${test_subset//,/ } test_subset=${test_subset//,/ }
...@@ -402,6 +411,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -402,6 +411,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file} tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi fi
done done
cat ${result_file} cat ${result_file}
......
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
...@@ -14,7 +14,7 @@ sacrebleu=0 ...@@ -14,7 +14,7 @@ sacrebleu=0
n_average=5 n_average=5
beam_size=4 beam_size=4
len_penalty=0.6 len_penalty=0.6
max_tokens=80000 max_tokens=4000
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
......
...@@ -12,6 +12,7 @@ extra_parameter= ...@@ -12,6 +12,7 @@ extra_parameter=
#extra_parameter="${extra_parameter} " #extra_parameter="${extra_parameter} "
exp_tag=baseline exp_tag=baseline
config_list=(base)
config_list=(deep) config_list=(deep)
# exp full name # exp full name
......
...@@ -484,7 +484,7 @@ def main(): ...@@ -484,7 +484,7 @@ def main():
default="unigram", default="unigram",
required=True, required=True,
type=str, type=str,
choices=["bpe", "unigram", "char"], choices=["word", "bpe", "unigram", "char"],
), ),
parser.add_argument("--vocab-size", default=8000, type=int) parser.add_argument("--vocab-size", default=8000, type=int)
parser.add_argument("--share", action="store_true", parser.add_argument("--share", action="store_true",
......
...@@ -296,6 +296,7 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): ...@@ -296,6 +296,7 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
if arg_overrides is not None: if arg_overrides is not None:
overwrite_args_by_name(state["cfg"], arg_overrides) overwrite_args_by_name(state["cfg"], arg_overrides)
if len(state.keys()) != 1:
state = _upgrade_state_dict(state) state = _upgrade_state_dict(state)
return state return state
......
...@@ -43,17 +43,17 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -43,17 +43,17 @@ class CtcCriterionConfig(FairseqDataclass):
default=0.0, default=0.0,
metadata={"help": "weight of CTC entropy"}, metadata={"help": "weight of CTC entropy"},
) )
intermedia_ctc_weight: float = field( interleaved_ctc_weight: float = field(
default=0.0, default=0.0,
metadata={"help": "weight of intermedia CTC loss"}, metadata={"help": "weight of interleaved CTC loss"},
) )
target_ctc_weight: float = field( target_ctc_weight: float = field(
default=0.0, default=0.0,
metadata={"help": "weight of CTC loss for target sentence"}, metadata={"help": "weight of CTC loss for target sentence"},
) )
target_intermedia_ctc_weight: float = field( target_interleaved_ctc_weight: float = field(
default=0.0, default=0.0,
metadata={"help": "weight of intermedia CTC loss for target sentence"}, metadata={"help": "weight of interleaved CTC loss for target sentence"},
) )
ctc_self_distill_weight: float = field( ctc_self_distill_weight: float = field(
default=0.0, default=0.0,
...@@ -127,13 +127,13 @@ class CtcCriterion(FairseqCriterion): ...@@ -127,13 +127,13 @@ class CtcCriterion(FairseqCriterion):
self.sentence_avg = cfg.sentence_avg self.sentence_avg = cfg.sentence_avg
self.ctc_weight = ctc_weight self.ctc_weight = ctc_weight
self.intermedia_ctc_weight = cfg.intermedia_ctc_weight self.interleaved_ctc_weight = cfg.interleaved_ctc_weight
self.target_ctc_weight = cfg.target_ctc_weight self.target_ctc_weight = cfg.target_ctc_weight
self.target_intermedia_ctc_weight = cfg.target_intermedia_ctc_weight self.target_interleaved_ctc_weight = cfg.target_interleaved_ctc_weight
self.ctc_self_distill_weight = cfg.ctc_self_distill_weight self.ctc_self_distill_weight = cfg.ctc_self_distill_weight
self.ctc_entropy = cfg.ctc_entropy self.ctc_entropy = cfg.ctc_entropy
self.all_ctc_weight = self.ctc_weight + self.intermedia_ctc_weight + \ self.all_ctc_weight = self.ctc_weight + self.interleaved_ctc_weight + \
self.target_ctc_weight + self.target_intermedia_ctc_weight + \ self.target_ctc_weight + self.target_interleaved_ctc_weight + \
self.ctc_self_distill_weight + self.ctc_entropy self.ctc_self_distill_weight + self.ctc_entropy
if self.all_ctc_weight > 0: if self.all_ctc_weight > 0:
...@@ -188,6 +188,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -188,6 +188,7 @@ class CtcCriterion(FairseqCriterion):
pad_mask = (tokens != self.pad_idx) & ( pad_mask = (tokens != self.pad_idx) & (
tokens != self.eos_idx tokens != self.eos_idx
) )
if mixup: if mixup:
mask1 = pad_mask[mixup_idx1] mask1 = pad_mask[mixup_idx1]
mask2 = pad_mask[mixup_idx2] mask2 = pad_mask[mixup_idx2]
...@@ -222,19 +223,20 @@ class CtcCriterion(FairseqCriterion): ...@@ -222,19 +223,20 @@ class CtcCriterion(FairseqCriterion):
# ctc_logit = ctc_logit / ctc_logit.sum(dim=-1, keepdim=True) # ctc_logit = ctc_logit / ctc_logit.sum(dim=-1, keepdim=True)
# cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100] # cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100]
# ctc_entropy = Categorical(logits=cut_ctc_logit).entropy().sum() # ctc_entropy = Categorical(logits=cut_ctc_logit).entropy().sum()
ctc_entropy = Categorical(logits=ctc_logit).entropy().sum() ctc_entropy = Categorical(logits=ctc_logit).entropy().sum()
logging_output["ctc_entropy"] = utils.item(ctc_entropy.data) logging_output["ctc_entropy"] = utils.item(ctc_entropy.data)
logging_output["ctc_loss"] = utils.item(ctc_loss.data) logging_output["ctc_loss"] = utils.item(ctc_loss.data)
intermedia_ctc_num = 0 interleaved_ctc_num = 0
intermedia_ctc_loss = 0 interleaved_ctc_loss = 0
if "intermedia_ctc_logits" in net_output: if "interleaved_ctc_logits" in net_output:
intermedia_ctc_num = len(net_output["intermedia_ctc_logits"]) interleaved_ctc_num = len(net_output["interleaved_ctc_logits"])
# calculate the intermedia CTC loss # calculate the interleaved CTC loss
if self.intermedia_ctc_weight > 0 and intermedia_ctc_num > 0: if self.interleaved_ctc_weight > 0 and interleaved_ctc_num > 0:
for i in range(intermedia_ctc_num): for i in range(interleaved_ctc_num):
out = net_output["intermedia_ctc_logits"][i] out = net_output["interleaved_ctc_logits"][i]
if type(out) == list: if type(out) == list:
inter_ctc_logit = out[0] inter_ctc_logit = out[0]
padding = ~out[1] padding = ~out[1]
...@@ -249,19 +251,19 @@ class CtcCriterion(FairseqCriterion): ...@@ -249,19 +251,19 @@ class CtcCriterion(FairseqCriterion):
inter_lprobs.batch_first = False inter_lprobs.batch_first = False
for flat, lengths, coef in zip(transcript_flat, transcript_lengths, loss_coef): for flat, lengths, coef in zip(transcript_flat, transcript_lengths, loss_coef):
intermedia_ctc_loss += self.get_loss(inter_lprobs, flat, inter_input_lengths, lengths) * coef interleaved_ctc_loss += self.get_loss(inter_lprobs, flat, inter_input_lengths, lengths) * coef
intermedia_ctc_loss /= intermedia_ctc_num interleaved_ctc_loss /= interleaved_ctc_num
logging_output["intermedia_ctc_loss"] = utils.item(intermedia_ctc_loss.data) logging_output["interleaved_ctc_loss"] = utils.item(interleaved_ctc_loss.data)
if lprobs is None: if lprobs is None:
lprobs = inter_lprobs lprobs = inter_lprobs
target_ctc_loss = 0 target_ctc_loss = 0
target_intermedia_ctc_loss = 0 target_interleaved_ctc_loss = 0
# calculate the target CTC loss # calculate the target CTC loss
if self.target_ctc_weight > 0 or self.target_intermedia_ctc_weight: if self.target_ctc_weight > 0 or self.target_interleaved_ctc_weight:
target = sample["target"] target = sample["target"]
pad_mask = (target != self.pad_idx) & (target != self.eos_idx) pad_mask = (target != self.pad_idx) & (target != self.eos_idx)
...@@ -292,12 +294,12 @@ class CtcCriterion(FairseqCriterion): ...@@ -292,12 +294,12 @@ class CtcCriterion(FairseqCriterion):
for flat, lengths, coef in zip(target_flat, target_length, loss_coef): for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_ctc_loss += self.get_loss(tgt_lprobs, flat, input_lengths, lengths) * coef target_ctc_loss += self.get_loss(tgt_lprobs, flat, input_lengths, lengths) * coef
target_intermedia_ctc_num = 0 target_interleaved_ctc_num = 0
if "target_intermedia_ctc_logits" in net_output: if "target_interleaved_ctc_logits" in net_output:
target_intermedia_ctc_num = len(net_output["target_intermedia_ctc_logits"]) target_interleaved_ctc_num = len(net_output["target_interleaved_ctc_logits"])
for i in range(target_intermedia_ctc_num): for i in range(target_interleaved_ctc_num):
out = net_output["target_intermedia_ctc_logits"][i] out = net_output["target_interleaved_ctc_logits"][i]
if type(out) == list: if type(out) == list:
inter_ctc_logit = out[0] inter_ctc_logit = out[0]
padding = ~out[1] padding = ~out[1]
...@@ -312,17 +314,17 @@ class CtcCriterion(FairseqCriterion): ...@@ -312,17 +314,17 @@ class CtcCriterion(FairseqCriterion):
tgt_inter_lprobs.batch_first = False tgt_inter_lprobs.batch_first = False
for flat, lengths, coef in zip(target_flat, target_length, loss_coef): for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_intermedia_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths, lengths) * coef target_interleaved_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths, lengths) * coef
target_intermedia_ctc_loss /= target_intermedia_ctc_num target_interleaved_ctc_loss /= target_interleaved_ctc_num
logging_output["target_intermedia_ctc_loss"] = utils.item(target_intermedia_ctc_loss.data) logging_output["target_interleaved_ctc_loss"] = utils.item(target_interleaved_ctc_loss.data)
# calculate the self distillation CTC loss # calculate the self distillation CTC loss
ctc_self_distill_loss = 0 ctc_self_distill_loss = 0
ctc_self_distill_num = 0 ctc_self_distill_num = 0
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and intermedia_ctc_num > 0: if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0:
for i in range(intermedia_ctc_num): for i in range(interleaved_ctc_num):
out = net_output["intermedia_ctc_logits"][i] out = net_output["interleaved_ctc_logits"][i]
if type(out) == list: if type(out) == list:
inter_ctc_logit = out[0] inter_ctc_logit = out[0]
padding = ~out[1] padding = ~out[1]
...@@ -347,9 +349,9 @@ class CtcCriterion(FairseqCriterion): ...@@ -347,9 +349,9 @@ class CtcCriterion(FairseqCriterion):
loss = \ loss = \
self.ctc_weight * ctc_loss + \ self.ctc_weight * ctc_loss + \
self.intermedia_ctc_weight * intermedia_ctc_loss + \ self.interleaved_ctc_weight * interleaved_ctc_loss + \
self.target_ctc_weight * target_ctc_loss + \ self.target_ctc_weight * target_ctc_loss + \
self.target_intermedia_ctc_weight * target_intermedia_ctc_loss + \ self.target_interleaved_ctc_weight * target_interleaved_ctc_loss + \
self.ctc_self_distill_weight * ctc_self_distill_loss + \ self.ctc_self_distill_weight * ctc_self_distill_loss + \
self.ctc_entropy * ctc_entropy self.ctc_entropy * ctc_entropy
...@@ -359,8 +361,8 @@ class CtcCriterion(FairseqCriterion): ...@@ -359,8 +361,8 @@ class CtcCriterion(FairseqCriterion):
logger.warning("Illegal loss %f!" % loss) logger.warning("Illegal loss %f!" % loss)
if self.ctc_weight != 0: if self.ctc_weight != 0:
logger.warning("CTC loss %f!" % ctc_loss) logger.warning("CTC loss %f!" % ctc_loss)
if self.intermedia_ctc_weight != 0: if self.interleaved_ctc_weight != 0:
logger.warning("Intermedia CTC loss %f!" % intermedia_ctc_loss) logger.warning("Intermedia CTC loss %f!" % interleaved_ctc_loss)
if self.target_ctc_weight != 0: if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss) logger.warning("Target CTC loss %f!" % target_ctc_loss)
...@@ -448,13 +450,13 @@ class CtcCriterion(FairseqCriterion): ...@@ -448,13 +450,13 @@ class CtcCriterion(FairseqCriterion):
sum(log.get("ctc_entropy", 0) for log in logging_outputs) sum(log.get("ctc_entropy", 0) for log in logging_outputs)
) )
inter_ctc_loss_sum = utils.item( inter_ctc_loss_sum = utils.item(
sum(log.get("intermedia_ctc_loss", 0) for log in logging_outputs) sum(log.get("interleaved_ctc_loss", 0) for log in logging_outputs)
) )
target_ctc_loss_sum = utils.item( target_ctc_loss_sum = utils.item(
sum(log.get("target_ctc_loss", 0) for log in logging_outputs) sum(log.get("target_ctc_loss", 0) for log in logging_outputs)
) )
target_intermedia_ctc_loss_sum = utils.item( target_interleaved_ctc_loss_sum = utils.item(
sum(log.get("target_intermedia_ctc_loss", 0) for log in logging_outputs) sum(log.get("target_interleaved_ctc_loss", 0) for log in logging_outputs)
) )
ctc_self_distill_loss_sum = utils.item( ctc_self_distill_loss_sum = utils.item(
sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs) sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs)
...@@ -505,7 +507,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -505,7 +507,7 @@ class CtcCriterion(FairseqCriterion):
) )
if inter_ctc_loss_sum > 0: if inter_ctc_loss_sum > 0:
metrics.log_scalar( metrics.log_scalar(
"intermedia_ctc_loss", "interleaved_ctc_loss",
inter_ctc_loss_sum / sample_size / math.log(2), inter_ctc_loss_sum / sample_size / math.log(2),
sample_size, sample_size,
round=3, round=3,
...@@ -517,10 +519,10 @@ class CtcCriterion(FairseqCriterion): ...@@ -517,10 +519,10 @@ class CtcCriterion(FairseqCriterion):
sample_size, sample_size,
round=3, round=3,
) )
if target_intermedia_ctc_loss_sum > 0: if target_interleaved_ctc_loss_sum > 0:
metrics.log_scalar( metrics.log_scalar(
"target_intermedia_ctc_loss", "target_interleaved_ctc_loss",
target_intermedia_ctc_loss_sum / sample_size / math.log(2), target_interleaved_ctc_loss_sum / sample_size / math.log(2),
sample_size, sample_size,
round=3, round=3,
) )
......
...@@ -89,6 +89,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -89,6 +89,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
if self.ctc_criterion.all_ctc_weight > 0: if self.ctc_criterion.all_ctc_weight > 0:
ctc_loss, logging_output = self.ctc_criterion.compute_ctc_loss(model, sample, encoder_out, logging_output) ctc_loss, logging_output = self.ctc_criterion.compute_ctc_loss(model, sample, encoder_out, logging_output)
loss = (1 - self.ctc_weight) * loss + ctc_loss loss = (1 - self.ctc_weight) * loss + ctc_loss
# if hasattr(model.encoder, "get_loss"):
# encoder_loss = model.encoder.get_loss()
# if encoder_loss != 0:
# loss += encoder_loss * sample_size
# logging_output["encoder_loss"] = utils.item(encoder_loss.data)
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output return loss, sample_size, logging_output
...@@ -103,6 +109,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -103,6 +109,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
nll_loss_sum = utils.item( nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs) sum(log.get("nll_loss", 0) for log in logging_outputs)
) )
enc_loss_sum = utils.item(
sum(log.get("encoder_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item( sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs) sum(log.get("sample_size", 0) for log in logging_outputs)
...@@ -121,6 +130,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -121,6 +130,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
metrics.log_derived( metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
) )
if enc_loss_sum != 0:
metrics.log_scalar("enc_loss", enc_loss_sum, sample_size, round=3)
if "ctc_loss" in logging_outputs[0] or "all_ctc_loss" in logging_outputs[0]: if "ctc_loss" in logging_outputs[0] or "all_ctc_loss" in logging_outputs[0]:
CtcCriterion.reduce_metrics(logging_outputs) CtcCriterion.reduce_metrics(logging_outputs)
......
...@@ -13,7 +13,7 @@ from fairseq.models import ( ...@@ -13,7 +13,7 @@ from fairseq.models import (
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.models.speech_to_text import S2TTransformerModel from .s2t_transformer import S2TTransformerModel
from fairseq.modules.speech_to_text import CTC, Adapter from fairseq.modules.speech_to_text import CTC, Adapter
from fairseq.modules import ( from fairseq.modules import (
...@@ -141,334 +141,13 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -141,334 +141,13 @@ class PDSS2TTransformerModel(S2TTransformerModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# subsampling S2TTransformerModel.add_args(parser)
parser.add_argument( PDSS2TTransformerModel.add_specific_args(parser)
"--subsampling-type",
type=str,
help="subsampling type, like conv1d and conv2d",
)
parser.add_argument(
"--subsampling-layers",
type=int,
help="subsampling layers",
)
parser.add_argument(
"--subsampling-filter",
type=int,
help="subsampling filter",
)
parser.add_argument(
"--subsampling-kernel",
type=int,
help="subsampling kernel",
)
parser.add_argument(
"--subsampling-stride",
type=int,
help="subsampling stride",
)
parser.add_argument(
"--subsampling-norm",
type=str,
default="none",
help="subsampling normalization type",
)
parser.add_argument(
"--subsampling-activation",
type=str,
default="none",
help="subsampling activation function type",
)
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
"rel_pos_legacy",
"rel_pos",
"rope",
"abs",
"transfer",
"reduced_rel_pos",
],
help="transformer encoder self-attention layer type"
)
# transfer
parser.add_argument(
"--relative-pos-enc",
action="store_true",
help="use relative position encoding for attention",
)
parser.add_argument(
"--linear-att",
action="store_true",
help="use linear attention",
)
# reduced attention @staticmethod
parser.add_argument( def add_specific_args(parser):
"--attention-reduced-method", """Add specific arguments to the parser."""
type=str, # PDS setting
default="conv",
help="reduction method for attention",
)
parser.add_argument(
"--attention-reduced-q",
action="store_true",
help="use reduction for query or not"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument('--share-all-embeddings',
action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument('--init-value', type=str, default='avg', choices=['avg', 'one'],
help='how to init the learned weight matrix')
parser.add_argument('--weight-type', type=str, default='scalar',
help='type of learned weight [scalar, scalar_n(n>1), vector]')
parser.add_argument('--encoder-learnable', type=eval, default='True',
help='enable to learn weights for encoder')
parser.add_argument('--decoder-learnable', type=eval, default='True',
help='enable to learn weights for decoder')
parser.add_argument('--normalize-learned-weight', type=eval, default='False',
help='normalize learned weight by softmax')
parser.add_argument('--normalize-embedding', type=eval, default='False',
help='normalize the input of embedding')
parser.add_argument('--history-dropout', type=float, default=0.0, metavar='D',
help='dropout for history output')
parser.add_argument('--history-window-size', type=int, default='-1',
help='how many past layers are considered. -1 means all')
# CTC
parser.add_argument(
"--ctc-layer",
default=0,
type=int,
help="the position of the ctc loss",
)
# local modeling
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument(
"--encoder-activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# pds setting
parser.add_argument( parser.add_argument(
"--pds-stages", "--pds-stages",
type=int, type=int,
...@@ -561,69 +240,6 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -561,69 +240,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
help="use the ctc after each stage", help="use the ctc after each stage",
) )
# intermedia ctc
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
parser.add_argument(
"--intermedia-temperature",
default=1,
type=float,
help="temperature of the intermedia ctc probability",
)
# mixup
parser.add_argument(
"--inter-mixup",
action="store_true",
help="use mixup or not",
)
parser.add_argument(
"--inter-mixup-layer",
default=None,
type=int,
help="the layers for mixup",
)
parser.add_argument(
"--inter-mixup-beta",
default=0.5,
type=float,
help="the coefficient beta for mixup",
)
parser.add_argument(
"--inter-mixup-prob",
default=1,
type=float,
help="the probability for mixup",
)
parser.add_argument(
"--inter-mixup-ratio",
default=1,
type=float,
help="the ratio for mixup",
)
pass
@classmethod @classmethod
def build_encoder(cls, args, task=None, embed_tokens=None): def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = PDSS2TTransformerEncoder(args, task, embed_tokens) encoder = PDSS2TTransformerEncoder(args, task, embed_tokens)
...@@ -707,7 +323,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -707,7 +323,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
args.pds_ctc = getattr(args, "pds_ctc", None) args.pds_ctc = getattr(args, "pds_ctc", None)
self.pds_ctc = [int(n) for n in args.pds_ctc.split("_")] if args.pds_ctc is not None else None self.pds_ctc = [int(n) for n in args.pds_ctc.split("_")] if args.pds_ctc is not None else None
inter_ctc_module = None inter_ctc_module = None
inter_adapter = None sae_adapter = None
for i in range(self.pds_stages): for i in range(self.pds_stages):
num_layers = self.pds_layers[i] num_layers = self.pds_layers[i]
...@@ -833,7 +449,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -833,7 +449,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
else: else:
logger.error("Unsupported fusion transform!") logger.error("Unsupported fusion transform!")
# intermedia modules for each stage # interleaved modules for each stage
if use_ctc: if use_ctc:
if inter_ctc_module is None: if inter_ctc_module is None:
ctc = CTC(embed_dim, ctc = CTC(embed_dim,
...@@ -847,7 +463,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -847,7 +463,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
else: else:
ctc = inter_ctc_module ctc = inter_ctc_module
if i != self.pds_stages - 1: if i != self.pds_stages - 1:
if inter_adapter is None: if sae_adapter is None:
strategy = None strategy = None
if args.intermedia_adapter == "shrink": if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", "avg") strategy = getattr(args, "ctc_compress_strategy", "avg")
...@@ -877,10 +493,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -877,10 +493,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.fusion_weight = nn.Parameter(torch.Tensor(fusion_stages_num).fill_(1.0)) self.fusion_weight = nn.Parameter(torch.Tensor(fusion_stages_num).fill_(1.0))
self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True) self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True)
# self.use_ctc = "sate" in args.arch or \
# (getattr(args, "criterion", "") == "ctc") or \
# (("ctc" in getattr(args, "criterion", "")) and
# (getattr(args, "ctc_weight", False) > 0))
self.use_ctc = "sate" in args.arch or (getattr(args, "ctc_weight", 0) > 0) self.use_ctc = "sate" in args.arch or (getattr(args, "ctc_weight", 0) > 0)
if self.use_ctc: if self.use_ctc:
# self.ctc_layer = (args.ctc_layer + self.layers) % self.layers # self.ctc_layer = (args.ctc_layer + self.layers) % self.layers
...@@ -890,7 +502,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -890,7 +502,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.ctc_layer = args.ctc_layer self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != 0 else False self.inter_ctc = True if self.ctc_layer != 0 else False
if self.inter_ctc: if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer) logger.info("Interleaved CTC loss in layer %d" % self.ctc_layer)
# embed_dim = self.pds_embed_dims[-1] # embed_dim = self.pds_embed_dims[-1]
embed_dim = self.embed_dim embed_dim = self.embed_dim
...@@ -1105,6 +717,12 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -1105,6 +717,12 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
seq_len = x.size(0) seq_len = x.size(0)
for state in prev_state: for state in prev_state:
i += 1 i += 1
# padding = prev_padding[i]
# if padding is not None:
# zero_padding = padding.transpose(0, 1).unsqueeze(2)
# state.masked_fill_(zero_padding, 0.0)
fusion_downsampling = getattr(self, f"fusion_downsampling{i + 1}") fusion_downsampling = getattr(self, f"fusion_downsampling{i + 1}")
fusion_pre_layer_norm = getattr(self, f"fusion_pre_layer_norm{i + 1}") fusion_pre_layer_norm = getattr(self, f"fusion_pre_layer_norm{i + 1}")
fusion_post_layer_norm = getattr(self, f"fusion_post_layer_norm{i + 1}") fusion_post_layer_norm = getattr(self, f"fusion_post_layer_norm{i + 1}")
...@@ -1144,6 +762,22 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -1144,6 +762,22 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
"src_lengths": [], "src_lengths": [],
} }
def get_loss(self):
if not self.pds_fusion:
return 0
weight = self.fusion_weight
loss = 0
for i in range(self.fusion_stages_num - 1):
sub = weight[i] - weight[i + 1]
if sub > 0:
loss += sub
if weight[i] < 0:
loss += weight[i]
loss += (0.5 * (weight.sum() - 1.0) ** 2).mean()
return loss
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = ( new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0 [] if len(encoder_out["encoder_out"]) == 0
...@@ -1191,6 +825,7 @@ def base_architecture(args): ...@@ -1191,6 +825,7 @@ def base_architecture(args):
args.subsampling_norm = getattr(args, "subsampling_norm", "none") args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu") args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn") args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
...@@ -1211,6 +846,10 @@ def base_architecture(args): ...@@ -1211,6 +846,10 @@ def base_architecture(args):
args.activation_fn = getattr(args, "activation_fn", "relu") args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.share_decoder_input_output_embed = getattr( args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False args, "share_decoder_input_output_embed", False
) )
...@@ -1219,6 +858,7 @@ def base_architecture(args): ...@@ -1219,6 +858,7 @@ def base_architecture(args):
args, "no_token_positional_embeddings", False args, "no_token_positional_embeddings", False
) )
args.adaptive_input = getattr(args, "adaptive_input", False) args.adaptive_input = getattr(args, "adaptive_input", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr( args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim args, "decoder_output_dim", args.decoder_embed_dim
...@@ -1227,14 +867,41 @@ def base_architecture(args): ...@@ -1227,14 +867,41 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1) args.embed_linear = getattr(args, "embed_linear", False)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) # CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer # Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False) args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False) args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31) args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.cnn_module_norm = getattr(args, "cnn_module_norm", "batch_norm")
# Relative position encoding
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# interleaved CTC
args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 1)
# PDS # PDS
args.pds_stages = getattr(args, "pds_stages", None) args.pds_stages = getattr(args, "pds_stages", None)
...@@ -1254,23 +921,10 @@ def base_architecture(args): ...@@ -1254,23 +921,10 @@ def base_architecture(args):
args.pds_conv_strides = getattr(args, "pds_conv_strides", None) args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
args.pds_attn_strides = getattr(args, "pds_attn_strides", None) args.pds_attn_strides = getattr(args, "pds_attn_strides", None)
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout) args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
args.pds_fusion = getattr(args, "pds_fusion", False) args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv") args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC
args.pds_ctc = getattr(args, "pds_ctc", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 0.5)
def set_pds_base_8(args): def set_pds_base_8(args):
args.pds_stages = getattr(args, "pds_stages", 4) args.pds_stages = getattr(args, "pds_stages", 4)
......
...@@ -12,6 +12,8 @@ from fairseq.models import ( ...@@ -12,6 +12,8 @@ from fairseq.models import (
register_model_architecture, register_model_architecture,
) )
from .s2t_transformer import S2TTransformerModel, S2TTransformerEncoder
from .pdss2t_transformer import PDSS2TTransformerModel, PDSS2TTransformerEncoder
from torch import Tensor from torch import Tensor
...@@ -27,465 +29,8 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -27,465 +29,8 @@ class S2TCTCModel(FairseqEncoderModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# subsampling S2TTransformerModel.add_args(parser)
parser.add_argument( PDSS2TTransformerModel.add_specific_args(parser)
"--subsampling-type",
type=str,
help="subsampling type, like conv1d and conv2d",
)
parser.add_argument(
"--subsampling-layers",
type=int,
help="subsampling layers",
)
parser.add_argument(
"--subsampling-filter",
type=int,
help="subsampling filter",
)
parser.add_argument(
"--subsampling-kernel",
type=int,
help="subsampling kernel",
)
parser.add_argument(
"--subsampling-stride",
type=int,
help="subsampling stride",
)
parser.add_argument(
"--subsampling-norm",
type=str,
default="none",
help="subsampling normalization type",
)
parser.add_argument(
"--subsampling-activation",
type=str,
default="none",
help="subsampling activation function type",
)
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
"rel_pos",
"rope",
"abs",
"transfer",
"reduced_rel_pos",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--relative-pos-enc",
action="store_true",
help="use relative position encoding for attention",
)
parser.add_argument(
"--linear-att",
action="store_true",
help="use linear attention",
)
parser.add_argument(
"--attention-reduced-method",
type=str,
default="conv",
help="reduction method for attention",
)
parser.add_argument(
"--attention-reduced-q",
action="store_true",
help="use reduction for query or not",
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument('--share-all-embeddings',
action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument('--init-value', type=str, default='avg', choices=['avg', 'one'],
help='how to init the learned weight matrix')
parser.add_argument('--weight-type', type=str, default='scalar',
help='type of learned weight [scalar, scalar_n(n>1), vector]')
parser.add_argument('--encoder-learnable', type=eval, default='True',
help='enable to learn weights for encoder')
parser.add_argument('--decoder-learnable', type=eval, default='True',
help='enable to learn weights for decoder')
parser.add_argument('--normalize-learned-weight', type=eval, default='False',
help='normalize learned weight by softmax')
parser.add_argument('--normalize-embedding', type=eval, default='False',
help='normalize the input of embedding')
parser.add_argument('--history-dropout', type=float, default=0.0, metavar='D',
help='dropout for history output')
parser.add_argument('--history-window-size', type=int, default='-1',
help='how many past layers are considered. -1 means all')
# CTC
parser.add_argument(
"--ctc-layer",
default=0,
type=int,
help="the position of the ctc loss",
)
# local modeling
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument(
"--encoder-activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# Simultaneous speech translation
parser.add_argument(
"--simul",
default=False,
action="store_true",
help="Simultaneous speech translation or not",
)
# interleaved dropout
parser.add_argument('--interleave-dropout', type=int,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout',
action="store_true",
default=False,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout-epoch',
type=int,
default=None,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout-strategy',
type=str,
help='interleaved dropout probability')
# pds setting
parser.add_argument(
"--pds-stages",
type=int,
help="the number of the stage",
)
parser.add_argument(
"--pds-layers",
type=str,
help="the number of the encoder layers in each stage",
)
parser.add_argument(
"--pds-ratios",
type=str,
help="the ratio of the down-sampling in each stage",
)
parser.add_argument(
"--pds-ds-method",
type=str,
choices=["glu", "conv", "proj", "fusion"],
help="the down-sampling method",
)
parser.add_argument(
"--pds-embed-dims",
type=str,
help="the embedding dimension in each stage",
)
parser.add_argument(
"--pds-kernel-sizes",
type=str,
help="the kernel size of the down-sampling module in each stage",
)
parser.add_argument(
"--pds-embed-norm",
action="store_true",
help="use layer norm in the down-sampling module",
)
parser.add_argument(
"--pds-position-embed",
type=str,
help="use the position embedding or not before each encoding",
)
parser.add_argument(
"--pds-attn-heads",
type=str,
help="the number of the attention heads in each stage",
)
parser.add_argument(
"--pds-attn-ds-ratios",
type=str,
help="the ratio of the down-sampling in the self attention module",
)
parser.add_argument(
"--pds-ffn-ratios",
type=str,
help="the ratio of the ffn in each stage",
)
parser.add_argument(
"--pds-conv-strides",
type=str,
help="the strides of the convolutional module (conformer) in each stage",
)
parser.add_argument(
"--pds-attn-strides",
type=str,
help="the strides of the attention module (conformer) in each stage",
)
parser.add_argument(
"--pds-fusion",
action="store_true",
help="use the representation fusion method",
)
parser.add_argument(
"--pds-fusion-method",
type=str,
help="the fusion method",
)
parser.add_argument(
"--pds-dropout",
type=float,
help="dropout in each stage",
)
parser.add_argument(
"--pds-ctc",
type=str,
help="use the ctc after each stage",
)
# intermedia CTC loss
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
# encoder # encoder
parser.add_argument( parser.add_argument(
...@@ -497,7 +42,7 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -497,7 +42,7 @@ class S2TCTCModel(FairseqEncoderModel):
pass pass
@classmethod @classmethod
def build_encoder(cls, args, task=None, embed_tokens=None): def build_encoder(cls, args, task=None):
encoder = S2TCTCEncoder(args, task) encoder = S2TCTCEncoder(args, task)
if getattr(args, "load_pretrained_encoder_from", None): if getattr(args, "load_pretrained_encoder_from", None):
logger.info( logger.info(
...@@ -561,10 +106,8 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -561,10 +106,8 @@ class S2TCTCEncoder(FairseqEncoder):
setattr(args, "ctc_weight", 1.0) setattr(args, "ctc_weight", 1.0)
encoder_type = getattr(args, "encoder_type", "transformer") encoder_type = getattr(args, "encoder_type", "transformer")
if encoder_type == "transformer": if encoder_type == "transformer":
from .s2t_transformer import S2TTransformerEncoder
self.encoder = S2TTransformerEncoder(args, task) self.encoder = S2TTransformerEncoder(args, task)
elif encoder_type == "pds": elif encoder_type == "pds":
from .pdss2t_transformer import PDSS2TTransformerEncoder
self.encoder = PDSS2TTransformerEncoder(args, task) self.encoder = PDSS2TTransformerEncoder(args, task)
else: else:
logger.error("Unsupported architecture: %s." % encoder_type) logger.error("Unsupported architecture: %s." % encoder_type)
...@@ -701,9 +244,11 @@ def base_architecture(args): ...@@ -701,9 +244,11 @@ def base_architecture(args):
args.ctc_layer = getattr(args, "ctc_layer", 0) args.ctc_layer = getattr(args, "ctc_layer", 0)
# Conformer # Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False) args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False) args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31) args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.cnn_module_norm = getattr(args, "cnn_module_norm", "batch_norm")
# settings for DLCL # settings for DLCL
args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False) args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False)
...@@ -724,11 +269,23 @@ def base_architecture(args): ...@@ -724,11 +269,23 @@ def base_architecture(args):
args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0) args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.init_mask_weight = getattr(args, 'init_mask_weight', 0) args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
# interleaved dropout # interleaved CTC
args.interleave_dropout = getattr(args, "interleave_dropout", None) args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.cl_dropout = getattr(args, "cl_dropout", False) args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None) args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 1)
# PDS # PDS
args.pds_stages = getattr(args, "pds_stages", None) args.pds_stages = getattr(args, "pds_stages", None)
...@@ -737,26 +294,22 @@ def base_architecture(args): ...@@ -737,26 +294,22 @@ def base_architecture(args):
args.pds_ds_method = getattr(args, "pds_ds_method", "conv") args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pds_embed_dims = getattr(args, "pds_embed_dims", None) args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", True) args.pds_embed_norm = getattr(args, "pds_embed_norm", False)
args.pds_position_embed = getattr(args, "pds_position_embed", None) args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None) args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None) args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_cnn_kernel_sizes = getattr(args, "pds_cnn_kernel_sizes", None) args.pds_cnn_kernel_sizes = getattr(args, "pds_cnn_kernel_sizes", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1") args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1") args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1") args.pds_attn_strides = getattr(args, "pds_attn_strides", None)
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout) args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
args.pds_fusion = getattr(args, "pds_fusion", False) args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv") args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
@register_model_architecture("s2t_ctc", "s2t_ctc_s") @register_model_architecture("s2t_ctc", "s2t_ctc_s")
......
...@@ -76,10 +76,10 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -76,10 +76,10 @@ class S2TSATEModel(S2TTransformerModel):
help="share the projection weights of the ctc and adapter", help="share the projection weights of the ctc and adapter",
) )
parser.add_argument( parser.add_argument(
"--temperature", "--adapter-temperature",
default=1.0, default=1.0,
type=float, type=float,
help="temperature of the CTC softmax", help="temperature of the CTC softmax in adapter",
) )
parser.add_argument( parser.add_argument(
"--acoustic-encoder", "--acoustic-encoder",
...@@ -103,14 +103,19 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -103,14 +103,19 @@ class S2TSATEModel(S2TTransformerModel):
parser.add_argument( parser.add_argument(
"--target-ctc-layer", "--target-ctc-layer",
default=None, default=None,
type=str, type=int,
help="ctc layer for target sentence", help="ctc layer for target sentence",
) )
parser.add_argument( parser.add_argument(
"--target-intermedia-ctc-layers", "--target-interleaved-ctc-layers",
default=None, default=None,
type=str, type=str,
help="intermedia ctc layers for target sentence", help="interleaved ctc layers for target sentence",
)
parser.add_argument(
"--share-target-ctc-and-sae",
action="store_true",
help="share the weight of target ctc and sae",
) )
# freeze # freeze
parser.add_argument( parser.add_argument(
...@@ -225,38 +230,42 @@ class TextEncoder(FairseqEncoder): ...@@ -225,38 +230,42 @@ class TextEncoder(FairseqEncoder):
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = embed_tokens.weight
self.intermedia_ctc_layers = [] self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.target_intermedia_ctc_layers = getattr(args, "target_intermedia_ctc_layers", None) self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
if self.target_intermedia_ctc_layers is not None: self.interleaved_ctc_layers = []
target_intermedia_ctc_layers = self.target_intermedia_ctc_layers.split(",") self.target_interleaved_ctc_layers = getattr(args, "target_interleaved_ctc_layers", None)
for layer_idx in target_intermedia_ctc_layers: if self.target_interleaved_ctc_layers is not None:
target_interleaved_ctc_layers = self.target_interleaved_ctc_layers.split(",")
for layer_idx in target_interleaved_ctc_layers:
layer_idx = int(layer_idx) layer_idx = int(layer_idx)
assert layer_idx <= layer_num, (layer_idx, layer_num) assert layer_idx <= layer_num, (layer_idx, layer_num)
if layer_idx <= 0: if layer_idx <= 0:
layer_idx += layer_num layer_idx += layer_num
self.intermedia_ctc_layers.append(layer_idx) self.interleaved_ctc_layers.append(layer_idx)
logger.info("Intermedia target CTC loss in layer %d" % layer_idx) logger.info("Interleaved target CTC loss in layer %d" % layer_idx)
if not self.use_ctc:
self.ctc = CTC(embed_dim, self.ctc = CTC(embed_dim,
dictionary_size=len(dictionary), dictionary_size=len(dictionary),
dropout=args.dropout) dropout=args.dropout)
if embed_tokens is not None: if embed_tokens is not None:
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = None strategy = {
if args.intermedia_adapter == "shrink": "ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
strategy = getattr(args, "ctc_compress_strategy", None) "distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
elif args.intermedia_adapter == "league": "drop_prob": getattr(args, "sae_drop_prob", 0),
strategy = getattr(args, "intermedia_distribution_cutoff", None) }
self.adapter = Adapter(embed_dim, args.intermedia_adapter,
self.sae_adapter = Adapter(embed_dim, args.sae_adapter,
len(dictionary), len(dictionary),
# embed_tokens=embed_tokens,
strategy=strategy) strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) if args.share_target_ctc_and_sae and hasattr(self.sae_adapter, "embed_adapter"):
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1) self.ctc.ctc_projection.weight = self.sae_adapter.embed_adapter.weight
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
def forward(self, x, encoder_padding_mask=None, history=None): def forward(self, x, encoder_padding_mask=None, history=None):
...@@ -266,7 +275,7 @@ class TextEncoder(FairseqEncoder): ...@@ -266,7 +275,7 @@ class TextEncoder(FairseqEncoder):
x = self.dropout_module(x) x = self.dropout_module(x)
target_ctc_logit = None target_ctc_logit = None
target_intermedia_ctc_logits = [] target_interleaved_ctc_logits = []
layer_idx = 0 layer_idx = 0
for layer in self.layers: for layer in self.layers:
if history is not None: if history is not None:
...@@ -277,18 +286,18 @@ class TextEncoder(FairseqEncoder): ...@@ -277,18 +286,18 @@ class TextEncoder(FairseqEncoder):
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
target_ctc_logit = self.ctc(x.clone()) target_ctc_logit = self.ctc(x.clone())
if layer_idx != self.layer_num and layer_idx in self.intermedia_ctc_layers: if layer_idx != self.layer_num and layer_idx in self.interleaved_ctc_layers:
if self.intermedia_drop_prob > 0: if self.interleaved_ctc_drop_prob > 0:
p = torch.rand(1).uniform_() p = torch.rand(1).uniform_()
if p < self.intermedia_drop_prob: if p < self.interleaved_ctc_drop_prob:
break break
norm_x = self.layer_norm(x) norm_x = self.layer_norm(x)
logit = self.ctc(norm_x) logit = self.ctc(norm_x)
target_intermedia_ctc_logits.append(logit) target_interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1) prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask) x, encoder_padding_mask = self.sae_adapter([x, prob], encoder_padding_mask)
if history is not None: if history is not None:
history.push(x) history.push(x)
...@@ -302,7 +311,7 @@ class TextEncoder(FairseqEncoder): ...@@ -302,7 +311,7 @@ class TextEncoder(FairseqEncoder):
if self.use_ctc and target_ctc_logit is None: if self.use_ctc and target_ctc_logit is None:
target_ctc_logit = self.ctc(x) target_ctc_logit = self.ctc(x)
return x, target_ctc_logit, target_intermedia_ctc_logits return x, target_ctc_logit, target_interleaved_ctc_logits
class S2TSATEEncoder(FairseqEncoder): class S2TSATEEncoder(FairseqEncoder):
...@@ -322,13 +331,12 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -322,13 +331,12 @@ class S2TSATEEncoder(FairseqEncoder):
logging.error("Unsupported model arch {}!".format(acoustic_encoder_type)) logging.error("Unsupported model arch {}!".format(acoustic_encoder_type))
# adapter # adapter
self.temperature = args.temperature self.adapter_temperature = args.adapter_temperature
strategy = {
strategy = None "ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
if args.adapter == "shrink": "distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
strategy = getattr(args, "ctc_compress_strategy", "avg") "drop_prob": getattr(args, "sae_drop_prob", 0),
elif args.adapter == "league": }
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(args.encoder_embed_dim, self.adapter = Adapter(args.encoder_embed_dim,
args.adapter, args.adapter,
...@@ -341,8 +349,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -341,8 +349,7 @@ class S2TSATEEncoder(FairseqEncoder):
acoustic_encoder_attention_type = args.encoder_attention_type acoustic_encoder_attention_type = args.encoder_attention_type
args.encoder_attention_type = args.text_attention_type args.encoder_attention_type = args.text_attention_type
# textual encoder
# text encoder
self.text_encoder = TextEncoder(args, task.source_dictionary, decoder_embed_tokens) self.text_encoder = TextEncoder(args, task.source_dictionary, decoder_embed_tokens)
args.encoder_attention_type = acoustic_encoder_attention_type args.encoder_attention_type = acoustic_encoder_attention_type
...@@ -369,10 +376,14 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -369,10 +376,14 @@ class S2TSATEEncoder(FairseqEncoder):
encoder_out = acoustic_encoder_out["encoder_out"][0] encoder_out = acoustic_encoder_out["encoder_out"][0]
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0] encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
ctc_padding_mask = encoder_padding_mask ctc_padding_mask = encoder_padding_mask
if "mixup" in encoder_out:
mixup = encoder_out["mixup"]
else:
mixup = None
if "ctc_logit" in acoustic_encoder_out and len(acoustic_encoder_out["ctc_logit"]) > 0: if "ctc_logit" in acoustic_encoder_out and len(acoustic_encoder_out["ctc_logit"]) > 0:
ctc_logit = acoustic_encoder_out["ctc_logit"][0] ctc_logit = acoustic_encoder_out["ctc_logit"][0]
ctc_prob = F.softmax(ctc_logit / self.temperature, dim=-1, dtype=torch.float32) ctc_prob = F.softmax(ctc_logit / self.adapter_temperature, dim=-1, dtype=torch.float32)
else: else:
ctc_logit = None ctc_logit = None
ctc_prob = None ctc_prob = None
...@@ -392,18 +403,20 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -392,18 +403,20 @@ class S2TSATEEncoder(FairseqEncoder):
if self.freeze_textual_encoder: if self.freeze_textual_encoder:
with torch.no_grad(): with torch.no_grad():
x, target_ctc_logit, target_intermedia_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history) x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask,
self.history)
else: else:
x, target_ctc_logit, target_intermedia_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history) x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C "ctc_logit": [ctc_logit], # T x B x C
"intermedia_ctc_logits": acoustic_encoder_out.get("intermedia_ctc_logits", []), # B x T x C "interleaved_ctc_logits": acoustic_encoder_out.get("interleaved_ctc_logits", []), # B x T x C
"target_ctc_logit": target_ctc_logit, # B x T x C "target_ctc_logit": target_ctc_logit, # B x T x C
"target_intermedia_ctc_logits": target_intermedia_ctc_logits, # B x T x C "target_interleaved_ctc_logits": target_interleaved_ctc_logits, # B x T x C
"ctc_padding_mask": [ctc_padding_mask], # B x T "ctc_padding_mask": [ctc_padding_mask], # B x T
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"mixup": mixup,
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C] "encoder_states": [], # List[T x B x C]
"src_tokens": [], "src_tokens": [],
...@@ -458,7 +471,7 @@ def base_architecture(args): ...@@ -458,7 +471,7 @@ def base_architecture(args):
args.subsampling_norm = getattr(args, "subsampling_norm", "none") args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu") args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
# transformer # Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12) args.encoder_layers = getattr(args, "encoder_layers", 12)
...@@ -480,6 +493,10 @@ def base_architecture(args): ...@@ -480,6 +493,10 @@ def base_architecture(args):
args.activation_fn = getattr(args, "activation_fn", "relu") args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.share_decoder_input_output_embed = getattr( args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False args, "share_decoder_input_output_embed", False
) )
...@@ -497,23 +514,58 @@ def base_architecture(args): ...@@ -497,23 +514,58 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1) args.embed_linear = getattr(args, "embed_linear", False)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) # CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer # Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False) args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False) args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31) args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.cnn_module_norm = getattr(args, "cnn_module_norm", "batch_norm")
# settings for DLCL
args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False)
args.use_dec_dlcl = getattr(args, "use_dec_dlcl", False)
args.init_value = getattr(args, 'init_value', 'avg')
args.weight_type = getattr(args, 'weight_type', 'scalar')
args.encoder_learnable = getattr(args, 'encoder_learnable', True)
args.decoder_learnable = getattr(args, 'decoder_learnable', True)
args.normalize_embed = getattr(args, 'normalize_embed', False)
args.history_dropout = getattr(args, 'history_dropout', 0.0)
args.history_window_size = getattr(args, 'history_window_size', -1)
# Relative position encoding
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# SATE # local modeling
args.acoustic_encoder = getattr(args, "acoustic_encoder", "transformer") args.hard_mask_window = getattr(args, 'hard_mask_window', 0)
args.adapter = getattr(args, "adapter", "league") args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.ctc_compress_strategy = getattr(args, "ctc_compress_strategy", "avg") args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
args.temperature = getattr(args, "temperature", 1.0)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) # interleaved CTC
args.text_attention_type = getattr(args, "text_attention_type", "selfattn") args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.share_ctc_and_adapter = getattr(args, "share_ctc_and_adapter", False) args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.share_target_ctc_and_sae = getattr(args, "share_target_ctc_and_sae", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 1)
# PDS # PDS
args.pds_stages = getattr(args, "pds_stages", None) args.pds_stages = getattr(args, "pds_stages", None)
...@@ -539,10 +591,14 @@ def base_architecture(args): ...@@ -539,10 +591,14 @@ def base_architecture(args):
args.pds_fusion = getattr(args, "pds_fusion", False) args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv") args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC # SATE
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0") args.acoustic_encoder = getattr(args, "acoustic_encoder", "transformer")
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none") args.adapter = getattr(args, "adapter", "league")
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) args.ctc_compress_strategy = getattr(args, "ctc_compress_strategy", "avg")
args.adapter_temperature = getattr(args, "adapter_temperature", 1.0)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.text_attention_type = getattr(args, "text_attention_type", "selfattn")
args.share_ctc_and_adapter = getattr(args, "share_ctc_and_adapter", False)
@register_model_architecture("s2t_sate", "s2t_sate_s") @register_model_architecture("s2t_sate", "s2t_sate_s")
......
...@@ -2,6 +2,7 @@ import logging ...@@ -2,6 +2,7 @@ import logging
import math import math
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -139,9 +140,35 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -139,9 +140,35 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"rel_pos", "rel_pos",
"rope", "rope",
"abs", "abs",
"transfer",
"reduced_rel_pos",
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
# transfer
parser.add_argument(
"--relative-pos-enc",
action="store_true",
help="use relative position encoding for attention",
)
parser.add_argument(
"--linear-att",
action="store_true",
help="use linear attention",
)
# reduced attention
parser.add_argument(
"--attention-reduced-method",
type=str,
default="conv",
help="reduction method for attention",
)
parser.add_argument(
"--attention-reduced-q",
action="store_true",
help="use reduction for query or not"
)
parser.add_argument( parser.add_argument(
"--encoder-attention-heads", "--encoder-attention-heads",
type=int, type=int,
...@@ -286,6 +313,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -286,6 +313,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=int, type=int,
help="the position of the ctc loss", help="the position of the ctc loss",
) )
parser.add_argument(
"--share-ctc-and-embed",
action="store_true",
help="share the weight of ctc and embedding",
)
# local modeling # local modeling
parser.add_argument( parser.add_argument(
...@@ -349,63 +381,97 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -349,63 +381,97 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="Use convolution module or not", help="Use convolution module or not",
) )
parser.add_argument( parser.add_argument(
"--cnn-module-norm",
default="batch_norm",
type=str,
help="normalization type of cnn module",
)
parser.add_argument(
"--cnn-module-kernel", "--cnn-module-kernel",
default=31, default=31,
type=int, type=int,
help="Kernel size of convolution module.", help="Kernel size of convolution module.",
) )
# Simultaneous speech translation
parser.add_argument( parser.add_argument(
"--simul", "--embed-linear",
default=False,
action="store_true", action="store_true",
help="Simultaneous speech translation or not", help="use linear transform after down-sampling",
) )
# interleaved dropout
parser.add_argument('--interleave-dropout', type=int, # interleaved CTC layers
help='interleaved dropout probability')
parser.add_argument('--cl-dropout',
action="store_true",
default=False,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout-epoch',
type=int,
default=None,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout-strategy',
type=str,
help='interleaved dropout probability')
# intermedia CTC loss
parser.add_argument( parser.add_argument(
"--intermedia-ctc-layers", "--interleaved-ctc-layers",
default=None, default=None,
type=str, type=str,
help="the position of the ctc loss, separated by comma ", help="the position of interleaved ctc layers, separated by comma ",
) )
parser.add_argument( parser.add_argument(
"--intermedia-adapter", "--interleaved-ctc-temperature",
default=1,
type=float,
help="temperature of the CTC probability in sae",
)
parser.add_argument(
"--interleaved-ctc-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
# Semantics-augmented Encoding (SAE)
parser.add_argument(
"--sae-adapter",
default="none", default="none",
type=str, type=str,
help="type of intermedia adapter", help="adapter type of sae ",
)
parser.add_argument(
"--sae-drop-prob",
default=0,
type=float,
help="dropping one input in sae with a probability",
) )
parser.add_argument( parser.add_argument(
"--intermedia-distribution-cutoff", "--sae-distribution-cutoff",
default=None, default=None,
type=int, type=int,
help="cutoff of the distribution", help="cutoff of the distribution in sae",
) )
parser.add_argument( parser.add_argument(
"--intermedia-drop-prob", "--share-ctc-and-sae",
default=0, action="store_true",
help="share the weight of ctc and sae",
)
# Mixup
parser.add_argument(
"--inter-mixup",
action="store_true",
help="use mixup or not",
)
parser.add_argument(
"--inter-mixup-layer",
default=None,
type=int,
help="the layers to apply mixup",
)
parser.add_argument(
"--inter-mixup-beta",
default=0.5,
type=float, type=float,
help="probability of dropping the followed layers", help="the coefficient beta of mixup",
) )
parser.add_argument( parser.add_argument(
"--intermedia-temperature", "--inter-mixup-prob",
default=1, default=1,
type=float, type=float,
help="temperature of the intermedia ctc probability", help="the probability of mixup",
)
parser.add_argument(
"--inter-mixup-ratio",
default=1,
type=float,
help="the ratio of mixup",
) )
pass pass
...@@ -513,10 +579,11 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -513,10 +579,11 @@ class S2TTransformerEncoder(FairseqEncoder):
self.padding_idx = 1 self.padding_idx = 1
self.subsample = subsampling(args) self.subsample = subsampling(args)
# self.linear = nn.Linear(dim, dim) self.embed_linear = getattr(args, "embed_linear", False)
if self.embed_linear:
self.linear = nn.Linear(dim, dim)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
if self.attn_type == "rel_pos": if self.attn_type == "rel_pos":
self.embed_positions = RelPositionalEncoding( self.embed_positions = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim args.max_source_positions, args.encoder_embed_dim
...@@ -546,9 +613,6 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -546,9 +613,6 @@ class S2TTransformerEncoder(FairseqEncoder):
else: else:
self.history = None self.history = None
# self.use_ctc = "sate" in args.arch or \
# (getattr(args, "criterion", "") == "ctc") or \
# (("ctc" in getattr(args, "criterion", "")) and (getattr(args, "ctc_weight", 0) > 0))
self.use_ctc = "sate" in args.arch or getattr(args, "ctc_weight", 0) > 0 self.use_ctc = "sate" in args.arch or getattr(args, "ctc_weight", 0) > 0
if self.use_ctc: if self.use_ctc:
self.ctc_layer = args.ctc_layer self.ctc_layer = args.ctc_layer
...@@ -560,47 +624,63 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -560,47 +624,63 @@ class S2TTransformerEncoder(FairseqEncoder):
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False) need_layernorm=True if self.inter_ctc else False)
if task.source_dictionary == task.target_dictionary and \ if getattr(args, "share_ctc_and_embed", False) and \
task.source_dictionary == task.target_dictionary and \
embed_tokens is not None and dim == embed_tokens.embedding_dim: embed_tokens is not None and dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_dropout = getattr(args, "interleave_dropout", None) self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.gather_cos_sim = getattr(args, "gather_cos_sim", False) self.interleaved_ctc_layers = []
# self.gather_cos_sim = True if args.interleaved_ctc_layers is not None:
self.dis = 2 interleaved_ctc_layers = args.interleaved_ctc_layers.split(",")
self.cos_sim = dict() for layer_idx in interleaved_ctc_layers:
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
layer_idx = int(layer_idx) layer_idx = int(layer_idx)
if layer_idx <= 0: if layer_idx <= 0:
layer_idx += args.encoder_layers layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx) self.interleaved_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx) logger.info("Interleaved CTC loss in layer %d" % layer_idx)
if not self.use_ctc: if not self.use_ctc:
self.ctc = CTC(dim, self.ctc = CTC(dim,
dictionary_size=len(task.source_dictionary), dictionary_size=len(task.source_dictionary),
dropout=args.dropout) dropout=args.dropout)
if getattr(args, "share_ctc_and_embed", False) and \
if task.source_dictionary == task.target_dictionary and embed_tokens is not None: task.source_dictionary == task.target_dictionary and \
embed_tokens is not None and dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = None strategy = {
if args.intermedia_adapter == "shrink": "ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
strategy = getattr(args, "ctc_compress_strategy", None) "distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
elif args.intermedia_adapter == "league": "drop_prob": getattr(args, "sae_drop_prob", 0),
strategy = getattr(args, "intermedia_distribution_cutoff", None) }
self.adapter = Adapter(dim, args.intermedia_adapter,
len(task.source_dictionary), strategy=strategy) self.sae_adapter = Adapter(dim, args.sae_adapter,
# embed_tokens=embed_tokens if embed_tokens is not None else self.ctc.ctc_projection) len(task.source_dictionary),
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) strategy=strategy,
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1) )
if args.share_ctc_and_sae and hasattr(self.sae_adapter, "embed_adapter"):
self.ctc.ctc_projection.weight = self.sae_adapter.embed_adapter.weight
# mixup
self.mixup = getattr(args, "inter_mixup", False)
if self.mixup:
self.mixup_layer = int(args.inter_mixup_layer)
self.mixup_prob = float(args.inter_mixup_prob)
self.mixup_ratio = float(args.inter_mixup_ratio)
beta = float(args.inter_mixup_beta)
from torch.distributions import Beta
self.beta = Beta(torch.Tensor([beta]), torch.Tensor([beta]))
logger.info("Use mixup in layer %d with beta %.2f, prob %.2f, ratio %.2f." % (
self.mixup_layer, beta, self.mixup_prob, self.mixup_ratio))
# gather cosine similarity
self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
self.dis = 2
self.cos_sim = dict()
@staticmethod @staticmethod
def pooling_ratio(): def pooling_ratio():
...@@ -624,21 +704,67 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -624,21 +704,67 @@ class S2TTransformerEncoder(FairseqEncoder):
self.cos_sim[idx] = [] self.cos_sim[idx] = []
self.cos_sim[idx].append(float(sim)) self.cos_sim[idx].append(float(sim))
def apply_mixup(self, x, encoder_padding_mask):
batch = x.size(1)
indices = np.random.permutation(batch)
if self.mixup_ratio == 1:
if len(indices) % 2 != 0:
indices = np.append(indices, (indices[-1]))
idx1 = indices[0::2]
idx2 = indices[1::2]
else:
mix_size = int(max(2, batch * self.mixup_ratio // 2 * 2))
mix_indices = indices[: mix_size]
idx1 = np.append(mix_indices[0::2], (indices[mix_size:]))
idx2 = np.append(mix_indices[1::2], (indices[mix_size:]))
idx1 = torch.from_numpy(idx1).to(x.device)
idx2 = torch.from_numpy(idx2).to(x.device)
x1 = x[:, idx1]
x2 = x[:, idx2]
coef = self.beta.sample().to(x.device).type_as(x)
x = (coef * x1 + (1 - coef) * x2)
pad1 = encoder_padding_mask[idx1]
pad2 = encoder_padding_mask[idx2]
encoder_padding_mask = pad1 + pad2
input_lengths = (~encoder_padding_mask).sum(-1)
mixup = {
"coef": coef,
"index1": idx1,
"index2": idx2,
}
return x, encoder_padding_mask, input_lengths, mixup
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
layer_idx = -1
mixup = None
if self.history is not None: if self.history is not None:
self.history.clean() self.history.clean()
# (B, T, D) -> (T, B, D)
x = src_tokens.transpose(0, 1)
input_lengths = src_lengths
# gather cosine similarity # gather cosine similarity
cos_sim_idx = -1 cos_sim_idx = -1
dis = self.dis dis = self.dis
if self.gather_cos_sim: if self.gather_cos_sim:
self.add_to_dict(src_tokens.transpose(0, 1), dis, cos_sim_idx) self.add_to_dict(x, dis, cos_sim_idx)
if self.training and self.mixup and layer_idx == self.mixup_layer:
if torch.rand(1) < self.mixup_prob:
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
# down-sampling # down-sampling
# (B, T, D) -> (T, B, D) x, input_lengths = self.subsample(x, input_lengths)
x = src_tokens.transpose(0, 1)
x, input_lengths = self.subsample(x, src_lengths)
# embedding scaling # embedding scaling
x = self.embed_scale * x x = self.embed_scale * x
...@@ -657,7 +783,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -657,7 +783,8 @@ class S2TTransformerEncoder(FairseqEncoder):
x += positions x += positions
positions = None positions = None
# x = self.linear(x) if self.embed_linear:
x = self.linear(x)
x = self.dropout_module(x) x = self.dropout_module(x)
# add emb into history # add emb into history
...@@ -670,43 +797,46 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -670,43 +797,46 @@ class S2TTransformerEncoder(FairseqEncoder):
cos_sim_idx += 1 cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx) self.add_to_dict(x, dis, cos_sim_idx)
layer_idx = 0
ctc_logit = None
intermedia_ctc_logits = []
for layer in self.layers:
layer_idx += 1 layer_idx += 1
ctc_logit = None
interleaved_ctc_logits = []
if self.training and self.mixup and layer_idx == self.mixup_layer:
if torch.rand(1) < self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
for layer in self.layers:
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
if layer_idx != len(self.layers) \
and self.interleaved_dropout is not None \
and layer_idx % self.interleaved_dropout == 0:
x = self.dropout_module(x)
# encoder layer # encoder layer
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
if self.training and self.mixup and layer_idx == self.mixup_layer:
if torch.rand(1) < self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone()) ctc_logit = self.ctc(x.clone())
# interleave CTC # interleaved CTC
if layer_idx in self.intermedia_ctc_layers: if layer_idx in self.interleaved_ctc_layers:
if self.intermedia_drop_prob > 0: if self.interleaved_ctc_drop_prob > 0:
p = torch.rand(1).uniform_() p = torch.rand(1).uniform_()
if p < self.intermedia_drop_prob: if p < self.interleaved_ctc_drop_prob:
break break
norm_x = self.layer_norm(x) norm_x = self.layer_norm(x)
logit = self.ctc(norm_x) logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit) interleaved_ctc_logits.append(logit)
logit = logit.clamp(min=-1e8 if logit.dtype == torch.float32 else -1e4, logit = logit.clamp(min=-1e8 if logit.dtype == torch.float32 else -1e4,
max=1e8 if logit.dtype == torch.float32 else 1e4) max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1) prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask) x, encoder_padding_mask = self.sae_adapter([x, prob], encoder_padding_mask)
# gather cosine similarity # gather cosine similarity
if self.gather_cos_sim: if self.gather_cos_sim:
...@@ -728,8 +858,9 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -728,8 +858,9 @@ class S2TTransformerEncoder(FairseqEncoder):
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C "ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C "interleaved_ctc_logits": interleaved_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"mixup": mixup,
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C] "encoder_states": [], # List[T x B x C]
"src_tokens": [], "src_tokens": [],
...@@ -872,14 +1003,18 @@ def base_architecture(args): ...@@ -872,14 +1003,18 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.embed_linear = getattr(args, "embed_linear", False)
# CTC # CTC
args.ctc_layer = getattr(args, "ctc_layer", 0) args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer # Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu") args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False) args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False) args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31) args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.cnn_module_norm = getattr(args, "cnn_module_norm", "batch_norm")
# settings for DLCL # settings for DLCL
args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False) args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False)
...@@ -902,16 +1037,23 @@ def base_architecture(args): ...@@ -902,16 +1037,23 @@ def base_architecture(args):
args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0) args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.init_mask_weight = getattr(args, 'init_mask_weight', 0) args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
# interleaved dropout # interleaved CTC
args.interleave_dropout = getattr(args, "interleave_dropout", None) args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.cl_dropout = getattr(args, "cl_dropout", False) args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None) args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
# Semantics-augmented Encoding (sae)
# intermedia CTC args.sae_adapter = getattr(args, "sae_adapter", "none")
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None) args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None) args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 1)
@register_model_architecture("s2t_transformer", "s2t_transformer_s") @register_model_architecture("s2t_transformer", "s2t_transformer_s")
......
...@@ -286,9 +286,6 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -286,9 +286,6 @@ class TransformerModel(FairseqEncoderDecoderModel):
help="freeze the module of the decoder", help="freeze the module of the decoder",
) )
parser.add_argument('--interleave-dropout', default=0, type=float, metavar='D',
help='interleaved dropout probability')
parser.add_argument( parser.add_argument(
"--squeeze-excitation", "--squeeze-excitation",
default=False, default=False,
......
...@@ -286,50 +286,61 @@ class TransformerCTCModel(FairseqEncoderDecoderModel): ...@@ -286,50 +286,61 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
help="freeze the module of the decoder", help="freeze the module of the decoder",
) )
parser.add_argument('--interleave-dropout', default=0, type=float, metavar='D',
help='interleaved dropout probability')
parser.add_argument( parser.add_argument(
"--squeeze-excitation", "--squeeze-excitation",
default=False, default=False,
action='store_true', action='store_true',
help="use squeeze and excitation method", help="use squeeze and excitation method",
) )
# CTC # interleaved CTC layers
parser.add_argument(
"--ctc-layer",
type=int,
help="ctc layers for target sentence",
)
parser.add_argument( parser.add_argument(
"--intermedia-ctc-layers", "--interleaved-ctc-layers",
default=None, default=None,
type=str, type=str,
help="the position of the ctc loss, separated by comma ", help="the position of interleaved ctc layers, separated by comma",
) )
parser.add_argument( parser.add_argument(
"--intermedia-adapter", "--interleaved-ctc-upsampling-ratio",
default="none", default=2,
type=str, type=int,
help="type of intermedia adapter", help="upsampling ratio of the representation for CTC calculation",
) )
parser.add_argument( parser.add_argument(
"--intermedia-distribution-cutoff", "--interleaved-ctc-temperature",
default=None, default=1,
type=int, type=float,
help="cutoff of the distribution", help="temperature of the CTC probability in sae",
) )
parser.add_argument( parser.add_argument(
"--intermedia-drop-prob", "--interleaved-ctc-drop-prob",
default=0, default=0,
type=float, type=float,
help="probability of dropping the followed layers", help="probability of dropping the followed layers",
) )
# Semantics-augmented Encoding (SAE)
parser.add_argument( parser.add_argument(
"--intermedia-temperature", "--sae-adapter",
default=1, default="none",
type=str,
help="adapter type of sae ",
)
parser.add_argument(
"--sae-drop-prob",
default=0,
type=float, type=float,
help="temperature of the intermedia ctc probability", help="dropping one input in sae with a probability",
)
parser.add_argument(
"--sae-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution in sae",
)
parser.add_argument(
"--share-ctc-and-sae",
action="store_true",
help="share the weight of ctc and sae",
) )
# fmt: on # fmt: on
...@@ -574,6 +585,7 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -574,6 +585,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC # CTC
self.use_ctc = getattr(args, "ctc_weight", 0) > 0 self.use_ctc = getattr(args, "ctc_weight", 0) > 0
if self.use_ctc: if self.use_ctc:
assert decoder_embed_tokens is not None
self.ctc_layer = args.ctc_layer self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.inter_ctc: if self.inter_ctc:
...@@ -583,35 +595,41 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -583,35 +595,41 @@ class TransformerCTCEncoder(FairseqEncoder):
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False) need_layernorm=True if self.inter_ctc else False)
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = decoder_embed_tokens.weight
self.intermedia_ctc_layers = [] self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
if args.intermedia_ctc_layers is not None: self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",") self.interleaved_ctc_upsampling_ratio = args.interleaved_ctc_upsampling_ratio
for layer_idx in intermedia_ctc_layers: self.interleaved_ctc_layers = []
if args.interleaved_ctc_layers is not None:
interleaved_ctc_layers = args.interleaved_ctc_layers.split(",")
for layer_idx in interleaved_ctc_layers:
layer_idx = int(layer_idx) layer_idx = int(layer_idx)
if layer_idx <= 0: if layer_idx <= 0:
layer_idx += args.encoder_layers layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx) self.interleaved_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx) logger.info("Interleaved CTC loss in layer %d" % layer_idx)
if not self.use_ctc: if not self.use_ctc:
self.ctc = CTC(embed_dim, self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings, dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout) dropout=args.dropout)
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = decoder_embed_tokens.weight
strategy = {
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
strategy = None self.sae_adapter = Adapter(embed_dim, args.sae_adapter,
if args.intermedia_adapter == "shrink": decoder_embed_tokens.num_embeddings,
strategy = getattr(args, "ctc_compress_strategy", None) strategy=strategy
elif args.intermedia_adapter == "league": )
strategy = getattr(args, "intermedia_distribution_cutoff", None) if args.share_ctc_and_sae and hasattr(self.sae_adapter, "embed_adapter"):
self.adapter = Adapter(embed_dim, args.intermedia_adapter, self.ctc.ctc_projection.weight = self.sae_adapter.embed_adapter.weight
decoder_embed_tokens.num_embeddings, embed_tokens=decoder_embed_tokens, strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
def build_encoder_layer(self, args): def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args) layer = TransformerEncoderLayer(args)
...@@ -672,12 +690,13 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -672,12 +690,13 @@ class TransformerCTCEncoder(FairseqEncoder):
return_all_hiddens, return_all_hiddens,
token_embeddings) token_embeddings)
def upsample(self, x, ratio=2): def upsampling(self, x):
ratio = self.interleaved_ctc_upsampling_ratio
if ratio <= 1: if ratio <= 1:
return x return x
seq_len, bsz, dim = x.size() seq_len, bsz, dim = x.size()
x = x.unsqueeze(0).expand(ratio, -1, -1, -1).reshape(-1, bsz, dim) x = x.unsqueeze(1).expand(-1, ratio, -1, -1).reshape(-1, bsz, dim)
return x return x
# TorchScript doesn't support super() method so that the scriptable Subclass # TorchScript doesn't support super() method so that the scriptable Subclass
...@@ -742,7 +761,7 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -742,7 +761,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# encoder layers # encoder layers
layer_idx = 0 layer_idx = 0
ctc_logit = None ctc_logit = None
intermedia_ctc_logits = [] interleaved_ctc_logits = []
for layer in self.layers: for layer in self.layers:
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
...@@ -757,24 +776,28 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -757,24 +776,28 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC # CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(self.upsample(x.clone())) ctc_logit = self.ctc(self.upsampling(x.clone()))
# Intermedia CTC # Intermedia CTC
if layer_idx in self.intermedia_ctc_layers: if layer_idx in self.interleaved_ctc_layers:
if self.intermedia_drop_prob > 0: if self.interleaved_ctc_drop_prob > 0:
p = torch.rand(1).uniform_() p = torch.rand(1).uniform_()
if p < self.intermedia_drop_prob: if p < self.interleaved_ctc_drop_prob:
break break
norm_x = self.layer_norm(x) norm_x = self.layer_norm(x)
up_x = self.upsample(norm_x) up_x = self.upsampling(norm_x)
up_logit = self.ctc(up_x) up_logit = self.ctc(up_x)
intermedia_ctc_logits.append(up_logit) interleaved_ctc_logits.append(up_logit)
up_prob = utils.softmax(up_logit / self.intermedia_temperature, dim=-1) up_prob = utils.softmax(up_logit / self.interleaved_ctc_temperature, dim=-1)
up_prob = up_prob.permute(1, 2, 0) up_prob = up_prob.permute(1, 2, 0)
prob = nn.functional.max_pool1d(up_prob, kernel_size=2, stride=2) prob = nn.functional.max_pool1d(up_prob,
kernel_size=self.interleaved_ctc_upsampling_ratio,
stride=self.interleaved_ctc_upsampling_ratio)
prob = prob.permute(2, 0, 1) prob = prob.permute(2, 0, 1)
x, _ = self.adapter([x, prob]) x, _ = self.adapter([x, prob])
if self.history is not None: if self.history is not None:
...@@ -787,12 +810,13 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -787,12 +810,13 @@ class TransformerCTCEncoder(FairseqEncoder):
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None: if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(self.upsample(x)) ctc_logit = self.ctc(self.upsampling(x))
ctc_padding_mask = encoder_padding_mask ctc_padding_mask = encoder_padding_mask
if ctc_logit is not None or len(intermedia_ctc_logits) != 0: if ctc_logit is not None or len(interleaved_ctc_logits) != 0:
bsz = encoder_padding_mask.size(0) bsz = encoder_padding_mask.size(0)
ctc_padding_mask = encoder_padding_mask.unsqueeze(-1).expand(-1, -1, 2).reshape(bsz, -1) ctc_padding_mask = encoder_padding_mask.unsqueeze(-1).\
expand(-1, -1, self.interleaved_ctc_upsampling_ratio).reshape(bsz, -1)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead. # `forward` so we use a dictionary instead.
...@@ -802,7 +826,7 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -802,7 +826,7 @@ class TransformerCTCEncoder(FairseqEncoder):
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C "ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"ctc_padding_mask": [ctc_padding_mask], "ctc_padding_mask": [ctc_padding_mask],
"intermedia_ctc_logits": intermedia_ctc_logits, # T x B x C "interleaved_ctc_logits": interleaved_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C "encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C] "encoder_states": encoder_states, # List[T x B x C]
...@@ -1457,9 +1481,17 @@ def base_architecture(args): ...@@ -1457,9 +1481,17 @@ def base_architecture(args):
# CTC # CTC
args.ctc_layer = getattr(args, "ctc_layer", args.encoder_layers) args.ctc_layer = getattr(args, "ctc_layer", args.encoder_layers)
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None) # interleaved CTC
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
@register_model_architecture("transformer_ctc", "transformer_ctc_relative") @register_model_architecture("transformer_ctc", "transformer_ctc_relative")
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from torch import nn from torch import nn
from fairseq.modules.activations import get_activation_class from fairseq.modules.activations import get_activation_class
from fairseq.modules.layer_norm import LayerNorm
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
...@@ -17,6 +18,7 @@ class ConvolutionModule(nn.Module): ...@@ -17,6 +18,7 @@ class ConvolutionModule(nn.Module):
bias=False, bias=False,
stride=1, stride=1,
export=False, export=False,
norm_type="batch_norm"
): ):
""" """
Args: Args:
...@@ -50,7 +52,13 @@ class ConvolutionModule(nn.Module): ...@@ -50,7 +52,13 @@ class ConvolutionModule(nn.Module):
groups=expand_embed_dim, groups=expand_embed_dim,
bias=bias, bias=bias,
) )
self.batch_norm = nn.BatchNorm1d(expand_embed_dim) self.norm_type = norm_type
if norm_type == "batch_norm":
self.norm = nn.BatchNorm1d(expand_embed_dim)
elif norm_type == "layer_norm":
self.norm = LayerNorm(expand_embed_dim)
else:
assert False, "Unsupported normalization type in convolution module"
self.activation = get_activation_class(activation_fn) self.activation = get_activation_class(activation_fn)
self.pointwise_conv2 = torch.nn.Conv1d( self.pointwise_conv2 = torch.nn.Conv1d(
expand_embed_dim, expand_embed_dim,
...@@ -62,7 +70,7 @@ class ConvolutionModule(nn.Module): ...@@ -62,7 +70,7 @@ class ConvolutionModule(nn.Module):
) )
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
def forward(self, x): def forward(self, x, mask_pad=None):
""" """
Args: Args:
x: Input of shape B X T X C x: Input of shape B X T X C
...@@ -72,23 +80,36 @@ class ConvolutionModule(nn.Module): ...@@ -72,23 +80,36 @@ class ConvolutionModule(nn.Module):
# exchange the temporal dimension and the feature dimension # exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2) x = x.transpose(1, 2)
# zero_mask_pad = mask_pad.unsqueeze(1)
# # mask batch padding
# if mask_pad is not None:
# x.masked_fill_(zero_mask_pad, 0.0)
# GLU mechanism # GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*expand_embed_dim, dim) x = self.pointwise_conv1(x) # (batch, 2*expand_embed_dim, dim)
x = self.glu(x) # (batch, expand_embed_dim, dim) x = self.glu(x) # (batch, expand_embed_dim, dim)
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.batch_norm(x)
if self.norm_type == "layer_norm":
x = x.transpose(1, 2)
x = self.norm(x)
x = self.activation(x) x = self.activation(x)
if self.norm_type == "layer_norm":
x = x.transpose(1, 2)
x = self.pointwise_conv2(x) x = self.pointwise_conv2(x)
# # mask batch padding
# if zero_mask_pad is not None:
# x.masked_fill_(zero_mask_pad, 0.0)
x = x.transpose(1, 2) x = x.transpose(1, 2)
x = self.dropout(x) x = self.dropout(x)
return x return x
# class ConvolutionModule(nn.Module): # class ConvolutionModule(nn.Module):
# """ConvolutionModule in Conformer model.""" # """ConvolutionModule in Conformer model."""
# def __init__(self, # def __init__(self,
......
...@@ -332,7 +332,7 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -332,7 +332,7 @@ class PDSTransformerEncoderLayer(nn.Module):
if self.normalize_before: if self.normalize_before:
x = self.conv_norm(x) x = self.conv_norm(x)
x = self.conv_module(x) x = self.conv_module(x, encoder_padding_mask)
x = x.transpose(0, 1) x = x.transpose(0, 1)
x = self.conv_res(residual) + x x = self.conv_res(residual) + x
......
...@@ -122,7 +122,9 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -122,7 +122,9 @@ class S2TTransformerEncoderLayer(nn.Module):
self.embed_dim, self.embed_dim,
depthwise_kernel_size=args.cnn_module_kernel, depthwise_kernel_size=args.cnn_module_kernel,
dropout=args.dropout, dropout=args.dropout,
activation_fn=getattr(args, 'activation_fn', 'swish')) activation_fn=getattr(args, 'activation_fn', 'swish'),
norm_type=args.cnn_module_norm
)
self.final_norm = LayerNorm(embed_dim) self.final_norm = LayerNorm(embed_dim)
else: else:
self.conv_norm = None self.conv_norm = None
......
...@@ -3,6 +3,7 @@ import logging ...@@ -3,6 +3,7 @@ import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from itertools import groupby
from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.modules import LayerNorm from fairseq.modules import LayerNorm
...@@ -61,9 +62,12 @@ class Adapter(nn.Module): ...@@ -61,9 +62,12 @@ class Adapter(nn.Module):
super().__init__() super().__init__()
dim = dim dim = dim
self.adapter_type = adapter_type self.adapter_type = adapter_type
self.cal_linear = False
self.cal_context = False
if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]: if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
self.cal_linear = True
self.linear_adapter = nn.Sequential( self.linear_adapter = nn.Sequential(
nn.Linear(dim, dim), nn.Linear(dim, dim),
LayerNorm(dim), LayerNorm(dim),
...@@ -71,14 +75,10 @@ class Adapter(nn.Module): ...@@ -71,14 +75,10 @@ class Adapter(nn.Module):
) )
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]: if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
self.cal_context = True
self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization
if embed_tokens is not None: if embed_tokens is not None:
self.embed_adapter.weight = embed_tokens.weight self.embed_adapter.weight = embed_tokens.weight
# if embed_tokens is None:
# num_embeddings = len(dictionary)
# self.embed_adapter = nn.Linear(num_embeddings, dim) # Embedding(num_embeddings, dim, dictionary.pad())
# else:
# self.embed_adapter = embed_tokens
if self.adapter_type == "gated_league": if self.adapter_type == "gated_league":
self.gate_linear = nn.Linear(2 * dim, dim) self.gate_linear = nn.Linear(2 * dim, dim)
...@@ -86,14 +86,22 @@ class Adapter(nn.Module): ...@@ -86,14 +86,22 @@ class Adapter(nn.Module):
self.gate_linear1 = nn.Linear(dim, dim) self.gate_linear1 = nn.Linear(dim, dim)
self.gate_linear2 = nn.Linear(dim, dim) self.gate_linear2 = nn.Linear(dim, dim)
# additional strategy
if self.adapter_type == "shrink": if self.adapter_type == "shrink":
assert strategy is not None assert strategy is not None
self.ctc_compress = getattr(CTCCompressStrategy, strategy) ctc_compress_strategy = getattr(strategy, "ctc_compress_strategy", "avg")
logger.info("CTC Compress Strategy: %s" % strategy) self.ctc_compress = getattr(CTCCompressStrategy, ctc_compress_strategy)
elif self.adapter_type == "league": logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy)
self.distribution_cutoff = strategy
if "league" in self.adapter_type:
self.distribution_cutoff = strategy.get("distribution_cutoff", None)
if self.distribution_cutoff is not None: if self.distribution_cutoff is not None:
logger.info("Distribution cutoff: %d" % int(strategy)) self.distribution_cutoff = int(self.distribution_cutoff)
logger.info("Distribution cutoff: %d" % self.distribution_cutoff)
self.drop_prob = strategy.get("drop_prob", 0)
if self.drop_prob != 0:
logger.info("Adapter drop probability: %f" % self.drop_prob)
def forward(self, x, padding=None): def forward(self, x, padding=None):
...@@ -103,14 +111,11 @@ class Adapter(nn.Module): ...@@ -103,14 +111,11 @@ class Adapter(nn.Module):
org_distribution = distribution org_distribution = distribution
distribution = distribution.contiguous().view(-1, distribution.size(-1)) distribution = distribution.contiguous().view(-1, distribution.size(-1))
if self.adapter_type == "linear": linear_out = None
out = self.linear_adapter(representation) soft_out = None
if self.cal_linear:
elif self.adapter_type == "context":
out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
if self.cal_context:
if self.distribution_cutoff is not None: if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), org_distribution.size(-1) - 1) cutoff = min(int(self.distribution_cutoff), org_distribution.size(-1) - 1)
threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1] threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
...@@ -120,24 +125,33 @@ class Adapter(nn.Module): ...@@ -120,24 +125,33 @@ class Adapter(nn.Module):
distribution = distribution.view(-1, distribution.size(-1)) distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
if self.adapter_type == "linear":
out = linear_out
elif self.adapter_type == "context":
out = soft_out
elif self.adapter_type == "league":
if self.drop_prob > 0 and torch.rand(1).uniform_() < self.drop_prob:
if torch.rand(1).uniform_() < 0.5:
out = linear_out
else:
out = soft_out
else:
out = linear_out + soft_out out = linear_out + soft_out
elif self.adapter_type == "gated_league": elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid() coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league": elif self.adapter_type == "inter_league":
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
out = representation + soft_out out = representation + soft_out
elif self.adapter_type == "none": elif self.adapter_type == "none":
out = representation out = representation
elif self.adapter_type == "shrink": elif self.adapter_type == "shrink":
from itertools import groupby
lengths = (~padding).long().sum(-1) lengths = (~padding).long().sum(-1)
with torch.no_grad(): with torch.no_grad():
......
...@@ -18,11 +18,9 @@ class CTC(nn.Module): ...@@ -18,11 +18,9 @@ class CTC(nn.Module):
super(CTC, self).__init__() super(CTC, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.ctc_projection = nn.Linear(embed_dim, dictionary_size, bias=False) self.ctc_projection = nn.Linear(embed_dim, dictionary_size)
nn.init.normal_( # nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5)
self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5
)
self.ctc_dropout_module = FairseqDropout( self.ctc_dropout_module = FairseqDropout(
p=dropout, module_name=self.__class__.__name__ p=dropout, module_name=self.__class__.__name__
...@@ -46,4 +44,3 @@ class CTC(nn.Module): ...@@ -46,4 +44,3 @@ class CTC(nn.Module):
def argmax(self, x): def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1) return torch.argmax(self.ctc_projection(x), dim=-1)
...@@ -191,7 +191,8 @@ class Conv2dSubsampling(nn.Module): ...@@ -191,7 +191,8 @@ class Conv2dSubsampling(nn.Module):
filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id], filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
kernel_size, kernel_size,
stride=stride, stride=stride,
padding=(kernel_size - 1) // 2), # padding=(kernel_size - 1) // 2
),
get_norm(norm, get_norm(norm,
filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id], filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
transpose=True if norm == "layer" else False), transpose=True if norm == "layer" else False),
...@@ -214,6 +215,8 @@ class Conv2dSubsampling(nn.Module): ...@@ -214,6 +215,8 @@ class Conv2dSubsampling(nn.Module):
# (B, C, D // S, T // S) -> (B, C * D // S, T // S) # (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size() batch_size, channels, subsampled_dim, subsampled_length = x.size()
assert subsampled_length == max(x_len), "The lengths are mismatched."
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).permute(2, 0, 1) x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).permute(2, 0, 1)
x = self.linear(x) x = self.linear(x)
......
...@@ -156,22 +156,22 @@ def main(cfg: FairseqConfig) -> None: ...@@ -156,22 +156,22 @@ def main(cfg: FairseqConfig) -> None:
) )
break break
if getattr(cfg.model, "cl_dropout", False): # if getattr(cfg.model, "cl_dropout", False):
cl_dropout_epoch = getattr(cfg.model, "cl_dropout_epoch", None) # cl_dropout_epoch = getattr(cfg.model, "cl_dropout_epoch", None)
cl_dropout_strategy = getattr(cfg.model, "cl_dropout_strategy", "linear") # cl_dropout_strategy = getattr(cfg.model, "cl_dropout_strategy", "linear")
dropout = getattr(cfg.model, "dropout", False) # dropout = getattr(cfg.model, "dropout", False)
assert cl_dropout_epoch > 0 # assert cl_dropout_epoch > 0
curr_epoch = epoch_itr.epoch # curr_epoch = epoch_itr.epoch
if curr_epoch <= cl_dropout_epoch: # if curr_epoch <= cl_dropout_epoch:
if curr_epoch == cl_dropout_epoch: # if curr_epoch == cl_dropout_epoch:
curr_dropout = dropout # curr_dropout = dropout
else: # else:
curr_dropout = curr_epoch / cl_dropout_epoch * dropout # curr_dropout = curr_epoch / cl_dropout_epoch * dropout
logger.info("Epoch {}: dropout ratio: {}.".format(curr_epoch, curr_dropout)) # logger.info("Epoch {}: dropout ratio: {}.".format(curr_epoch, curr_dropout))
for name, module in trainer.model.named_modules(): # for name, module in trainer.model.named_modules():
from fairseq.modules.fairseq_dropout import FairseqDropout # from fairseq.modules.fairseq_dropout import FairseqDropout
if isinstance(module, FairseqDropout): # if isinstance(module, FairseqDropout):
module.p = curr_dropout # module.p = curr_dropout
# train for one epoch # train for one epoch
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论