Commit 1d60b3a6 by xuchen

big update! I integrate the latest updates of shell scripts, optimize the…

big update! I integrate the latest updates of shell scripts, optimize the implementation of sae and fix some bugs.
parent 1288e535
......@@ -13,7 +13,7 @@ label_smoothing: 0.1
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 2048
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
......
arch: s2t_transformer_m
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 1e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv2d
subsmapling-layers: 2
subsampling-filter: 512
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: relu
dropout: 0.15
activation-fn: relu
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
load-pretrained-encoder-from: /home/xuchen/after.pt
load-pretrained-decoder-from: /home/xuchen/after.pt
#load-pretrained-decoder-from:
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
cnn-module-kernel: 15
encoder-attention-type: rel_pos
encoder-activation-fn: swish
\ No newline at end of file
......@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
......
#! /bin/bash
# Processing ASR Datasets
# Processing aishell ASR Datasets
# Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com)
......@@ -323,7 +323,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi
export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result
suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ -z ${cer} && ${cer} -eq 1 ]]; then
suffix=${suffix}_cer
else
suffix=${suffix}_wer
fi
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file}
test_subset=${test_subset//,/ }
......@@ -352,6 +361,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then
eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi
done
cat ${result_file}
......
......@@ -13,12 +13,12 @@ extra_parameter=
exp_tag=
#config_list=(base)
#config_list=(ctc)
#config_list=(base conformer)
#config_list=(base ctc)
config_list=(base ctc conformer)
config_list=(big ctc conformer)
#config_list=(pds_base_16)
config_list=(pds_base_16 conformer rpr)
config_list=(pds_base_16 conformer)
# exp full name
exp_name=
......
......@@ -14,7 +14,7 @@ sacrebleu=0
n_average=10
beam_size=5
len_penalty=1.0
max_tokens=80000
max_tokens=20000
dec_model=checkpoint_best.pt
cmd="./run.sh
......
......@@ -25,6 +25,7 @@ pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
fp16-scale-tolerance: 0.25
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
......
......@@ -18,7 +18,7 @@ exp_tag=
#config_list=(base conformer)
config_list=(pds_base_8 ctc)
#config_list=(pds_base_16 conformer rpr)
#config_list=(pds_base_16 conformer)
# exp full name
exp_name=
......
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
......@@ -3,15 +3,15 @@
gpu_num=1
data_dir=
test_subset=(test)
test_subset=(valid test)
exp_name=
if [ "$#" -eq 1 ]; then
exp_name=$1
fi
sacrebleu=1
n_average=10
sacrebleu=0
n_average=5
beam_size=5
len_penalty=1.0
max_tokens=80000
......
......@@ -3,14 +3,16 @@ arch: s2t_dual
asr-encoder: pds
mt-encoder-layers: 30
encoder-drop-net: True
encoder-drop-net-prob: 0.8
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 1000
lr: 5e-4
#lr: 1e-5
lr: 1e-3
adam_betas: (0.9,0.98)
criterion: join_speech_and_text_loss
......@@ -56,9 +58,9 @@ pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
load-pretrained-asr-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/asr/0308_lcrm_unified_pds_base_8_grow_conformer_ctc_baseline_clamp/avg_10_checkpoint.pt
load-pretrained-mt-encoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
#load-pretrained-asr-encoder-from:
#load-pretrained-mt-encoder-from:
#load-pretrained-decoder-from:
......@@ -6,7 +6,6 @@ lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 1000
lr: 5e-4
#lr: 1e-5
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
......@@ -52,9 +51,7 @@ pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt
load-pretrained-acoustic-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/asr/0308_lcrm_unified_pds_base_8_grow_conformer_ctc_baseline_clamp/avg_10_checkpoint.pt
load-pretrained-text-encoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
#load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
......@@ -10,8 +10,8 @@ if [ "$#" -eq 1 ]; then
exp_name=$1
fi
sacrebleu=0
n_average=1
sacrebleu=1
n_average=10
beam_size=5
len_penalty=1.0
max_tokens=80000
......
......@@ -14,13 +14,11 @@ extra_parameter=
exp_tag=
#config_list=(base)
#config_list=(sate ctc)
#config_list=(ctc conformer rpr)
#config_list=(base sate)
#config_list=(base ctc)
#config_list=(pds_base_8 conformer)
#config_list=(sate ctc)
config_list=(sate_pds ctc)
#config_list=(pds_base_8)
#config_list=(pds_base_8 conformer)
# exp full name
exp_name=
......
arch: s2t_ctc
encoder-type: transformer
arch: s2t_sate
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
weight-decay: 1e-6
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv2d
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 176
subsampling-kernel: 3
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: batch2d
subsampling-activation: swish
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 176
encoder-ffn-embed-dim: 704
encoder-layers: 16
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
\ No newline at end of file
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
#inter_mixup: True
#inter_mixup_layer: -1
#inter_mixup_ratio: 0.2
ctc-weight: 0.2
interleaved-ctc-weight: 0.1
interleaved-ctc-layers: 6,9
interleaved-temperature: 2
#target-ctc-weight: 0.3
#target-ctc-layer: 6
target-interleaved-ctc-weight: 0.1
target-interleaved-ctc-layers: 2,4
sae-adapter: league
share-ctc-and-sae: False
sae-drop-prob: 0.2
interleaved-ctc-drop-prob: 0.2
sae-distribution-cutoff: 10
ctc-self-distill-weight: 0
post-process: sentencepiece
\ No newline at end of file
#! /bin/bash
# Processing LibriSpeech En-Fr Datasets
# Processing LibriSpeech En-Fr ST Datasets
# Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com)
......
......@@ -13,11 +13,11 @@ extra_parameter=
exp_tag=
#config_list=(base)
#config_list=(base conformer)
#config_list=(base ctc)
#config_list=(base ctc conformer)
#config_list=(pds_base_8)
config_list=(pds_base_8 conformer rpr)
#config_list=(pds_base_8 ctc)
config_list=(pds_base_8 conformer)
# exp full name
exp_name=
......
arch: s2t_ctc
encoder-type: transformer
optimizer: adam
#clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
subsampling-type: conv2d
subsampling-layers: 2
subsampling-filter: 176
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: batch2d
subsampling-activation: swish
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 176
encoder-ffn-embed-dim: 704
encoder-layers: 16
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
\ No newline at end of file
......@@ -8,6 +8,7 @@ best-checkpoint-metric: loss
maximize-best-checkpoint-metric: False
post-process: sentencepiece
validate-interval: 1
no-epoch-checkpoints: True
#keep-last-epochs: 10
keep-best-checkpoints: 10
......
......@@ -5,7 +5,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
lr: 0.0014
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
......
......@@ -38,11 +38,11 @@ post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-layers: 16
encoder-layers: 12
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
......
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 240
pds-stages: 3
#ctc-layer: 15
pds-layers: 5_5_5
pds-ratios: 2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 5_5_5
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-layers: 15
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
#load-pretrained-encoder-from:
arch: pdss2t_transformer_s_16
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_9_3
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 160_192_224_256
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 16
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
arch: pdss2t_transformer_s_16
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_256_320
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
......@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_32
encoder-embed-dim: 256
pds-stages: 5
ctc-layer: 12
#ctc-layer: 12
pds-layers: 2_2_3_3_2
pds-ratios: 2_2_2_2_2
pds-fusion: True
......
arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_1
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
......@@ -2,7 +2,7 @@ arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
......
arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 5_3_3_5
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_224_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 16
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_256_256_320
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
......@@ -21,7 +21,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
lr: 0.0014
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
......
......@@ -2,7 +2,6 @@ arch: pdss2t_transformer_m_32
encoder-embed-dim: 512
pds-stages: 5
#pds-dropout: 0
pds-layers: 2_2_3_3_2
pds-ratios: 2_2_2_2_2
pds-fusion: True
......@@ -21,7 +20,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
lr: 0.0014
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
......
......@@ -20,7 +20,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
lr: 0.0014
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
......
arch: s2t_ctc
encoder-type: transformer
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.002
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
......@@ -12,7 +12,7 @@ encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
#ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
......@@ -41,4 +41,4 @@ activation-fn: relu
encoder-layers: 12
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
#load-pretrained-decoder-from:
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_9_3
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 160_192_224_256
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 18
encoder-attention-heads: 4
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_256_320
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 320
pds-stages: 4
#ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_256_320
pds-ds-method: conv
pds-embed-norm: True
#pds-embed-norm: False
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
......@@ -12,7 +12,7 @@ encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
......@@ -48,4 +48,4 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
#load-pretrained-decoder-from:
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 5_3_3_5
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_224_224_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 16
encoder-attention-heads: 4
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_256_256_320
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
......@@ -3,7 +3,7 @@
gpu_num=1
data_dir=
test_subset=(dev-clean dev-other test-clean test-other)
test_subset=(dev-clean dev-other test-clean test-other all)
exp_name=
if [ "$#" -eq 1 ]; then
......@@ -13,7 +13,7 @@ fi
n_average=10
beam_size=5
len_penalty=1.0
max_tokens=80000
max_tokens=100000
dec_model=checkpoint_best.pt
cmd="./run.sh
......
......@@ -55,7 +55,7 @@ exp_tag=baseline
exp_name=
# config
train_config=ctc
train_config=base
data_config=config.yaml
# training setting
......@@ -190,32 +190,19 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--fp16"
fi
if [[ $step_valid -eq 1 ]]; then
validate_interval=1
save_interval=1
keep_last_epochs=10
no_epoch_checkpoints=0
save_interval_updates=500
keep_interval_updates=10
else
validate_interval=1
keep_last_epochs=10
fi
if [[ -n $no_epoch_checkpoints && $no_epoch_checkpoints -eq 1 ]]; then
cmd="$cmd
--no-epoch-checkpoints"
fi
if [[ -n $validate_interval ]]; then
cmd="${cmd}
--validate-interval $validate_interval "
fi
if [[ -n $save_interval ]]; then
cmd="${cmd}
--save-interval $save_interval "
fi
if [[ -n $keep_last_epochs ]]; then
cmd="${cmd}
--keep-last-epochs $keep_last_epochs "
fi
if [[ -n $save_interval_updates ]]; then
cmd="${cmd}
--save-interval-updates $save_interval_updates"
......@@ -271,10 +258,19 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi
export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result
suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ -z ${cer} && ${cer} -eq 1 ]]; then
suffix=${suffix}_cer
else
suffix=${suffix}_wer
fi
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file}
test_subset=(${test_subset//,/ })
test_subset=${test_subset//,/ }
for subset in ${test_subset[@]}; do
subset=${subset}
cmd="python ${code_dir}/fairseq_cli/generate.py
......@@ -293,6 +289,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then
eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi
done
cat ${result_file}
......
......@@ -12,14 +12,45 @@ extra_parameter=
#extra_parameter="${extra_parameter} "
exp_tag=
# Transformer
#config_list=(base)
#config_list=(pds_base_16)
#config_list=(pds_base_8)
# CTC
#config_list=(purectc_base)
#config_list=(purectc_pds_base_8)
#config_list=(purectc_pds_base_8_growth)
#config_list=(purectc_pds_base_8_growth_fusion256)
#config_list=(purectc_pds_base_16)
#config_list=(purectc_pds_base_16_growth)
#config_list=(purectc_pds_base_16_growth_fusion256)
#config_list=(purectc_pds_base_16_growth_fusion320)
# conformer
#config_list=(base conformer)
#config_list=(ConformerCTCSmall)
#config_list=(big conformer)
#config_list=(pds_base_4 conformer)
#config_list=(pds_base_16 conformer)
config_list=(pds_base_32 conformer)
#config_list=(pds_big_8 conformer)
#config_list=(pds_big_16 conformer)
#config_list=(pds_big_32 conformer)
#config_list=(pds_base_8_growth_fusion256 conformer)
# growth validation
#config_list=(pds_base_8_growth)
#config_list=(pds_base_8_growth_fusion256)
#config_list=(pds_base_16_growth_fusion256)
#config_list=(pds_base_16_growth)
config_list=(purectc_pds_base_16)
#config_list=(pds_base)
#config_list=(pds_big)
#config_list=(pds_deep)
# compare with Effective
#config_list=(purectc_base_compare)
#config_list=(purectc_pds_base_8_compare)
#config_list=(purectc_pds_base_8_compare2)
#config_list=(EffecientConformerCTCSmall)
#config_list=(purectc_pds_base_16)
# exp full name
exp_name=
......
......@@ -16,4 +16,5 @@ no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
skip-invalid-size-inputs-valid-test: True
\ No newline at end of file
skip-invalid-size-inputs-valid-test: True
post-process: sentencepiece
\ No newline at end of file
ctc-weight: 0.2
intermedia-ctc-layers: 6,9
intermedia-adapter: league
intermedia-ctc-weight: 0.1
interleaved-ctc-weight: 0.1
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
sae-adapter: league
sae-drop-prob: 0.2
sae-distribution-cutoff: 10
share-ctc-and-sae: False
ctc-self-distill-weight: 0
post-process: sentencepiece
\ No newline at end of file
inter_mixup: True
inter_mixup_layer: 0
inter_mixup_layer: -1
inter_mixup_prob: 1.0
inter_mixup_ratio: 0.2
\ No newline at end of file
......@@ -240,13 +240,9 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [[ $step_valid -eq 1 ]]; then
validate_interval=1
save_interval=1
keep_last_epochs=10
no_epoch_checkpoints=0
save_interval_updates=500
keep_interval_updates=10
else
validate_interval=1
keep_last_epochs=10
fi
if [[ -n $no_epoch_checkpoints && $no_epoch_checkpoints -eq 1 ]]; then
cmd="$cmd
......@@ -260,10 +256,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
cmd="${cmd}
--save-interval $save_interval "
fi
if [[ -n $keep_last_epochs ]]; then
cmd="${cmd}
--keep-last-epochs $keep_last_epochs "
fi
if [[ -n $save_interval_updates ]]; then
cmd="${cmd}
--save-interval-updates $save_interval_updates"
......@@ -282,11 +274,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device}
cmd="nohup ${cmd} >> ${model_dir}/train.log 2>&1 &"
log=${model_dir}/train.log
cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then
eval $cmd
sleep 2s
tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log
tail -n "$(wc -l ${log} | awk '{print $1+1}')" -f ${log}
fi
fi
wait
......@@ -319,7 +312,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi
export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result
suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ -z ${cer} && ${cer} -eq 1 ]]; then
suffix=${suffix}_cer
else
suffix=${suffix}_wer
fi
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file}
test_subset=${test_subset//,/ }
......@@ -351,8 +353,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then
eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi
done
cat ${result_file}
fi
train-subset: train
valid-subset: dev
valid-subset: valid
max-epoch: 50
max-update: 100000
......
......@@ -351,10 +351,14 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi
export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result
suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file}
test_subset=(${test_subset//,/ })
test_subset=${test_subset//,/ }
for subset in ${test_subset[@]}; do
cmd="python ${code_dir}/fairseq_cli/generate.py
${data_dir}
......@@ -385,6 +389,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then
eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi
done
cat ${result_file}
......
......@@ -16,4 +16,5 @@ no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
skip-invalid-size-inputs-valid-test: True
\ No newline at end of file
skip-invalid-size-inputs-valid-test: True
post-process: sentencepiece
\ No newline at end of file
......@@ -45,6 +45,6 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-asr-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/0225_st_purectc_pds_base_8_baseline_topctc/avg_10_checkpoint.pt
#load-pretrained-mt-encoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt
\ No newline at end of file
#load-pretrained-asr-encoder-from:
#load-pretrained-mt-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
ctc-weight: 0.2
intermedia-ctc-weight: 0.1
intermedia-ctc-layers: 6,9
interleaved-ctc-weight: 0.1
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
target-interleaved-ctc-weight: 0.1
target-interleaved-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
sae-adapter: league
sae-drop-prob: 0.0
#sae-distribution-cutoff: 10
share-ctc-and-sae: False
share-target-ctc-and-sae: False
ctc-self-distill-weight: 0
post-process: sentencepiece
\ No newline at end of file
ctc-self-distill-weight: 0
\ No newline at end of file
inter_mixup: True
inter_mixup_layer: -1
inter_mixup_prob: 1.0
inter_mixup_ratio: 0.2
\ No newline at end of file
......@@ -369,7 +369,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi
export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result
suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
if [[ ${sacrebleu} -eq 1 ]]; then
suffix=${suffix}_sacrebleu
else
suffix=${suffix}_multibleu
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file}
test_subset=${test_subset//,/ }
......@@ -402,6 +411,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ $eval -eq 1 ]]; then
eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
fi
done
cat ${result_file}
......
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
......@@ -14,7 +14,7 @@ sacrebleu=0
n_average=5
beam_size=4
len_penalty=0.6
max_tokens=80000
max_tokens=4000
dec_model=checkpoint_best.pt
cmd="./run.sh
......
......@@ -12,6 +12,7 @@ extra_parameter=
#extra_parameter="${extra_parameter} "
exp_tag=baseline
config_list=(base)
config_list=(deep)
# exp full name
......
......@@ -484,7 +484,7 @@ def main():
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
choices=["word", "bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=8000, type=int)
parser.add_argument("--share", action="store_true",
......
......@@ -296,7 +296,8 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
if arg_overrides is not None:
overwrite_args_by_name(state["cfg"], arg_overrides)
state = _upgrade_state_dict(state)
if len(state.keys()) != 1:
state = _upgrade_state_dict(state)
return state
......
......@@ -43,17 +43,17 @@ class CtcCriterionConfig(FairseqDataclass):
default=0.0,
metadata={"help": "weight of CTC entropy"},
)
intermedia_ctc_weight: float = field(
interleaved_ctc_weight: float = field(
default=0.0,
metadata={"help": "weight of intermedia CTC loss"},
metadata={"help": "weight of interleaved CTC loss"},
)
target_ctc_weight: float = field(
default=0.0,
metadata={"help": "weight of CTC loss for target sentence"},
)
target_intermedia_ctc_weight: float = field(
target_interleaved_ctc_weight: float = field(
default=0.0,
metadata={"help": "weight of intermedia CTC loss for target sentence"},
metadata={"help": "weight of interleaved CTC loss for target sentence"},
)
ctc_self_distill_weight: float = field(
default=0.0,
......@@ -127,13 +127,13 @@ class CtcCriterion(FairseqCriterion):
self.sentence_avg = cfg.sentence_avg
self.ctc_weight = ctc_weight
self.intermedia_ctc_weight = cfg.intermedia_ctc_weight
self.interleaved_ctc_weight = cfg.interleaved_ctc_weight
self.target_ctc_weight = cfg.target_ctc_weight
self.target_intermedia_ctc_weight = cfg.target_intermedia_ctc_weight
self.target_interleaved_ctc_weight = cfg.target_interleaved_ctc_weight
self.ctc_self_distill_weight = cfg.ctc_self_distill_weight
self.ctc_entropy = cfg.ctc_entropy
self.all_ctc_weight = self.ctc_weight + self.intermedia_ctc_weight + \
self.target_ctc_weight + self.target_intermedia_ctc_weight + \
self.all_ctc_weight = self.ctc_weight + self.interleaved_ctc_weight + \
self.target_ctc_weight + self.target_interleaved_ctc_weight + \
self.ctc_self_distill_weight + self.ctc_entropy
if self.all_ctc_weight > 0:
......@@ -188,6 +188,7 @@ class CtcCriterion(FairseqCriterion):
pad_mask = (tokens != self.pad_idx) & (
tokens != self.eos_idx
)
if mixup:
mask1 = pad_mask[mixup_idx1]
mask2 = pad_mask[mixup_idx2]
......@@ -222,19 +223,20 @@ class CtcCriterion(FairseqCriterion):
# ctc_logit = ctc_logit / ctc_logit.sum(dim=-1, keepdim=True)
# cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100]
# ctc_entropy = Categorical(logits=cut_ctc_logit).entropy().sum()
ctc_entropy = Categorical(logits=ctc_logit).entropy().sum()
logging_output["ctc_entropy"] = utils.item(ctc_entropy.data)
logging_output["ctc_loss"] = utils.item(ctc_loss.data)
intermedia_ctc_num = 0
intermedia_ctc_loss = 0
if "intermedia_ctc_logits" in net_output:
intermedia_ctc_num = len(net_output["intermedia_ctc_logits"])
interleaved_ctc_num = 0
interleaved_ctc_loss = 0
if "interleaved_ctc_logits" in net_output:
interleaved_ctc_num = len(net_output["interleaved_ctc_logits"])
# calculate the intermedia CTC loss
if self.intermedia_ctc_weight > 0 and intermedia_ctc_num > 0:
for i in range(intermedia_ctc_num):
out = net_output["intermedia_ctc_logits"][i]
# calculate the interleaved CTC loss
if self.interleaved_ctc_weight > 0 and interleaved_ctc_num > 0:
for i in range(interleaved_ctc_num):
out = net_output["interleaved_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
......@@ -249,19 +251,19 @@ class CtcCriterion(FairseqCriterion):
inter_lprobs.batch_first = False
for flat, lengths, coef in zip(transcript_flat, transcript_lengths, loss_coef):
intermedia_ctc_loss += self.get_loss(inter_lprobs, flat, inter_input_lengths, lengths) * coef
interleaved_ctc_loss += self.get_loss(inter_lprobs, flat, inter_input_lengths, lengths) * coef
intermedia_ctc_loss /= intermedia_ctc_num
logging_output["intermedia_ctc_loss"] = utils.item(intermedia_ctc_loss.data)
interleaved_ctc_loss /= interleaved_ctc_num
logging_output["interleaved_ctc_loss"] = utils.item(interleaved_ctc_loss.data)
if lprobs is None:
lprobs = inter_lprobs
target_ctc_loss = 0
target_intermedia_ctc_loss = 0
target_interleaved_ctc_loss = 0
# calculate the target CTC loss
if self.target_ctc_weight > 0 or self.target_intermedia_ctc_weight:
if self.target_ctc_weight > 0 or self.target_interleaved_ctc_weight:
target = sample["target"]
pad_mask = (target != self.pad_idx) & (target != self.eos_idx)
......@@ -292,12 +294,12 @@ class CtcCriterion(FairseqCriterion):
for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_ctc_loss += self.get_loss(tgt_lprobs, flat, input_lengths, lengths) * coef
target_intermedia_ctc_num = 0
if "target_intermedia_ctc_logits" in net_output:
target_intermedia_ctc_num = len(net_output["target_intermedia_ctc_logits"])
target_interleaved_ctc_num = 0
if "target_interleaved_ctc_logits" in net_output:
target_interleaved_ctc_num = len(net_output["target_interleaved_ctc_logits"])
for i in range(target_intermedia_ctc_num):
out = net_output["target_intermedia_ctc_logits"][i]
for i in range(target_interleaved_ctc_num):
out = net_output["target_interleaved_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
......@@ -312,17 +314,17 @@ class CtcCriterion(FairseqCriterion):
tgt_inter_lprobs.batch_first = False
for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_intermedia_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths, lengths) * coef
target_interleaved_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths, lengths) * coef
target_intermedia_ctc_loss /= target_intermedia_ctc_num
logging_output["target_intermedia_ctc_loss"] = utils.item(target_intermedia_ctc_loss.data)
target_interleaved_ctc_loss /= target_interleaved_ctc_num
logging_output["target_interleaved_ctc_loss"] = utils.item(target_interleaved_ctc_loss.data)
# calculate the self distillation CTC loss
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and intermedia_ctc_num > 0:
for i in range(intermedia_ctc_num):
out = net_output["intermedia_ctc_logits"][i]
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0:
for i in range(interleaved_ctc_num):
out = net_output["interleaved_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
......@@ -347,9 +349,9 @@ class CtcCriterion(FairseqCriterion):
loss = \
self.ctc_weight * ctc_loss + \
self.intermedia_ctc_weight * intermedia_ctc_loss + \
self.interleaved_ctc_weight * interleaved_ctc_loss + \
self.target_ctc_weight * target_ctc_loss + \
self.target_intermedia_ctc_weight * target_intermedia_ctc_loss + \
self.target_interleaved_ctc_weight * target_interleaved_ctc_loss + \
self.ctc_self_distill_weight * ctc_self_distill_loss + \
self.ctc_entropy * ctc_entropy
......@@ -359,8 +361,8 @@ class CtcCriterion(FairseqCriterion):
logger.warning("Illegal loss %f!" % loss)
if self.ctc_weight != 0:
logger.warning("CTC loss %f!" % ctc_loss)
if self.intermedia_ctc_weight != 0:
logger.warning("Intermedia CTC loss %f!" % intermedia_ctc_loss)
if self.interleaved_ctc_weight != 0:
logger.warning("Intermedia CTC loss %f!" % interleaved_ctc_loss)
if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss)
......@@ -448,13 +450,13 @@ class CtcCriterion(FairseqCriterion):
sum(log.get("ctc_entropy", 0) for log in logging_outputs)
)
inter_ctc_loss_sum = utils.item(
sum(log.get("intermedia_ctc_loss", 0) for log in logging_outputs)
sum(log.get("interleaved_ctc_loss", 0) for log in logging_outputs)
)
target_ctc_loss_sum = utils.item(
sum(log.get("target_ctc_loss", 0) for log in logging_outputs)
)
target_intermedia_ctc_loss_sum = utils.item(
sum(log.get("target_intermedia_ctc_loss", 0) for log in logging_outputs)
target_interleaved_ctc_loss_sum = utils.item(
sum(log.get("target_interleaved_ctc_loss", 0) for log in logging_outputs)
)
ctc_self_distill_loss_sum = utils.item(
sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs)
......@@ -505,7 +507,7 @@ class CtcCriterion(FairseqCriterion):
)
if inter_ctc_loss_sum > 0:
metrics.log_scalar(
"intermedia_ctc_loss",
"interleaved_ctc_loss",
inter_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
......@@ -517,10 +519,10 @@ class CtcCriterion(FairseqCriterion):
sample_size,
round=3,
)
if target_intermedia_ctc_loss_sum > 0:
if target_interleaved_ctc_loss_sum > 0:
metrics.log_scalar(
"target_intermedia_ctc_loss",
target_intermedia_ctc_loss_sum / sample_size / math.log(2),
"target_interleaved_ctc_loss",
target_interleaved_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
......
......@@ -89,6 +89,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
if self.ctc_criterion.all_ctc_weight > 0:
ctc_loss, logging_output = self.ctc_criterion.compute_ctc_loss(model, sample, encoder_out, logging_output)
loss = (1 - self.ctc_weight) * loss + ctc_loss
# if hasattr(model.encoder, "get_loss"):
# encoder_loss = model.encoder.get_loss()
# if encoder_loss != 0:
# loss += encoder_loss * sample_size
# logging_output["encoder_loss"] = utils.item(encoder_loss.data)
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
......@@ -103,6 +109,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs)
)
enc_loss_sum = utils.item(
sum(log.get("encoder_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
......@@ -121,6 +130,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
if enc_loss_sum != 0:
metrics.log_scalar("enc_loss", enc_loss_sum, sample_size, round=3)
if "ctc_loss" in logging_outputs[0] or "all_ctc_loss" in logging_outputs[0]:
CtcCriterion.reduce_metrics(logging_outputs)
......
......@@ -13,7 +13,7 @@ from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text import S2TTransformerModel
from .s2t_transformer import S2TTransformerModel
from fairseq.modules.speech_to_text import CTC, Adapter
from fairseq.modules import (
......@@ -141,334 +141,13 @@ class PDSS2TTransformerModel(S2TTransformerModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# subsampling
parser.add_argument(
"--subsampling-type",
type=str,
help="subsampling type, like conv1d and conv2d",
)
parser.add_argument(
"--subsampling-layers",
type=int,
help="subsampling layers",
)
parser.add_argument(
"--subsampling-filter",
type=int,
help="subsampling filter",
)
parser.add_argument(
"--subsampling-kernel",
type=int,
help="subsampling kernel",
)
parser.add_argument(
"--subsampling-stride",
type=int,
help="subsampling stride",
)
parser.add_argument(
"--subsampling-norm",
type=str,
default="none",
help="subsampling normalization type",
)
parser.add_argument(
"--subsampling-activation",
type=str,
default="none",
help="subsampling activation function type",
)
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
"rel_pos_legacy",
"rel_pos",
"rope",
"abs",
"transfer",
"reduced_rel_pos",
],
help="transformer encoder self-attention layer type"
)
# transfer
parser.add_argument(
"--relative-pos-enc",
action="store_true",
help="use relative position encoding for attention",
)
parser.add_argument(
"--linear-att",
action="store_true",
help="use linear attention",
)
# reduced attention
parser.add_argument(
"--attention-reduced-method",
type=str,
default="conv",
help="reduction method for attention",
)
parser.add_argument(
"--attention-reduced-q",
action="store_true",
help="use reduction for query or not"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument('--share-all-embeddings',
action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument('--init-value', type=str, default='avg', choices=['avg', 'one'],
help='how to init the learned weight matrix')
parser.add_argument('--weight-type', type=str, default='scalar',
help='type of learned weight [scalar, scalar_n(n>1), vector]')
parser.add_argument('--encoder-learnable', type=eval, default='True',
help='enable to learn weights for encoder')
parser.add_argument('--decoder-learnable', type=eval, default='True',
help='enable to learn weights for decoder')
parser.add_argument('--normalize-learned-weight', type=eval, default='False',
help='normalize learned weight by softmax')
parser.add_argument('--normalize-embedding', type=eval, default='False',
help='normalize the input of embedding')
parser.add_argument('--history-dropout', type=float, default=0.0, metavar='D',
help='dropout for history output')
parser.add_argument('--history-window-size', type=int, default='-1',
help='how many past layers are considered. -1 means all')
# CTC
parser.add_argument(
"--ctc-layer",
default=0,
type=int,
help="the position of the ctc loss",
)
S2TTransformerModel.add_args(parser)
PDSS2TTransformerModel.add_specific_args(parser)
# local modeling
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument(
"--encoder-activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# pds setting
@staticmethod
def add_specific_args(parser):
"""Add specific arguments to the parser."""
# PDS setting
parser.add_argument(
"--pds-stages",
type=int,
......@@ -561,69 +240,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
help="use the ctc after each stage",
)
# intermedia ctc
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
parser.add_argument(
"--intermedia-temperature",
default=1,
type=float,
help="temperature of the intermedia ctc probability",
)
# mixup
parser.add_argument(
"--inter-mixup",
action="store_true",
help="use mixup or not",
)
parser.add_argument(
"--inter-mixup-layer",
default=None,
type=int,
help="the layers for mixup",
)
parser.add_argument(
"--inter-mixup-beta",
default=0.5,
type=float,
help="the coefficient beta for mixup",
)
parser.add_argument(
"--inter-mixup-prob",
default=1,
type=float,
help="the probability for mixup",
)
parser.add_argument(
"--inter-mixup-ratio",
default=1,
type=float,
help="the ratio for mixup",
)
pass
@classmethod
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = PDSS2TTransformerEncoder(args, task, embed_tokens)
......@@ -707,7 +323,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
args.pds_ctc = getattr(args, "pds_ctc", None)
self.pds_ctc = [int(n) for n in args.pds_ctc.split("_")] if args.pds_ctc is not None else None
inter_ctc_module = None
inter_adapter = None
sae_adapter = None
for i in range(self.pds_stages):
num_layers = self.pds_layers[i]
......@@ -833,7 +449,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
else:
logger.error("Unsupported fusion transform!")
# intermedia modules for each stage
# interleaved modules for each stage
if use_ctc:
if inter_ctc_module is None:
ctc = CTC(embed_dim,
......@@ -847,7 +463,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
else:
ctc = inter_ctc_module
if i != self.pds_stages - 1:
if inter_adapter is None:
if sae_adapter is None:
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", "avg")
......@@ -877,10 +493,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.fusion_weight = nn.Parameter(torch.Tensor(fusion_stages_num).fill_(1.0))
self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True)
# self.use_ctc = "sate" in args.arch or \
# (getattr(args, "criterion", "") == "ctc") or \
# (("ctc" in getattr(args, "criterion", "")) and
# (getattr(args, "ctc_weight", False) > 0))
self.use_ctc = "sate" in args.arch or (getattr(args, "ctc_weight", 0) > 0)
if self.use_ctc:
# self.ctc_layer = (args.ctc_layer + self.layers) % self.layers
......@@ -890,7 +502,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != 0 else False
if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
logger.info("Interleaved CTC loss in layer %d" % self.ctc_layer)
# embed_dim = self.pds_embed_dims[-1]
embed_dim = self.embed_dim
......@@ -1105,6 +717,12 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
seq_len = x.size(0)
for state in prev_state:
i += 1
# padding = prev_padding[i]
# if padding is not None:
# zero_padding = padding.transpose(0, 1).unsqueeze(2)
# state.masked_fill_(zero_padding, 0.0)
fusion_downsampling = getattr(self, f"fusion_downsampling{i + 1}")
fusion_pre_layer_norm = getattr(self, f"fusion_pre_layer_norm{i + 1}")
fusion_post_layer_norm = getattr(self, f"fusion_post_layer_norm{i + 1}")
......@@ -1144,6 +762,22 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
"src_lengths": [],
}
def get_loss(self):
if not self.pds_fusion:
return 0
weight = self.fusion_weight
loss = 0
for i in range(self.fusion_stages_num - 1):
sub = weight[i] - weight[i + 1]
if sub > 0:
loss += sub
if weight[i] < 0:
loss += weight[i]
loss += (0.5 * (weight.sum() - 1.0) ** 2).mean()
return loss
def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0
......@@ -1191,6 +825,7 @@ def base_architecture(args):
args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
......@@ -1211,6 +846,10 @@ def base_architecture(args):
args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
......@@ -1219,6 +858,7 @@ def base_architecture(args):
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
......@@ -1227,14 +867,41 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
args.embed_linear = getattr(args, "embed_linear", False)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.cnn_module_norm = getattr(args, "cnn_module_norm", "batch_norm")
# Relative position encoding
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# interleaved CTC
args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 1)
# PDS
args.pds_stages = getattr(args, "pds_stages", None)
......@@ -1254,23 +921,10 @@ def base_architecture(args):
args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
args.pds_attn_strides = getattr(args, "pds_attn_strides", None)
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC
args.pds_ctc = getattr(args, "pds_ctc", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 0.5)
def set_pds_base_8(args):
args.pds_stages = getattr(args, "pds_stages", 4)
......
......@@ -12,6 +12,8 @@ from fairseq.models import (
register_model_architecture,
)
from .s2t_transformer import S2TTransformerModel, S2TTransformerEncoder
from .pdss2t_transformer import PDSS2TTransformerModel, PDSS2TTransformerEncoder
from torch import Tensor
......@@ -27,465 +29,8 @@ class S2TCTCModel(FairseqEncoderModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# subsampling
parser.add_argument(
"--subsampling-type",
type=str,
help="subsampling type, like conv1d and conv2d",
)
parser.add_argument(
"--subsampling-layers",
type=int,
help="subsampling layers",
)
parser.add_argument(
"--subsampling-filter",
type=int,
help="subsampling filter",
)
parser.add_argument(
"--subsampling-kernel",
type=int,
help="subsampling kernel",
)
parser.add_argument(
"--subsampling-stride",
type=int,
help="subsampling stride",
)
parser.add_argument(
"--subsampling-norm",
type=str,
default="none",
help="subsampling normalization type",
)
parser.add_argument(
"--subsampling-activation",
type=str,
default="none",
help="subsampling activation function type",
)
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
"rel_pos",
"rope",
"abs",
"transfer",
"reduced_rel_pos",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--relative-pos-enc",
action="store_true",
help="use relative position encoding for attention",
)
parser.add_argument(
"--linear-att",
action="store_true",
help="use linear attention",
)
parser.add_argument(
"--attention-reduced-method",
type=str,
default="conv",
help="reduction method for attention",
)
parser.add_argument(
"--attention-reduced-q",
action="store_true",
help="use reduction for query or not",
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument('--share-all-embeddings',
action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument('--init-value', type=str, default='avg', choices=['avg', 'one'],
help='how to init the learned weight matrix')
parser.add_argument('--weight-type', type=str, default='scalar',
help='type of learned weight [scalar, scalar_n(n>1), vector]')
parser.add_argument('--encoder-learnable', type=eval, default='True',
help='enable to learn weights for encoder')
parser.add_argument('--decoder-learnable', type=eval, default='True',
help='enable to learn weights for decoder')
parser.add_argument('--normalize-learned-weight', type=eval, default='False',
help='normalize learned weight by softmax')
parser.add_argument('--normalize-embedding', type=eval, default='False',
help='normalize the input of embedding')
parser.add_argument('--history-dropout', type=float, default=0.0, metavar='D',
help='dropout for history output')
parser.add_argument('--history-window-size', type=int, default='-1',
help='how many past layers are considered. -1 means all')
# CTC
parser.add_argument(
"--ctc-layer",
default=0,
type=int,
help="the position of the ctc loss",
)
# local modeling
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument(
"--encoder-activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# Simultaneous speech translation
parser.add_argument(
"--simul",
default=False,
action="store_true",
help="Simultaneous speech translation or not",
)
# interleaved dropout
parser.add_argument('--interleave-dropout', type=int,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout',
action="store_true",
default=False,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout-epoch',
type=int,
default=None,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout-strategy',
type=str,
help='interleaved dropout probability')
# pds setting
parser.add_argument(
"--pds-stages",
type=int,
help="the number of the stage",
)
parser.add_argument(
"--pds-layers",
type=str,
help="the number of the encoder layers in each stage",
)
parser.add_argument(
"--pds-ratios",
type=str,
help="the ratio of the down-sampling in each stage",
)
parser.add_argument(
"--pds-ds-method",
type=str,
choices=["glu", "conv", "proj", "fusion"],
help="the down-sampling method",
)
parser.add_argument(
"--pds-embed-dims",
type=str,
help="the embedding dimension in each stage",
)
parser.add_argument(
"--pds-kernel-sizes",
type=str,
help="the kernel size of the down-sampling module in each stage",
)
parser.add_argument(
"--pds-embed-norm",
action="store_true",
help="use layer norm in the down-sampling module",
)
parser.add_argument(
"--pds-position-embed",
type=str,
help="use the position embedding or not before each encoding",
)
parser.add_argument(
"--pds-attn-heads",
type=str,
help="the number of the attention heads in each stage",
)
parser.add_argument(
"--pds-attn-ds-ratios",
type=str,
help="the ratio of the down-sampling in the self attention module",
)
parser.add_argument(
"--pds-ffn-ratios",
type=str,
help="the ratio of the ffn in each stage",
)
parser.add_argument(
"--pds-conv-strides",
type=str,
help="the strides of the convolutional module (conformer) in each stage",
)
parser.add_argument(
"--pds-attn-strides",
type=str,
help="the strides of the attention module (conformer) in each stage",
)
parser.add_argument(
"--pds-fusion",
action="store_true",
help="use the representation fusion method",
)
parser.add_argument(
"--pds-fusion-method",
type=str,
help="the fusion method",
)
parser.add_argument(
"--pds-dropout",
type=float,
help="dropout in each stage",
)
parser.add_argument(
"--pds-ctc",
type=str,
help="use the ctc after each stage",
)
# intermedia CTC loss
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
S2TTransformerModel.add_args(parser)
PDSS2TTransformerModel.add_specific_args(parser)
# encoder
parser.add_argument(
......@@ -497,7 +42,7 @@ class S2TCTCModel(FairseqEncoderModel):
pass
@classmethod
def build_encoder(cls, args, task=None, embed_tokens=None):
def build_encoder(cls, args, task=None):
encoder = S2TCTCEncoder(args, task)
if getattr(args, "load_pretrained_encoder_from", None):
logger.info(
......@@ -561,10 +106,8 @@ class S2TCTCEncoder(FairseqEncoder):
setattr(args, "ctc_weight", 1.0)
encoder_type = getattr(args, "encoder_type", "transformer")
if encoder_type == "transformer":
from .s2t_transformer import S2TTransformerEncoder
self.encoder = S2TTransformerEncoder(args, task)
elif encoder_type == "pds":
from .pdss2t_transformer import PDSS2TTransformerEncoder
self.encoder = PDSS2TTransformerEncoder(args, task)
else:
logger.error("Unsupported architecture: %s." % encoder_type)
......@@ -701,9 +244,11 @@ def base_architecture(args):
args.ctc_layer = getattr(args, "ctc_layer", 0)
# Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.cnn_module_norm = getattr(args, "cnn_module_norm", "batch_norm")
# settings for DLCL
args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False)
......@@ -724,11 +269,23 @@ def base_architecture(args):
args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
# interleaved dropout
args.interleave_dropout = getattr(args, "interleave_dropout", None)
args.cl_dropout = getattr(args, "cl_dropout", False)
args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None)
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
# interleaved CTC
args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 1)
# PDS
args.pds_stages = getattr(args, "pds_stages", None)
......@@ -737,26 +294,22 @@ def base_architecture(args):
args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", True)
args.pds_embed_norm = getattr(args, "pds_embed_norm", False)
args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_cnn_kernel_sizes = getattr(args, "pds_cnn_kernel_sizes", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1")
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
args.pds_attn_strides = getattr(args, "pds_attn_strides", None)
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
@register_model_architecture("s2t_ctc", "s2t_ctc_s")
......
......@@ -76,10 +76,10 @@ class S2TSATEModel(S2TTransformerModel):
help="share the projection weights of the ctc and adapter",
)
parser.add_argument(
"--temperature",
"--adapter-temperature",
default=1.0,
type=float,
help="temperature of the CTC softmax",
help="temperature of the CTC softmax in adapter",
)
parser.add_argument(
"--acoustic-encoder",
......@@ -103,14 +103,19 @@ class S2TSATEModel(S2TTransformerModel):
parser.add_argument(
"--target-ctc-layer",
default=None,
type=str,
type=int,
help="ctc layer for target sentence",
)
parser.add_argument(
"--target-intermedia-ctc-layers",
"--target-interleaved-ctc-layers",
default=None,
type=str,
help="intermedia ctc layers for target sentence",
help="interleaved ctc layers for target sentence",
)
parser.add_argument(
"--share-target-ctc-and-sae",
action="store_true",
help="share the weight of target ctc and sae",
)
# freeze
parser.add_argument(
......@@ -225,38 +230,42 @@ class TextEncoder(FairseqEncoder):
self.ctc.ctc_projection.weight = embed_tokens.weight
self.intermedia_ctc_layers = []
self.target_intermedia_ctc_layers = getattr(args, "target_intermedia_ctc_layers", None)
if self.target_intermedia_ctc_layers is not None:
target_intermedia_ctc_layers = self.target_intermedia_ctc_layers.split(",")
for layer_idx in target_intermedia_ctc_layers:
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.interleaved_ctc_layers = []
self.target_interleaved_ctc_layers = getattr(args, "target_interleaved_ctc_layers", None)
if self.target_interleaved_ctc_layers is not None:
target_interleaved_ctc_layers = self.target_interleaved_ctc_layers.split(",")
for layer_idx in target_interleaved_ctc_layers:
layer_idx = int(layer_idx)
assert layer_idx <= layer_num, (layer_idx, layer_num)
if layer_idx <= 0:
layer_idx += layer_num
self.intermedia_ctc_layers.append(layer_idx)
self.interleaved_ctc_layers.append(layer_idx)
logger.info("Intermedia target CTC loss in layer %d" % layer_idx)
logger.info("Interleaved target CTC loss in layer %d" % layer_idx)
if not self.use_ctc:
self.ctc = CTC(embed_dim,
dictionary_size=len(dictionary),
dropout=args.dropout)
if embed_tokens is not None:
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter,
len(dictionary),
# embed_tokens=embed_tokens,
strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
strategy = {
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
self.sae_adapter = Adapter(embed_dim, args.sae_adapter,
len(dictionary),
strategy=strategy)
if args.share_target_ctc_and_sae and hasattr(self.sae_adapter, "embed_adapter"):
self.ctc.ctc_projection.weight = self.sae_adapter.embed_adapter.weight
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
def forward(self, x, encoder_padding_mask=None, history=None):
......@@ -266,7 +275,7 @@ class TextEncoder(FairseqEncoder):
x = self.dropout_module(x)
target_ctc_logit = None
target_intermedia_ctc_logits = []
target_interleaved_ctc_logits = []
layer_idx = 0
for layer in self.layers:
if history is not None:
......@@ -277,18 +286,18 @@ class TextEncoder(FairseqEncoder):
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
target_ctc_logit = self.ctc(x.clone())
if layer_idx != self.layer_num and layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0:
if layer_idx != self.layer_num and layer_idx in self.interleaved_ctc_layers:
if self.interleaved_ctc_drop_prob > 0:
p = torch.rand(1).uniform_()
if p < self.intermedia_drop_prob:
if p < self.interleaved_ctc_drop_prob:
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
target_intermedia_ctc_logits.append(logit)
target_interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae_adapter([x, prob], encoder_padding_mask)
if history is not None:
history.push(x)
......@@ -302,7 +311,7 @@ class TextEncoder(FairseqEncoder):
if self.use_ctc and target_ctc_logit is None:
target_ctc_logit = self.ctc(x)
return x, target_ctc_logit, target_intermedia_ctc_logits
return x, target_ctc_logit, target_interleaved_ctc_logits
class S2TSATEEncoder(FairseqEncoder):
......@@ -322,13 +331,12 @@ class S2TSATEEncoder(FairseqEncoder):
logging.error("Unsupported model arch {}!".format(acoustic_encoder_type))
# adapter
self.temperature = args.temperature
strategy = None
if args.adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", "avg")
elif args.adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter_temperature = args.adapter_temperature
strategy = {
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
self.adapter = Adapter(args.encoder_embed_dim,
args.adapter,
......@@ -341,8 +349,7 @@ class S2TSATEEncoder(FairseqEncoder):
acoustic_encoder_attention_type = args.encoder_attention_type
args.encoder_attention_type = args.text_attention_type
# text encoder
# textual encoder
self.text_encoder = TextEncoder(args, task.source_dictionary, decoder_embed_tokens)
args.encoder_attention_type = acoustic_encoder_attention_type
......@@ -369,10 +376,14 @@ class S2TSATEEncoder(FairseqEncoder):
encoder_out = acoustic_encoder_out["encoder_out"][0]
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
ctc_padding_mask = encoder_padding_mask
if "mixup" in encoder_out:
mixup = encoder_out["mixup"]
else:
mixup = None
if "ctc_logit" in acoustic_encoder_out and len(acoustic_encoder_out["ctc_logit"]) > 0:
ctc_logit = acoustic_encoder_out["ctc_logit"][0]
ctc_prob = F.softmax(ctc_logit / self.temperature, dim=-1, dtype=torch.float32)
ctc_prob = F.softmax(ctc_logit / self.adapter_temperature, dim=-1, dtype=torch.float32)
else:
ctc_logit = None
ctc_prob = None
......@@ -392,18 +403,20 @@ class S2TSATEEncoder(FairseqEncoder):
if self.freeze_textual_encoder:
with torch.no_grad():
x, target_ctc_logit, target_intermedia_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask,
self.history)
else:
x, target_ctc_logit, target_intermedia_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C
"intermedia_ctc_logits": acoustic_encoder_out.get("intermedia_ctc_logits", []), # B x T x C
"ctc_logit": [ctc_logit], # T x B x C
"interleaved_ctc_logits": acoustic_encoder_out.get("interleaved_ctc_logits", []), # B x T x C
"target_ctc_logit": target_ctc_logit, # B x T x C
"target_intermedia_ctc_logits": target_intermedia_ctc_logits, # B x T x C
"ctc_padding_mask": [ctc_padding_mask], # B x T
"target_interleaved_ctc_logits": target_interleaved_ctc_logits, # B x T x C
"ctc_padding_mask": [ctc_padding_mask], # B x T
"encoder_padding_mask": [encoder_padding_mask], # B x T
"mixup": mixup,
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
"src_tokens": [],
......@@ -458,7 +471,7 @@ def base_architecture(args):
args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
# transformer
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
......@@ -480,6 +493,10 @@ def base_architecture(args):
args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
......@@ -497,23 +514,58 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
args.embed_linear = getattr(args, "embed_linear", False)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.cnn_module_norm = getattr(args, "cnn_module_norm", "batch_norm")
# settings for DLCL
args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False)
args.use_dec_dlcl = getattr(args, "use_dec_dlcl", False)
args.init_value = getattr(args, 'init_value', 'avg')
args.weight_type = getattr(args, 'weight_type', 'scalar')
args.encoder_learnable = getattr(args, 'encoder_learnable', True)
args.decoder_learnable = getattr(args, 'decoder_learnable', True)
args.normalize_embed = getattr(args, 'normalize_embed', False)
args.history_dropout = getattr(args, 'history_dropout', 0.0)
args.history_window_size = getattr(args, 'history_window_size', -1)
# Relative position encoding
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# SATE
args.acoustic_encoder = getattr(args, "acoustic_encoder", "transformer")
args.adapter = getattr(args, "adapter", "league")
args.ctc_compress_strategy = getattr(args, "ctc_compress_strategy", "avg")
args.temperature = getattr(args, "temperature", 1.0)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.text_attention_type = getattr(args, "text_attention_type", "selfattn")
args.share_ctc_and_adapter = getattr(args, "share_ctc_and_adapter", False)
# local modeling
args.hard_mask_window = getattr(args, 'hard_mask_window', 0)
args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
# interleaved CTC
args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.share_target_ctc_and_sae = getattr(args, "share_target_ctc_and_sae", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 1)
# PDS
args.pds_stages = getattr(args, "pds_stages", None)
......@@ -539,10 +591,14 @@ def base_architecture(args):
args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0")
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
# SATE
args.acoustic_encoder = getattr(args, "acoustic_encoder", "transformer")
args.adapter = getattr(args, "adapter", "league")
args.ctc_compress_strategy = getattr(args, "ctc_compress_strategy", "avg")
args.adapter_temperature = getattr(args, "adapter_temperature", 1.0)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.text_attention_type = getattr(args, "text_attention_type", "selfattn")
args.share_ctc_and_adapter = getattr(args, "share_ctc_and_adapter", False)
@register_model_architecture("s2t_sate", "s2t_sate_s")
......
......@@ -2,6 +2,7 @@ import logging
import math
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -139,9 +140,35 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"rel_pos",
"rope",
"abs",
"transfer",
"reduced_rel_pos",
],
help="transformer encoder self-attention layer type"
)
# transfer
parser.add_argument(
"--relative-pos-enc",
action="store_true",
help="use relative position encoding for attention",
)
parser.add_argument(
"--linear-att",
action="store_true",
help="use linear attention",
)
# reduced attention
parser.add_argument(
"--attention-reduced-method",
type=str,
default="conv",
help="reduction method for attention",
)
parser.add_argument(
"--attention-reduced-q",
action="store_true",
help="use reduction for query or not"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
......@@ -286,6 +313,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=int,
help="the position of the ctc loss",
)
parser.add_argument(
"--share-ctc-and-embed",
action="store_true",
help="share the weight of ctc and embedding",
)
# local modeling
parser.add_argument(
......@@ -349,63 +381,97 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-norm",
default="batch_norm",
type=str,
help="normalization type of cnn module",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# Simultaneous speech translation
parser.add_argument(
"--simul",
default=False,
"--embed-linear",
action="store_true",
help="Simultaneous speech translation or not",
)
# interleaved dropout
parser.add_argument('--interleave-dropout', type=int,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout',
action="store_true",
default=False,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout-epoch',
type=int,
default=None,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout-strategy',
type=str,
help='interleaved dropout probability')
# intermedia CTC loss
parser.add_argument(
"--intermedia-ctc-layers",
help="use linear transform after down-sampling",
)
# interleaved CTC layers
parser.add_argument(
"--interleaved-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
help="the position of interleaved ctc layers, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
"--interleaved-ctc-temperature",
default=1,
type=float,
help="temperature of the CTC probability in sae",
)
parser.add_argument(
"--interleaved-ctc-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
# Semantics-augmented Encoding (SAE)
parser.add_argument(
"--sae-adapter",
default="none",
type=str,
help="type of intermedia adapter",
help="adapter type of sae ",
)
parser.add_argument(
"--sae-drop-prob",
default=0,
type=float,
help="dropping one input in sae with a probability",
)
parser.add_argument(
"--intermedia-distribution-cutoff",
"--sae-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution",
help="cutoff of the distribution in sae",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
"--share-ctc-and-sae",
action="store_true",
help="share the weight of ctc and sae",
)
# Mixup
parser.add_argument(
"--inter-mixup",
action="store_true",
help="use mixup or not",
)
parser.add_argument(
"--inter-mixup-layer",
default=None,
type=int,
help="the layers to apply mixup",
)
parser.add_argument(
"--inter-mixup-beta",
default=0.5,
type=float,
help="probability of dropping the followed layers",
help="the coefficient beta of mixup",
)
parser.add_argument(
"--inter-mixup-prob",
default=1,
type=float,
help="the probability of mixup",
)
parser.add_argument(
"--intermedia-temperature",
"--inter-mixup-ratio",
default=1,
type=float,
help="temperature of the intermedia ctc probability",
help="the ratio of mixup",
)
pass
......@@ -513,10 +579,11 @@ class S2TTransformerEncoder(FairseqEncoder):
self.padding_idx = 1
self.subsample = subsampling(args)
# self.linear = nn.Linear(dim, dim)
self.embed_linear = getattr(args, "embed_linear", False)
if self.embed_linear:
self.linear = nn.Linear(dim, dim)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
if self.attn_type == "rel_pos":
self.embed_positions = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
......@@ -546,9 +613,6 @@ class S2TTransformerEncoder(FairseqEncoder):
else:
self.history = None
# self.use_ctc = "sate" in args.arch or \
# (getattr(args, "criterion", "") == "ctc") or \
# (("ctc" in getattr(args, "criterion", "")) and (getattr(args, "ctc_weight", 0) > 0))
self.use_ctc = "sate" in args.arch or getattr(args, "ctc_weight", 0) > 0
if self.use_ctc:
self.ctc_layer = args.ctc_layer
......@@ -560,47 +624,63 @@ class S2TTransformerEncoder(FairseqEncoder):
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
if task.source_dictionary == task.target_dictionary and \
if getattr(args, "share_ctc_and_embed", False) and \
task.source_dictionary == task.target_dictionary and \
embed_tokens is not None and dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_dropout = getattr(args, "interleave_dropout", None)
self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
# self.gather_cos_sim = True
self.dis = 2
self.cos_sim = dict()
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.interleaved_ctc_layers = []
if args.interleaved_ctc_layers is not None:
interleaved_ctc_layers = args.interleaved_ctc_layers.split(",")
for layer_idx in interleaved_ctc_layers:
layer_idx = int(layer_idx)
if layer_idx <= 0:
layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx)
self.interleaved_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
logger.info("Interleaved CTC loss in layer %d" % layer_idx)
if not self.use_ctc:
self.ctc = CTC(dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
if getattr(args, "share_ctc_and_embed", False) and \
task.source_dictionary == task.target_dictionary and \
embed_tokens is not None and dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(dim, args.intermedia_adapter,
len(task.source_dictionary), strategy=strategy)
# embed_tokens=embed_tokens if embed_tokens is not None else self.ctc.ctc_projection)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
strategy = {
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
self.sae_adapter = Adapter(dim, args.sae_adapter,
len(task.source_dictionary),
strategy=strategy,
)
if args.share_ctc_and_sae and hasattr(self.sae_adapter, "embed_adapter"):
self.ctc.ctc_projection.weight = self.sae_adapter.embed_adapter.weight
# mixup
self.mixup = getattr(args, "inter_mixup", False)
if self.mixup:
self.mixup_layer = int(args.inter_mixup_layer)
self.mixup_prob = float(args.inter_mixup_prob)
self.mixup_ratio = float(args.inter_mixup_ratio)
beta = float(args.inter_mixup_beta)
from torch.distributions import Beta
self.beta = Beta(torch.Tensor([beta]), torch.Tensor([beta]))
logger.info("Use mixup in layer %d with beta %.2f, prob %.2f, ratio %.2f." % (
self.mixup_layer, beta, self.mixup_prob, self.mixup_ratio))
# gather cosine similarity
self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
self.dis = 2
self.cos_sim = dict()
@staticmethod
def pooling_ratio():
......@@ -624,21 +704,67 @@ class S2TTransformerEncoder(FairseqEncoder):
self.cos_sim[idx] = []
self.cos_sim[idx].append(float(sim))
def apply_mixup(self, x, encoder_padding_mask):
batch = x.size(1)
indices = np.random.permutation(batch)
if self.mixup_ratio == 1:
if len(indices) % 2 != 0:
indices = np.append(indices, (indices[-1]))
idx1 = indices[0::2]
idx2 = indices[1::2]
else:
mix_size = int(max(2, batch * self.mixup_ratio // 2 * 2))
mix_indices = indices[: mix_size]
idx1 = np.append(mix_indices[0::2], (indices[mix_size:]))
idx2 = np.append(mix_indices[1::2], (indices[mix_size:]))
idx1 = torch.from_numpy(idx1).to(x.device)
idx2 = torch.from_numpy(idx2).to(x.device)
x1 = x[:, idx1]
x2 = x[:, idx2]
coef = self.beta.sample().to(x.device).type_as(x)
x = (coef * x1 + (1 - coef) * x2)
pad1 = encoder_padding_mask[idx1]
pad2 = encoder_padding_mask[idx2]
encoder_padding_mask = pad1 + pad2
input_lengths = (~encoder_padding_mask).sum(-1)
mixup = {
"coef": coef,
"index1": idx1,
"index2": idx2,
}
return x, encoder_padding_mask, input_lengths, mixup
def forward(self, src_tokens, src_lengths):
layer_idx = -1
mixup = None
if self.history is not None:
self.history.clean()
# (B, T, D) -> (T, B, D)
x = src_tokens.transpose(0, 1)
input_lengths = src_lengths
# gather cosine similarity
cos_sim_idx = -1
dis = self.dis
if self.gather_cos_sim:
self.add_to_dict(src_tokens.transpose(0, 1), dis, cos_sim_idx)
self.add_to_dict(x, dis, cos_sim_idx)
if self.training and self.mixup and layer_idx == self.mixup_layer:
if torch.rand(1) < self.mixup_prob:
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
# down-sampling
# (B, T, D) -> (T, B, D)
x = src_tokens.transpose(0, 1)
x, input_lengths = self.subsample(x, src_lengths)
x, input_lengths = self.subsample(x, input_lengths)
# embedding scaling
x = self.embed_scale * x
......@@ -657,7 +783,8 @@ class S2TTransformerEncoder(FairseqEncoder):
x += positions
positions = None
# x = self.linear(x)
if self.embed_linear:
x = self.linear(x)
x = self.dropout_module(x)
# add emb into history
......@@ -670,43 +797,46 @@ class S2TTransformerEncoder(FairseqEncoder):
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
layer_idx = 0
layer_idx += 1
ctc_logit = None
intermedia_ctc_logits = []
for layer in self.layers:
layer_idx += 1
interleaved_ctc_logits = []
if self.training and self.mixup and layer_idx == self.mixup_layer:
if torch.rand(1) < self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
for layer in self.layers:
if self.history is not None:
x = self.history.pop()
if layer_idx != len(self.layers) \
and self.interleaved_dropout is not None \
and layer_idx % self.interleaved_dropout == 0:
x = self.dropout_module(x)
# encoder layer
x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
if self.training and self.mixup and layer_idx == self.mixup_layer:
if torch.rand(1) < self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone())
# interleave CTC
if layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0:
# interleaved CTC
if layer_idx in self.interleaved_ctc_layers:
if self.interleaved_ctc_drop_prob > 0:
p = torch.rand(1).uniform_()
if p < self.intermedia_drop_prob:
if p < self.interleaved_ctc_drop_prob:
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit)
interleaved_ctc_logits.append(logit)
logit = logit.clamp(min=-1e8 if logit.dtype == torch.float32 else -1e4,
max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae_adapter([x, prob], encoder_padding_mask)
# gather cosine similarity
if self.gather_cos_sim:
......@@ -728,8 +858,9 @@ class S2TTransformerEncoder(FairseqEncoder):
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C
"interleaved_ctc_logits": interleaved_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"mixup": mixup,
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
"src_tokens": [],
......@@ -872,14 +1003,18 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.embed_linear = getattr(args, "embed_linear", False)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.cnn_module_norm = getattr(args, "cnn_module_norm", "batch_norm")
# settings for DLCL
args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False)
......@@ -902,16 +1037,23 @@ def base_architecture(args):
args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
# interleaved dropout
args.interleave_dropout = getattr(args, "interleave_dropout", None)
args.cl_dropout = getattr(args, "cl_dropout", False)
args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None)
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
# intermedia CTC
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
# interleaved CTC
args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 1)
@register_model_architecture("s2t_transformer", "s2t_transformer_s")
......
......@@ -286,9 +286,6 @@ class TransformerModel(FairseqEncoderDecoderModel):
help="freeze the module of the decoder",
)
parser.add_argument('--interleave-dropout', default=0, type=float, metavar='D',
help='interleaved dropout probability')
parser.add_argument(
"--squeeze-excitation",
default=False,
......
......@@ -286,50 +286,61 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
help="freeze the module of the decoder",
)
parser.add_argument('--interleave-dropout', default=0, type=float, metavar='D',
help='interleaved dropout probability')
parser.add_argument(
"--squeeze-excitation",
default=False,
action='store_true',
help="use squeeze and excitation method",
)
# CTC
parser.add_argument(
"--ctc-layer",
type=int,
help="ctc layers for target sentence",
)
# interleaved CTC layers
parser.add_argument(
"--intermedia-ctc-layers",
"--interleaved-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
help="the position of interleaved ctc layers, separated by comma",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
"--interleaved-ctc-upsampling-ratio",
default=2,
type=int,
help="upsampling ratio of the representation for CTC calculation",
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution",
"--interleaved-ctc-temperature",
default=1,
type=float,
help="temperature of the CTC probability in sae",
)
parser.add_argument(
"--intermedia-drop-prob",
"--interleaved-ctc-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
# Semantics-augmented Encoding (SAE)
parser.add_argument(
"--intermedia-temperature",
default=1,
"--sae-adapter",
default="none",
type=str,
help="adapter type of sae ",
)
parser.add_argument(
"--sae-drop-prob",
default=0,
type=float,
help="temperature of the intermedia ctc probability",
help="dropping one input in sae with a probability",
)
parser.add_argument(
"--sae-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution in sae",
)
parser.add_argument(
"--share-ctc-and-sae",
action="store_true",
help="share the weight of ctc and sae",
)
# fmt: on
......@@ -574,6 +585,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC
self.use_ctc = getattr(args, "ctc_weight", 0) > 0
if self.use_ctc:
assert decoder_embed_tokens is not None
self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.inter_ctc:
......@@ -583,35 +595,41 @@ class TransformerCTCEncoder(FairseqEncoder):
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
self.ctc.ctc_projection.weight = embed_tokens.weight
self.ctc.ctc_projection.weight = decoder_embed_tokens.weight
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.interleaved_ctc_upsampling_ratio = args.interleaved_ctc_upsampling_ratio
self.interleaved_ctc_layers = []
if args.interleaved_ctc_layers is not None:
interleaved_ctc_layers = args.interleaved_ctc_layers.split(",")
for layer_idx in interleaved_ctc_layers:
layer_idx = int(layer_idx)
if layer_idx <= 0:
layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx)
self.interleaved_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
logger.info("Interleaved CTC loss in layer %d" % layer_idx)
if not self.use_ctc:
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout)
self.ctc.ctc_projection.weight = embed_tokens.weight
self.ctc.ctc_projection.weight = decoder_embed_tokens.weight
strategy = {
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter,
decoder_embed_tokens.num_embeddings, embed_tokens=decoder_embed_tokens, strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
self.sae_adapter = Adapter(embed_dim, args.sae_adapter,
decoder_embed_tokens.num_embeddings,
strategy=strategy
)
if args.share_ctc_and_sae and hasattr(self.sae_adapter, "embed_adapter"):
self.ctc.ctc_projection.weight = self.sae_adapter.embed_adapter.weight
def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args)
......@@ -672,12 +690,13 @@ class TransformerCTCEncoder(FairseqEncoder):
return_all_hiddens,
token_embeddings)
def upsample(self, x, ratio=2):
def upsampling(self, x):
ratio = self.interleaved_ctc_upsampling_ratio
if ratio <= 1:
return x
seq_len, bsz, dim = x.size()
x = x.unsqueeze(0).expand(ratio, -1, -1, -1).reshape(-1, bsz, dim)
x = x.unsqueeze(1).expand(-1, ratio, -1, -1).reshape(-1, bsz, dim)
return x
# TorchScript doesn't support super() method so that the scriptable Subclass
......@@ -742,7 +761,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# encoder layers
layer_idx = 0
ctc_logit = None
intermedia_ctc_logits = []
interleaved_ctc_logits = []
for layer in self.layers:
if self.history is not None:
x = self.history.pop()
......@@ -757,24 +776,28 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(self.upsample(x.clone()))
ctc_logit = self.ctc(self.upsampling(x.clone()))
# Intermedia CTC
if layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0:
if layer_idx in self.interleaved_ctc_layers:
if self.interleaved_ctc_drop_prob > 0:
p = torch.rand(1).uniform_()
if p < self.intermedia_drop_prob:
if p < self.interleaved_ctc_drop_prob:
break
norm_x = self.layer_norm(x)
up_x = self.upsample(norm_x)
up_x = self.upsampling(norm_x)
up_logit = self.ctc(up_x)
intermedia_ctc_logits.append(up_logit)
up_prob = utils.softmax(up_logit / self.intermedia_temperature, dim=-1)
interleaved_ctc_logits.append(up_logit)
up_prob = utils.softmax(up_logit / self.interleaved_ctc_temperature, dim=-1)
up_prob = up_prob.permute(1, 2, 0)
prob = nn.functional.max_pool1d(up_prob, kernel_size=2, stride=2)
prob = nn.functional.max_pool1d(up_prob,
kernel_size=self.interleaved_ctc_upsampling_ratio,
stride=self.interleaved_ctc_upsampling_ratio)
prob = prob.permute(2, 0, 1)
x, _ = self.adapter([x, prob])
if self.history is not None:
......@@ -787,12 +810,13 @@ class TransformerCTCEncoder(FairseqEncoder):
x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(self.upsample(x))
ctc_logit = self.ctc(self.upsampling(x))
ctc_padding_mask = encoder_padding_mask
if ctc_logit is not None or len(intermedia_ctc_logits) != 0:
if ctc_logit is not None or len(interleaved_ctc_logits) != 0:
bsz = encoder_padding_mask.size(0)
ctc_padding_mask = encoder_padding_mask.unsqueeze(-1).expand(-1, -1, 2).reshape(bsz, -1)
ctc_padding_mask = encoder_padding_mask.unsqueeze(-1).\
expand(-1, -1, self.interleaved_ctc_upsampling_ratio).reshape(bsz, -1)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
......@@ -802,7 +826,7 @@ class TransformerCTCEncoder(FairseqEncoder):
"encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"ctc_padding_mask": [ctc_padding_mask],
"intermedia_ctc_logits": intermedia_ctc_logits, # T x B x C
"interleaved_ctc_logits": interleaved_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
......@@ -1457,9 +1481,17 @@ def base_architecture(args):
# CTC
args.ctc_layer = getattr(args, "ctc_layer", args.encoder_layers)
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
# interleaved CTC
args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
@register_model_architecture("transformer_ctc", "transformer_ctc_relative")
......
......@@ -2,21 +2,23 @@ import torch
from torch import nn
from fairseq.modules.activations import get_activation_class
from fairseq.modules.layer_norm import LayerNorm
class ConvolutionModule(nn.Module):
"""Convolution block used in the conformer block"""
def __init__(
self,
embed_dim,
expand_embed_dim,
depthwise_kernel_size,
dropout,
activation_fn="swish",
bias=False,
stride=1,
export=False,
self,
embed_dim,
expand_embed_dim,
depthwise_kernel_size,
dropout,
activation_fn="swish",
bias=False,
stride=1,
export=False,
norm_type="batch_norm"
):
"""
Args:
......@@ -30,8 +32,8 @@ class ConvolutionModule(nn.Module):
"""
super(ConvolutionModule, self).__init__()
assert (
depthwise_kernel_size - 1
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
depthwise_kernel_size - 1
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
self.pointwise_conv1 = torch.nn.Conv1d(
embed_dim,
2 * expand_embed_dim,
......@@ -50,7 +52,13 @@ class ConvolutionModule(nn.Module):
groups=expand_embed_dim,
bias=bias,
)
self.batch_norm = nn.BatchNorm1d(expand_embed_dim)
self.norm_type = norm_type
if norm_type == "batch_norm":
self.norm = nn.BatchNorm1d(expand_embed_dim)
elif norm_type == "layer_norm":
self.norm = LayerNorm(expand_embed_dim)
else:
assert False, "Unsupported normalization type in convolution module"
self.activation = get_activation_class(activation_fn)
self.pointwise_conv2 = torch.nn.Conv1d(
expand_embed_dim,
......@@ -62,7 +70,7 @@ class ConvolutionModule(nn.Module):
)
self.dropout = torch.nn.Dropout(dropout)
def forward(self, x):
def forward(self, x, mask_pad=None):
"""
Args:
x: Input of shape B X T X C
......@@ -72,23 +80,36 @@ class ConvolutionModule(nn.Module):
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2)
# zero_mask_pad = mask_pad.unsqueeze(1)
# # mask batch padding
# if mask_pad is not None:
# x.masked_fill_(zero_mask_pad, 0.0)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*expand_embed_dim, dim)
x = self.glu(x) # (batch, expand_embed_dim, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.batch_norm(x)
if self.norm_type == "layer_norm":
x = x.transpose(1, 2)
x = self.norm(x)
x = self.activation(x)
if self.norm_type == "layer_norm":
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# # mask batch padding
# if zero_mask_pad is not None:
# x.masked_fill_(zero_mask_pad, 0.0)
x = x.transpose(1, 2)
x = self.dropout(x)
return x
# class ConvolutionModule(nn.Module):
# """ConvolutionModule in Conformer model."""
# def __init__(self,
......
......@@ -332,7 +332,7 @@ class PDSTransformerEncoderLayer(nn.Module):
if self.normalize_before:
x = self.conv_norm(x)
x = self.conv_module(x)
x = self.conv_module(x, encoder_padding_mask)
x = x.transpose(0, 1)
x = self.conv_res(residual) + x
......
......@@ -122,7 +122,9 @@ class S2TTransformerEncoderLayer(nn.Module):
self.embed_dim,
depthwise_kernel_size=args.cnn_module_kernel,
dropout=args.dropout,
activation_fn=getattr(args, 'activation_fn', 'swish'))
activation_fn=getattr(args, 'activation_fn', 'swish'),
norm_type=args.cnn_module_norm
)
self.final_norm = LayerNorm(embed_dim)
else:
self.conv_norm = None
......
......@@ -3,6 +3,7 @@ import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import groupby
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.modules import LayerNorm
......@@ -61,9 +62,12 @@ class Adapter(nn.Module):
super().__init__()
dim = dim
self.adapter_type = adapter_type
self.cal_linear = False
self.cal_context = False
if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
self.cal_linear = True
self.linear_adapter = nn.Sequential(
nn.Linear(dim, dim),
LayerNorm(dim),
......@@ -71,14 +75,10 @@ class Adapter(nn.Module):
)
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
self.cal_context = True
self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization
if embed_tokens is not None:
self.embed_adapter.weight = embed_tokens.weight
# if embed_tokens is None:
# num_embeddings = len(dictionary)
# self.embed_adapter = nn.Linear(num_embeddings, dim) # Embedding(num_embeddings, dim, dictionary.pad())
# else:
# self.embed_adapter = embed_tokens
if self.adapter_type == "gated_league":
self.gate_linear = nn.Linear(2 * dim, dim)
......@@ -86,14 +86,22 @@ class Adapter(nn.Module):
self.gate_linear1 = nn.Linear(dim, dim)
self.gate_linear2 = nn.Linear(dim, dim)
# additional strategy
if self.adapter_type == "shrink":
assert strategy is not None
self.ctc_compress = getattr(CTCCompressStrategy, strategy)
logger.info("CTC Compress Strategy: %s" % strategy)
elif self.adapter_type == "league":
self.distribution_cutoff = strategy
ctc_compress_strategy = getattr(strategy, "ctc_compress_strategy", "avg")
self.ctc_compress = getattr(CTCCompressStrategy, ctc_compress_strategy)
logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy)
if "league" in self.adapter_type:
self.distribution_cutoff = strategy.get("distribution_cutoff", None)
if self.distribution_cutoff is not None:
logger.info("Distribution cutoff: %d" % int(strategy))
self.distribution_cutoff = int(self.distribution_cutoff)
logger.info("Distribution cutoff: %d" % self.distribution_cutoff)
self.drop_prob = strategy.get("drop_prob", 0)
if self.drop_prob != 0:
logger.info("Adapter drop probability: %f" % self.drop_prob)
def forward(self, x, padding=None):
......@@ -103,14 +111,11 @@ class Adapter(nn.Module):
org_distribution = distribution
distribution = distribution.contiguous().view(-1, distribution.size(-1))
if self.adapter_type == "linear":
out = self.linear_adapter(representation)
elif self.adapter_type == "context":
out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
elif self.adapter_type == "league":
linear_out = None
soft_out = None
if self.cal_linear:
linear_out = self.linear_adapter(representation)
if self.cal_context:
if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), org_distribution.size(-1) - 1)
threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
......@@ -120,24 +125,33 @@ class Adapter(nn.Module):
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
out = linear_out + soft_out
elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
if self.adapter_type == "linear":
out = linear_out
elif self.adapter_type == "context":
out = soft_out
elif self.adapter_type == "league":
if self.drop_prob > 0 and torch.rand(1).uniform_() < self.drop_prob:
if torch.rand(1).uniform_() < 0.5:
out = linear_out
else:
out = soft_out
else:
out = linear_out + soft_out
elif self.adapter_type == "gated_league":
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league":
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
out = representation + soft_out
elif self.adapter_type == "none":
out = representation
elif self.adapter_type == "shrink":
from itertools import groupby
lengths = (~padding).long().sum(-1)
with torch.no_grad():
......
......@@ -13,16 +13,14 @@ logger = logging.getLogger(__name__)
class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout, need_layernorm=False):
super(CTC, self).__init__()
self.embed_dim = embed_dim
self.ctc_projection = nn.Linear(embed_dim, dictionary_size, bias=False)
self.ctc_projection = nn.Linear(embed_dim, dictionary_size)
nn.init.normal_(
self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5
)
# nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5)
self.ctc_dropout_module = FairseqDropout(
p=dropout, module_name=self.__class__.__name__
......@@ -46,4 +44,3 @@ class CTC(nn.Module):
def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1)
......@@ -191,7 +191,8 @@ class Conv2dSubsampling(nn.Module):
filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2),
# padding=(kernel_size - 1) // 2
),
get_norm(norm,
filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
transpose=True if norm == "layer" else False),
......@@ -214,6 +215,8 @@ class Conv2dSubsampling(nn.Module):
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size()
assert subsampled_length == max(x_len), "The lengths are mismatched."
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).permute(2, 0, 1)
x = self.linear(x)
......
......@@ -156,22 +156,22 @@ def main(cfg: FairseqConfig) -> None:
)
break
if getattr(cfg.model, "cl_dropout", False):
cl_dropout_epoch = getattr(cfg.model, "cl_dropout_epoch", None)
cl_dropout_strategy = getattr(cfg.model, "cl_dropout_strategy", "linear")
dropout = getattr(cfg.model, "dropout", False)
assert cl_dropout_epoch > 0
curr_epoch = epoch_itr.epoch
if curr_epoch <= cl_dropout_epoch:
if curr_epoch == cl_dropout_epoch:
curr_dropout = dropout
else:
curr_dropout = curr_epoch / cl_dropout_epoch * dropout
logger.info("Epoch {}: dropout ratio: {}.".format(curr_epoch, curr_dropout))
for name, module in trainer.model.named_modules():
from fairseq.modules.fairseq_dropout import FairseqDropout
if isinstance(module, FairseqDropout):
module.p = curr_dropout
# if getattr(cfg.model, "cl_dropout", False):
# cl_dropout_epoch = getattr(cfg.model, "cl_dropout_epoch", None)
# cl_dropout_strategy = getattr(cfg.model, "cl_dropout_strategy", "linear")
# dropout = getattr(cfg.model, "dropout", False)
# assert cl_dropout_epoch > 0
# curr_epoch = epoch_itr.epoch
# if curr_epoch <= cl_dropout_epoch:
# if curr_epoch == cl_dropout_epoch:
# curr_dropout = dropout
# else:
# curr_dropout = curr_epoch / cl_dropout_epoch * dropout
# logger.info("Epoch {}: dropout ratio: {}.".format(curr_epoch, curr_dropout))
# for name, module in trainer.model.named_modules():
# from fairseq.modules.fairseq_dropout import FairseqDropout
# if isinstance(module, FairseqDropout):
# module.p = curr_dropout
# train for one epoch
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论