Commit d220c040 by xuchen

Big update!

I optimize the implementation of the speech-to-text tasks. As always, I update the shell scripts and YAML configures for easy training.
There may be some bugs. So, the follow-up update is coming!
parent 99763132
...@@ -40,5 +40,3 @@ encoder-attention-heads: 4 ...@@ -40,5 +40,3 @@ encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
...@@ -29,7 +29,7 @@ label_smoothing: 0.1 ...@@ -29,7 +29,7 @@ label_smoothing: 0.1
conv-kernel-sizes: 5,5 conv-kernel-sizes: 5,5
conv-channels: 1024 conv-channels: 1024
dropout: 0.1 dropout: 0.15
activation-fn: relu activation-fn: relu
encoder-embed-dim: 512 encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
...@@ -39,6 +39,4 @@ encoder-attention-heads: 8 ...@@ -39,6 +39,4 @@ encoder-attention-heads: 8
decoder-embed-dim: 512 decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8 decoder-attention-heads: 8
attention-dropout: 0.1 \ No newline at end of file
activation-dropout: 0.1
arch: s2t_conformer_s
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
#arch: pdss2t_transformer_s
#arch: s2t_transformer_s
arch: s2t_sate
encoder-embed-dim: 256
pyramid-stages: 4
#pyramid-dropout: 0
pyramid-layers: 2_2_6_2
pyramid-ratios: 2_2_2_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8
pyramid-attn-heads: 4_4_4_4
cl-dropout: True
cl-dropout-epoch: 50
train-subset: train-clean-100 train-subset: train-clean-100
valid-subset: dev-clean valid-subset: dev-clean
...@@ -5,7 +26,7 @@ max-epoch: 100 ...@@ -5,7 +26,7 @@ max-epoch: 100
max-update: 300000 max-update: 300000
num-workers: 8 num-workers: 8
patience: 10 patience: 20
no-progress-bar: True no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
...@@ -14,7 +35,6 @@ report-accuracy: True ...@@ -14,7 +35,6 @@ report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_transformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -28,11 +48,9 @@ criterion: label_smoothed_cross_entropy_with_ctc ...@@ -28,11 +48,9 @@ criterion: label_smoothed_cross_entropy_with_ctc
ctc-weight: 0.3 ctc-weight: 0.3
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024 conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
......
train-subset: train_st train-subset: train-clean-100
valid-subset: dev_st valid-subset: dev-clean
max-epoch: 50 max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -15,6 +15,14 @@ report-accuracy: True ...@@ -15,6 +15,14 @@ report-accuracy: True
#load-pretrained-acoustic-encoder-from: #load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from: #load-pretrained-text-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
#load-pretrained-acoustic-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/1007_st_ctc_baseline/avg_10_checkpoint.pt
#load-pretrained-acoustic-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/1111_st_ctc_conformer_lr0.001/avg_10_checkpoint.pt
#load-pretrained-acoustic-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/1007_st_pyramid4_all256_3333_sr8_ctc/avg_10_checkpoint.pt
#load-pretrained-acoustic-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/1114_st_pyramid4_all256_ctc_fix/avg_10_checkpoint.pt
#load-pretrained-acoustic-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/1015_st_pyramid4_all256_conformer_baseline/avg_10_checkpoint.pt
#load-pretrained-acoustic-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/1111_st_pyramid4_all256_conformer_ctc/avg_10_checkpoint.pt
arch: s2t_sate arch: s2t_sate
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
...@@ -24,33 +32,37 @@ lr-scheduler: inverse_sqrt ...@@ -24,33 +32,37 @@ lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
text-encoder-layers: 6 text-encoder-layers: 6
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True #macaron-style: True
use-cnn-module: True #use-cnn-module: True
cnn-module-kernel: 31 #cnn-module-kernel: 31
acoustic-encoder: transformer #acoustic-encoder: transformer
#acoustic-encoder: conformer
acoustic-encoder: pyramid
adapter: league adapter: league
#adapter: none
#adapter: context
#decoder-embed-dim: 256 encoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048 pyramid-stages: 4
#decoder-attention-heads: 4 #pyramid-dropout: 0
#attention-dropout: 0.1 pyramid-layers: 3_3_3_3
#activation-dropout: 0.1 pyramid-sr-ratios: 2_2_1_2
pyramid-embed-dims: 256_256_256_256
pyramid-fuse: True
pyramid-reduced-embed: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8
pyramid-heads: 4_4_4_4
\ No newline at end of file
arch: pys2t_transformer_s arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pyramid-stages: 3
pyramid-layers: 3_6_3
pyramid-fuse-way: all_conv
pyramid-fuse: True
pyramid-sr-ratios: 2_2_2
pyramid-embed-dims: 256_256_256
pyramid-reduced-embed: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1
pyramid-kernel-sizes: 5_5_5
pyramid-ffn-ratios: 8_8_8
pyramid-heads: 4_4_4
train-subset: train-clean-100,train-clean-360,train-other-500 train-subset: train-clean-100,train-clean-360,train-other-500
valid-subset: dev-clean valid-subset: dev-clean
...@@ -20,7 +7,7 @@ max-epoch: 100 ...@@ -20,7 +7,7 @@ max-epoch: 100
max-update: 300000 max-update: 300000
num-workers: 8 num-workers: 8
patience: 20 patience: 10
no-progress-bar: True no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
...@@ -41,7 +28,6 @@ lr: 2e-3 ...@@ -41,7 +28,6 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
...@@ -52,5 +38,3 @@ encoder-attention-heads: 4 ...@@ -52,5 +38,3 @@ encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
arch: pys2t_transformer_s arch: pdss2t_transformer_s_16
encoder-embed-dim: 256 encoder-embed-dim: 256
pyramid-stages: 4 pyramid-stages: 4
#pyramid-dropout: 0 #pyramid-dropout: 0
pyramid-layers: 2_2_6_2 pyramid-layers: 2_2_6_2
pyramid-sr-ratios: 2_2_2_2 pyramid-ratios: 2_2_2_2
pyramid-fuse: True pyramid-fusion: True
pyramid-fuse-way: all_conv pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256 pyramid-embed-dims: 256_256_256_256
pyramid-reduced-embed: conv pyramid-ds-method: conv
pyramid-embed-norm: True pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1 pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5 pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8 pyramid-ffn-ratios: 8_8_8_8
pyramid-heads: 4_4_4_4 pyramid-attn-heads: 4_4_4_4
train-subset: train-clean-100,train-clean-360,train-other-500 train-subset: train-clean-100,train-clean-360,train-other-500
valid-subset: dev-clean valid-subset: dev-clean
...@@ -21,7 +22,7 @@ max-epoch: 100 ...@@ -21,7 +22,7 @@ max-epoch: 100
max-update: 300000 max-update: 300000
num-workers: 8 num-workers: 8
patience: 20 patience: 10
no-progress-bar: True no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
...@@ -42,7 +43,6 @@ lr: 2e-3 ...@@ -42,7 +43,6 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
...@@ -53,5 +53,3 @@ encoder-attention-heads: 4 ...@@ -53,5 +53,3 @@ encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
train-subset: train_st arch: pdss2t_transformer_s_32
valid-subset: dev_st
max-epoch: 50 encoder-embed-dim: 256
max-update: 100000 pyramid-stages: 5
#pyramid-dropout: 0
pyramid-layers: 2_2_3_3_2
pyramid-ratios: 2_2_2_2_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1_1
pyramid-kernel-sizes: 5_5_5_5_5
pyramid-ffn-ratios: 8_8_8_8_8
pyramid-attn-heads: 4_4_4_4_4
train-subset: train-clean-100,train-clean-360,train-other-500
valid-subset: dev-clean
max-epoch: 100
max-update: 300000
num-workers: 8 num-workers: 8
patience: 10 patience: 10
...@@ -12,11 +29,8 @@ seed: 1 ...@@ -12,11 +29,8 @@ seed: 1
report-accuracy: True report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_sate
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -26,31 +40,16 @@ warmup-updates: 10000 ...@@ -26,31 +40,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True decoder-embed-dim: 256
use-cnn-module: True decoder-ffn-embed-dim: 2048
cnn-module-kernel: 31 decoder-attention-heads: 4
acoustic-encoder: transformer
adapter: league
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
arch: pys2t_transformer_s arch: pdss2t_transformer_s_8
encoder-embed-dim: 256 encoder-embed-dim: 256
pyramid-stages: 4 pyramid-stages: 4
#pyramid-dropout: 0
pyramid-layers: 3_3_3_3 pyramid-layers: 3_3_3_3
pyramid-sr-ratios: 2_2_1_2 pyramid-ratios: 2_2_1_2
pyramid-fuse: True pyramid-fusion: True
pyramid-fuse-way: all_conv pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256 pyramid-embed-dims: 256_256_256_256
pyramid-reduced-embed: conv pyramid-ds-method: conv
pyramid-embed-norm: True pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1 pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5 pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8 pyramid-ffn-ratios: 8_8_8_8
pyramid-heads: 4_4_4_4 pyramid-attn-heads: 4_4_4_4
train-subset: train-clean-100,train-clean-360,train-other-500 train-subset: train-clean-100,train-clean-360,train-other-500
valid-subset: dev-clean valid-subset: dev-clean
...@@ -20,7 +22,7 @@ max-epoch: 100 ...@@ -20,7 +22,7 @@ max-epoch: 100
max-update: 300000 max-update: 300000
num-workers: 8 num-workers: 8
patience: 20 patience: 10
no-progress-bar: True no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
...@@ -41,7 +43,6 @@ lr: 2e-3 ...@@ -41,7 +43,6 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
...@@ -52,5 +53,3 @@ encoder-attention-heads: 4 ...@@ -52,5 +53,3 @@ encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
train-subset: train_st arch: pdss2t_transformer_m_8
valid-subset: dev_st #arch: pdss2t_transformer_m_16
#arch: pdss2t_transformer_m_32
max-epoch: 50 train-subset: train-clean-100,train-clean-360,train-other-500
max-update: 100000 valid-subset: dev-clean
max-epoch: 100
max-update: 300000
num-workers: 8 num-workers: 8
patience: 10 patience: 10
...@@ -12,11 +16,8 @@ seed: 1 ...@@ -12,11 +16,8 @@ seed: 1
report-accuracy: True report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_sate
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -26,32 +27,16 @@ warmup-updates: 10000 ...@@ -26,32 +27,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True decoder-embed-dim: 512
use-cnn-module: True decoder-ffn-embed-dim: 2048
cnn-module-kernel: 31 decoder-attention-heads: 8
acoustic-encoder: transformer
adapter: league
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
arch: pys2t_transformer_s arch: pdss2t_transformer_m_16
encoder-embed-dim: 512 encoder-embed-dim: 512
pyramid-stages: 4 pyramid-stages: 4
#pyramid-dropout: 0
pyramid-layers: 2_2_6_2 pyramid-layers: 2_2_6_2
#pyramid-layers: 3_3_3_3 pyramid-ratios: 2_2_2_2
pyramid-sr-ratios: 2_2_2_2 pyramid-fusion: True
pyramid-fuse: True pyramid-fusion-method: all_conv
pyramid-fuse-way: all_conv
pyramid-embed-dims: 512_512_512_512 pyramid-embed-dims: 512_512_512_512
pyramid-reduced-embed: conv pyramid-ds-method: conv
pyramid-embed-norm: True pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1 pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5 pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8 pyramid-ffn-ratios: 4_4_4_4
pyramid-heads: 8_8_8_8 pyramid-attn-heads: 8_8_8_8
train-subset: train-clean-100,train-clean-360,train-other-500 train-subset: train-clean-100,train-clean-360,train-other-500
valid-subset: dev-clean valid-subset: dev-clean
...@@ -21,7 +22,7 @@ max-epoch: 100 ...@@ -21,7 +22,7 @@ max-epoch: 100
max-update: 300000 max-update: 300000
num-workers: 8 num-workers: 8
patience: 20 patience: 10
no-progress-bar: True no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
...@@ -42,16 +43,13 @@ lr: 2e-3 ...@@ -42,16 +43,13 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 8
decoder-embed-dim: 256 decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 8
attention-dropout: 0.1
activation-dropout: 0.1
train-subset: train arch: pdss2t_transformer_m_32
valid-subset: valid
max-epoch: 50 encoder-embed-dim: 512
max-update: 100000 pyramid-stages: 5
#pyramid-dropout: 0
pyramid-layers: 2_2_3_3_2
pyramid-ratios: 2_2_2_2_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 512_512_512_512_512
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1_1
pyramid-kernel-sizes: 5_5_5_5_5
pyramid-ffn-ratios: 4_4_4_4_4
pyramid-attn-heads: 8_8_8_8_8
train-subset: train-clean-100,train-clean-360,train-other-500
valid-subset: dev-clean
max-epoch: 100
max-update: 300000
num-workers: 8 num-workers: 8
patience: 10 patience: 10
...@@ -10,40 +27,29 @@ no-progress-bar: True ...@@ -10,40 +27,29 @@ no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
skip-invalid-size-inputs-valid-test: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: dlcl_transformer
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 8000 warmup-updates: 10000
lr: 1e-3 lr: 2e-3
adam_betas: (0.9,0.997) #adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
dropout: 0.1 dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 6 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 8 encoder-attention-heads: 8
decoder-embed-dim: 512 decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8 decoder-attention-heads: 8
use-enc-dlcl: True
use-dec-dlcl: True
\ No newline at end of file
train-subset: train arch: pdss2t_transformer_m_8
valid-subset: valid
max-epoch: 50 encoder-embed-dim: 512
max-update: 100000 pyramid-stages: 4
#pyramid-dropout: 0
pyramid-layers: 3_3_3_3
pyramid-ratios: 2_2_1_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 512_512_512_512
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 4_4_4_4
pyramid-attn-heads: 8_8_8_8
train-subset: train-clean-100,train-clean-360,train-other-500
valid-subset: dev-clean
max-epoch: 100
max-update: 300000
num-workers: 8 num-workers: 8
patience: 10 patience: 10
...@@ -10,42 +27,29 @@ no-progress-bar: True ...@@ -10,42 +27,29 @@ no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
skip-invalid-size-inputs-valid-test: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: transformer
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 8000 warmup-updates: 10000
lr: 1e-3 lr: 2e-3
adam_betas: (0.9,0.997) #adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
dropout: 0.1 dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 6 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 8 encoder-attention-heads: 8
decoder-embed-dim: 512 decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8 decoder-attention-heads: 8
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 20
max-decoder-relative-length: 20
#train-subset: train-clean-100,train-clean-360,train-other-500 arch: pdss2t_transformer_sd_8
train-subset: train-clean-100 #arch: pdss2t_transformer_sd_16
#arch: pdss2t_transformer_sd_32
train-subset: train-clean-100,train-clean-360,train-other-500
valid-subset: dev-clean valid-subset: dev-clean
max-epoch: 100 max-epoch: 100
max-update: 300000 max-update: 300000
num-workers: 0 num-workers: 8
patience: 10 patience: 10
no-progress-bar: True no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
arch: s2t_transformer_s #load-pretrained-encoder-from:
#load-pretrained-decoder-from:
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -22,26 +27,16 @@ warmup-updates: 10000 ...@@ -22,26 +27,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 3 encoder-layers: 30
decoder-layers: 3 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True decoder-embed-dim: 256
use-cnn-module: True decoder-ffn-embed-dim: 2048
cnn-module-kernel: 31 decoder-attention-heads: 4
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
arch: pys2t_transformer_s arch: pdss2t_transformer_sd_16
encoder-embed-dim: 256 encoder-embed-dim: 256
pyramid-stages: 4 pyramid-stages: 4
pyramid-layers: 3_3_8_4 #pyramid-dropout: 0
pyramid-sr-ratios: 2_2_2_2 pyramid-layers: 5_5_12_8
pyramid-fuse: True pyramid-ratios: 2_2_2_2
pyramid-fuse-way: all_conv pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256 pyramid-embed-dims: 256_256_256_256
pyramid-reduced-embed: conv pyramid-ds-method: conv
pyramid-embed-norm: True pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1 pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5 pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8 pyramid-ffn-ratios: 8_8_8_8
pyramid-heads: 4_4_4_4 pyramid-attn-heads: 4_4_4_4
train-subset: train-clean-100,train-clean-360,train-other-500 train-subset: train-clean-100,train-clean-360,train-other-500
valid-subset: dev-clean valid-subset: dev-clean
...@@ -20,7 +22,7 @@ max-epoch: 100 ...@@ -20,7 +22,7 @@ max-epoch: 100
max-update: 300000 max-update: 300000
num-workers: 8 num-workers: 8
patience: 20 patience: 10
no-progress-bar: True no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
...@@ -41,16 +43,13 @@ lr: 2e-3 ...@@ -41,16 +43,13 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 30
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
arch: pdss2t_transformer_sd_32
encoder-embed-dim: 256
pyramid-stages: 5
#pyramid-dropout: 0
pyramid-layers: 5_5_7_7_6
pyramid-ratios: 2_2_2_2_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1_1
pyramid-kernel-sizes: 5_5_5_5_5
pyramid-ffn-ratios: 8_8_8_8_8
pyramid-attn-heads: 4_4_4_4_4
train-subset: train-clean-100,train-clean-360,train-other-500
valid-subset: dev-clean
max-epoch: 100
max-update: 300000
num-workers: 8
patience: 10
no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 30
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
arch: pys2t_transformer_s arch: pdss2t_transformer_sd_8
encoder-embed-dim: 256 encoder-embed-dim: 256
pyramid-stages: 4 pyramid-stages: 4
pyramid-layers: 5_5_15_5 #pyramid-dropout: 0
pyramid-sr-ratios: 2_2_2_2 pyramid-layers: 7_7_7_9
pyramid-fuse: True pyramid-ratios: 2_2_1_2
pyramid-fuse-way: all_conv pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256 pyramid-embed-dims: 256_256_256_256
pyramid-reduced-embed: conv pyramid-ds-method: conv
pyramid-embed-norm: True pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1 pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5 pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8 pyramid-ffn-ratios: 8_8_8_8
pyramid-heads: 4_4_4_4 pyramid-attn-heads: 4_4_4_4
train-subset: train-clean-100,train-clean-360,train-other-500 train-subset: train-clean-100,train-clean-360,train-other-500
valid-subset: dev-clean valid-subset: dev-clean
...@@ -20,7 +22,7 @@ max-epoch: 100 ...@@ -20,7 +22,7 @@ max-epoch: 100
max-update: 300000 max-update: 300000
num-workers: 8 num-workers: 8
patience: 20 patience: 10
no-progress-bar: True no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
...@@ -41,16 +43,13 @@ lr: 2e-3 ...@@ -41,16 +43,13 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 30
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
...@@ -6,23 +6,24 @@ gpu_num=8 ...@@ -6,23 +6,24 @@ gpu_num=8
update_freq=1 update_freq=1
max_tokens=100000 max_tokens=100000
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
#exp_tag= #exp_tag=
#config_list=(base) #config_list=(base)
#config_list=(ctc) #config_list=(ctc)
#config_list=(ctc conformer rpr) #config_list=(ctc conformer rpr)
config_list=(base conformer rpr) config_list=(base conformer rpr)
#config_list=(pyramid4_all256) #config_list=(pds_base)
#config_list=(pyramid5_all256) #config_list=(pds_big)
#config_list=(pds_deep)
# exp full name # exp full name
exp_name= exp_name=
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
train_config=$(echo ${config_list[*]} | sed 's/ /,/g') train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
cmd="./run.sh cmd="./run.sh
......
...@@ -2,22 +2,22 @@ set -e ...@@ -2,22 +2,22 @@ set -e
eval=1 eval=1
lcrm=0
tokenizer=0
root_dir=~/st/Fairseq-S2T root_dir=~/st/Fairseq-S2T
data_dir=/home/xuchen/st/data/test data_dir=~/st/data/test
vocab_dir=/home/xuchen/st/data/mustc/st_lcrm/en-de vocab_dir=~/st/data/mustc/st/en-de
asr_vocab_prefix=spm_unigram10000_st_share asr_vocab_prefix=spm_unigram10000_st_share
src_lang=en src_lang=en
tgt_lang=de tgt_lang=de
splits=(2019) subsets=(2019)
source ~/tools/audio/bin/activate
splits=`echo ${splits[*]} | sed 's/ /,/g'`
cp -r ${vocab_dir}/${asr_vocab_prefix}.* ${data_dir}/${src_lang}-${tgt_lang} cp -r ${vocab_dir}/${asr_vocab_prefix}.* ${data_dir}/${src_lang}-${tgt_lang}
rm -rf ${data_dir}/${src_lang}-${tgt_lang}/fbank80.zip rm -rf ${data_dir}/${src_lang}-${tgt_lang}/fbank80.zip
splits=$(echo ${subsets[*]} | sed 's/ /,/g')
cmd="python ${root_dir}/examples/speech_to_text/prep_st_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_st_data.py
--data-root ${data_dir} --data-root ${data_dir}
--output-root ${data_dir} --output-root ${data_dir}
...@@ -42,4 +42,3 @@ cmd="python ${root_dir}/examples/speech_to_text/prep_st_data.py ...@@ -42,4 +42,3 @@ cmd="python ${root_dir}/examples/speech_to_text/prep_st_data.py
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
deactivate
arch: s2t_conformer_s
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
train-subset: train_st arch: pdss2t_transformer_s_8
valid-subset: dev_st
max-epoch: 50 train-subset: train_asr
valid-subset: dev_asr
max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -11,10 +13,10 @@ log-interval: 100 ...@@ -11,10 +13,10 @@ log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_conformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -24,26 +26,16 @@ warmup-updates: 10000 ...@@ -24,26 +26,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
arch: pys2t_transformer_s arch: pdss2t_transformer_s_16
encoder-embed-dim: 256 encoder-embed-dim: 256
pyramid-stages: 4 pyramid-stages: 4
#pyramid-dropout: 0 #pyramid-dropout: 0
pyramid-layers: 2_2_6_2 pyramid-layers: 2_2_6_2
pyramid-sr-ratios: 2_2_2_2 pyramid-ratios: 2_2_2_2
pyramid-fuse: True pyramid-fusion: True
pyramid-fuse-way: all_conv pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256 pyramid-embed-dims: 256_256_256_256
pyramid-reduced-embed: conv pyramid-ds-method: conv
pyramid-embed-norm: True pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1 pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5 pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8 pyramid-ffn-ratios: 8_8_8_8
pyramid-heads: 4_4_4_4 pyramid-attn-heads: 4_4_4_4
train-subset: train_asr train-subset: train_asr
valid-subset: dev_asr valid-subset: dev_asr
...@@ -21,12 +22,13 @@ max-epoch: 100 ...@@ -21,12 +22,13 @@ max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
patience: 20 patience: 10
no-progress-bar: True no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
...@@ -42,7 +44,6 @@ lr: 2e-3 ...@@ -42,7 +44,6 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
...@@ -53,5 +54,3 @@ encoder-attention-heads: 4 ...@@ -53,5 +54,3 @@ encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
train-subset: train_st arch: pdss2t_transformer_s_32
valid-subset: dev_st
max-epoch: 50 encoder-embed-dim: 256
pyramid-stages: 5
#pyramid-dropout: 0
pyramid-layers: 2_2_3_3_2
pyramid-ratios: 2_2_2_2_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1_1
pyramid-kernel-sizes: 5_5_5_5_5
pyramid-ffn-ratios: 8_8_8_8_8
pyramid-attn-heads: 4_4_4_4_4
train-subset: train_asr
valid-subset: dev_asr
max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -11,12 +28,10 @@ log-interval: 100 ...@@ -11,12 +28,10 @@ log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_sate
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -26,32 +41,16 @@ warmup-updates: 10000 ...@@ -26,32 +41,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True decoder-embed-dim: 256
use-cnn-module: True decoder-ffn-embed-dim: 2048
cnn-module-kernel: 31 decoder-attention-heads: 4
acoustic-encoder: transformer
adapter: league
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
arch: pys2t_transformer_s arch: pdss2t_transformer_s_8
encoder-embed-dim: 256 encoder-embed-dim: 256
pyramid-stages: 4 pyramid-stages: 4
#pyramid-dropout: 0 #pyramid-dropout: 0
pyramid-layers: 3_3_3_3 pyramid-layers: 3_3_3_3
pyramid-sr-ratios: 2_2_1_2 pyramid-ratios: 2_2_1_2
pyramid-fuse: True pyramid-fusion: True
pyramid-fuse-way: all_conv pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256 pyramid-embed-dims: 256_256_256_256
pyramid-reduced-embed: conv pyramid-ds-method: conv
pyramid-embed-norm: True pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1 pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5 pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8 pyramid-ffn-ratios: 8_8_8_8
pyramid-heads: 4_4_4_4 pyramid-attn-heads: 4_4_4_4
train-subset: train_asr train-subset: train_asr
valid-subset: dev_asr valid-subset: dev_asr
...@@ -21,12 +22,13 @@ max-epoch: 100 ...@@ -21,12 +22,13 @@ max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
patience: 20 patience: 10
no-progress-bar: True no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
...@@ -42,7 +44,6 @@ lr: 2e-3 ...@@ -42,7 +44,6 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
...@@ -53,5 +54,3 @@ encoder-attention-heads: 4 ...@@ -53,5 +54,3 @@ encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
...@@ -6,23 +6,24 @@ gpu_num=8 ...@@ -6,23 +6,24 @@ gpu_num=8
update_freq=1 update_freq=1
max_tokens=40000 max_tokens=40000
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
exp_tag= exp_tag=
#config_list=(base) #config_list=(base)
#config_list=(ctc) #config_list=(ctc)
#config_list=(base conformer) #config_list=(base conformer)
#config_list=(pyramid4_base) #config_list=(pds_base_16)
config_list=(pyramid4_base conformer rpr) config_list=(pds_base_16 conformer rpr)
# exp full name # exp full name
exp_name= exp_name=
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
train_config=$(echo ${config_list[*]} | sed 's/ /,/g') train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
cmd="./run.sh cmd="./run.sh
......
set -e set -e
eval=1 eval=1
lcrm=0
root_dir=~/st/Fairseq-S2T root_dir=~/st/Fairseq-S2T
data_dir=/home/xuchen/st/data/wmt/test data_dir=/home/xuchen/st/data/wmt/test
......
arch: s2t_conformer_s
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
arch: pdss2t_transformer_s_8
train-subset: train_st train-subset: train_st
valid-subset: dev_st valid-subset: dev_st
max-epoch: 50 max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -14,7 +16,6 @@ report-accuracy: True ...@@ -14,7 +16,6 @@ report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_transformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -24,14 +25,11 @@ warmup-updates: 10000 ...@@ -24,14 +25,11 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
...@@ -40,5 +38,3 @@ encoder-attention-heads: 4 ...@@ -40,5 +38,3 @@ encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
arch: pdss2t_transformer_s_16
encoder-embed-dim: 256
pyramid-stages: 4
#pyramid-dropout: 0
pyramid-layers: 2_2_6_2
pyramid-ratios: 2_2_2_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8
pyramid-attn-heads: 4_4_4_4
train-subset: train_st train-subset: train_st
valid-subset: dev_st valid-subset: dev_st
max-epoch: 50 max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -14,7 +31,6 @@ report-accuracy: True ...@@ -14,7 +31,6 @@ report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_transformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -24,27 +40,16 @@ warmup-updates: 10000 ...@@ -24,27 +40,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 100
max-decoder-relative-length: 20
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
arch: pdss2t_transformer_s_32
encoder-embed-dim: 256
pyramid-stages: 5
#pyramid-dropout: 0
pyramid-layers: 2_2_3_3_2
pyramid-ratios: 2_2_2_2_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1_1
pyramid-kernel-sizes: 5_5_5_5_5
pyramid-ffn-ratios: 8_8_8_8_8
pyramid-attn-heads: 4_4_4_4_4
train-subset: train_st train-subset: train_st
valid-subset: dev_st valid-subset: dev_st
max-epoch: 50 max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -14,7 +31,6 @@ report-accuracy: True ...@@ -14,7 +31,6 @@ report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_conformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -24,31 +40,16 @@ warmup-updates: 10000 ...@@ -24,31 +40,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True decoder-embed-dim: 256
use-cnn-module: True decoder-ffn-embed-dim: 2048
cnn-module-kernel: 31 decoder-attention-heads: 4
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 100
max-decoder-relative-length: 20
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
arch: pys2t_transformer_s arch: pdss2t_transformer_s_8
encoder-embed-dim: 256 encoder-embed-dim: 256
#pyramid-dropout: 0
pyramid-stages: 4 pyramid-stages: 4
#pyramid-dropout: 0
pyramid-layers: 3_3_3_3 pyramid-layers: 3_3_3_3
pyramid-sr-ratios: 2_2_1_2 pyramid-ratios: 2_2_1_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256 pyramid-embed-dims: 256_256_256_256
pyramid-fuse: True pyramid-ds-method: conv
pyramid-reduced-embed: conv
pyramid-embed-norm: True pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1 pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5 pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8 pyramid-ffn-ratios: 8_8_8_8
pyramid-heads: 4_4_4_4 pyramid-attn-heads: 4_4_4_4
train-subset: train_st train-subset: train_st
valid-subset: dev_st valid-subset: dev_st
...@@ -26,10 +28,8 @@ log-interval: 100 ...@@ -26,10 +28,8 @@ log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/1002_pyramid4_all256_3333_sr8/avg_10_checkpoint.pt #load-pretrained-encoder-from:
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/1002_pyramid4_all256_3333_sr8/checkpoint_best.pt #load-pretrained-decoder-from:
load-pretrained-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/1007_st_pyramid4_all256_3333_sr8_ctc/avg_10_checkpoint.pt
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/mustc/mt/st_1003_2349_train_s_baseline/avg_10_checkpoint.pt
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
...@@ -43,7 +43,6 @@ lr: 2e-3 ...@@ -43,7 +43,6 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
...@@ -54,5 +53,3 @@ encoder-attention-heads: 4 ...@@ -54,5 +53,3 @@ encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
arch: s2t_sate_s
acoustic-encoder: transformer
adapter: league
#load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from:
...@@ -43,10 +43,11 @@ text-encoder-layers: 6 ...@@ -43,10 +43,11 @@ text-encoder-layers: 6
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True #macaron-style: True
use-cnn-module: True #use-cnn-module: True
cnn-module-kernel: 31 #cnn-module-kernel: 31
#acoustic-encoder: pds
acoustic-encoder: transformer acoustic-encoder: transformer
adapter: league adapter: league
...@@ -54,18 +55,17 @@ encoder-embed-dim: 256 ...@@ -54,18 +55,17 @@ encoder-embed-dim: 256
pyramid-stages: 4 pyramid-stages: 4
#pyramid-dropout: 0 #pyramid-dropout: 0
pyramid-layers: 3_3_3_3 pyramid-layers: 3_3_3_3
pyramid-sr-ratios: 2_2_1_2 pyramid-ratios: 2_2_1_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256 pyramid-embed-dims: 256_256_256_256
pyramid-fuse: True pyramid-ds-method: conv
pyramid-reduced-embed: conv
pyramid-embed-norm: True pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1 pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5 pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8 pyramid-ffn-ratios: 8_8_8_8
pyramid-heads: 4_4_4_4 pyramid-attn-heads: 4_4_4_4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
...@@ -6,6 +6,11 @@ gpu_num=8 ...@@ -6,6 +6,11 @@ gpu_num=8
update_freq=1 update_freq=1
max_tokens=40000 max_tokens=40000
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
exp_tag= exp_tag=
#config_list=(base) #config_list=(base)
...@@ -14,17 +19,12 @@ config_list=(ctc) ...@@ -14,17 +19,12 @@ config_list=(ctc)
#config_list=(ctc conformer rpr) #config_list=(ctc conformer rpr)
#config_list=(base sate) #config_list=(base sate)
#config_list=(pyramid4_base_sr8) #config_list=(pds_base)
#config_list=(pyramid4_base_sr8 conformer) #config_list=(pds_base conformer)
# exp full name # exp full name
exp_name= exp_name=
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
train_config=$(echo ${config_list[*]} | sed 's/ /,/g') train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
cmd="./run.sh cmd="./run.sh
......
set -e
eval=1
lcrm=0
tokenizer=0
root_dir=~/st/Fairseq-S2T
data_dir=~/st/data/test
vocab_dir=~/st/data/mustc/st/en-de
asr_vocab_prefix=spm_unigram10000_st_share
src_lang=en
tgt_lang=de
subsets=(2019)
cp -r ${vocab_dir}/${asr_vocab_prefix}.* ${data_dir}/${src_lang}-${tgt_lang}
rm -rf ${data_dir}/${src_lang}-${tgt_lang}/fbank80.zip
splits=$(echo ${subsets[*]} | sed 's/ /,/g')
cmd="python ${root_dir}/examples/speech_to_text/prep_st_data.py
--data-root ${data_dir}
--output-root ${data_dir}
--splits ${splits}
--task asr
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--add-src
--share
--asr-prefix ${asr_vocab_prefix}
--cmvn-type utterance"
if [[ ${lcrm} -eq 1 ]]; then
cmd="$cmd
--lowercase-src
--rm-punc-src"
fi
if [[ ${tokenizer} -eq 1 ]]; then
cmd="$cmd
--tokenizer"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
train-subset: train_st train-subset: train_asr
valid-subset: dev_st valid-subset: dev_asr
max-epoch: 50 max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -24,7 +24,6 @@ warmup-updates: 10000 ...@@ -24,7 +24,6 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
......
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
use-enc-dlcl: True
use-dec-dlcl: True
encoder-attention-type: local
hard-mask-window: 0
gauss-mask-sigma: 3
init-mask-weight: 0
\ No newline at end of file
train-subset: train_st arch: pdss2t_transformer_s_8
valid-subset: dev_st
max-epoch: 50 train-subset: train_asr
valid-subset: dev_asr
max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -11,10 +13,10 @@ log-interval: 100 ...@@ -11,10 +13,10 @@ log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_conformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -24,26 +26,16 @@ warmup-updates: 10000 ...@@ -24,26 +26,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
train-subset: train_st arch: pdss2t_transformer_s_16
valid-subset: dev_st
max-epoch: 50 encoder-embed-dim: 256
pyramid-stages: 4
#pyramid-dropout: 0
pyramid-layers: 2_2_6_2
pyramid-ratios: 2_2_2_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8
pyramid-attn-heads: 4_4_4_4
train-subset: train_asr
valid-subset: dev_asr
max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -11,12 +28,10 @@ log-interval: 100 ...@@ -11,12 +28,10 @@ log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_sate
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -26,32 +41,16 @@ warmup-updates: 10000 ...@@ -26,32 +41,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True decoder-embed-dim: 256
use-cnn-module: True decoder-ffn-embed-dim: 2048
cnn-module-kernel: 31 decoder-attention-heads: 4
acoustic-encoder: conformer
adapter: league
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
train-subset: train_st arch: pdss2t_transformer_s_32
valid-subset: dev_st
max-epoch: 50 encoder-embed-dim: 256
pyramid-stages: 5
#pyramid-dropout: 0
pyramid-layers: 2_2_3_3_2
pyramid-ratios: 2_2_2_2_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1_1
pyramid-kernel-sizes: 5_5_5_5_5
pyramid-ffn-ratios: 8_8_8_8_8
pyramid-attn-heads: 4_4_4_4_4
train-subset: train_asr
valid-subset: dev_asr
max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -11,12 +28,10 @@ log-interval: 100 ...@@ -11,12 +28,10 @@ log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_sate
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -26,32 +41,16 @@ warmup-updates: 10000 ...@@ -26,32 +41,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True decoder-embed-dim: 256
use-cnn-module: True decoder-ffn-embed-dim: 2048
cnn-module-kernel: 31 decoder-attention-heads: 4
acoustic-encoder: conformer
adapter: league
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
train-subset: train_st arch: pdss2t_transformer_s_8
valid-subset: dev_st
max-epoch: 50 encoder-embed-dim: 256
pyramid-stages: 4
#pyramid-dropout: 0
pyramid-layers: 3_3_3_3
pyramid-ratios: 2_2_1_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8
pyramid-attn-heads: 4_4_4_4
train-subset: train_asr
valid-subset: dev_asr
max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -11,12 +28,10 @@ log-interval: 100 ...@@ -11,12 +28,10 @@ log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_sate
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -26,37 +41,16 @@ warmup-updates: 10000 ...@@ -26,37 +41,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True decoder-embed-dim: 256
use-cnn-module: True decoder-ffn-embed-dim: 2048
cnn-module-kernel: 31 decoder-attention-heads: 4
acoustic-encoder: transformer
adapter: league
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 100
max-decoder-relative-length: 20
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
encoder-attention-type: rel_selfattn
#encoder-attention-type: relative
#max-encoder-relative-length: 100
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
gpu_num=1 gpu_num=1
data_dir= data_dir=
test_subset=(test-cleam test-other) test_subset=(test)
exp_name= exp_name=
if [ "$#" -eq 1 ]; then if [ "$#" -eq 1 ]; then
...@@ -13,7 +13,7 @@ fi ...@@ -13,7 +13,7 @@ fi
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=10000 max_tokens=80000
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
...@@ -31,9 +31,9 @@ cmd="./run.sh ...@@ -31,9 +31,9 @@ cmd="./run.sh
if [[ -n ${data_dir} ]]; then if [[ -n ${data_dir} ]]; then
cmd="$cmd --data_dir ${data_dir}" cmd="$cmd --data_dir ${data_dir}"
fi fi
if [[ -n ${test_subset} ]]; then if [[ ${#test_subset[@]} -ne 0 ]]; then
test_subset=`echo ${test_subset[*]} | sed 's/ /,/g'` subsets=$(echo ${test_subset[*]} | sed 's/ /,/g')
cmd="$cmd --test_subset ${test_subset}" cmd="$cmd --test_subset ${subsets}"
fi fi
echo $cmd echo $cmd
......
gpu_num=1 gpu_num=4
cmd="sh train.sh"
while : while :
do do
all_devices=$(seq 0 `gpustat | sed '1,2d' | wc -l`); record=$(mktemp -t temp.record.XXXXXX)
gpustat > $record
all_devices=$(seq 0 "$(sed '1,2d' ${record} | wc -l)");
count=0 count=0
for dev in ${all_devices[@]} for dev in ${all_devices[@]}
do do
line=`expr $dev + 2` line=$((dev + 2))
use=`gpustat -p | head -n $line | tail -1 | cut -d '|' -f4 | wc -w` use=$(head -n $line ${record} | tail -1 | cut -d '|' -f3 | cut -d '/' -f1)
if [[ $use -eq 0 ]]; then
if [[ $use -lt 100 ]]; then
device[$count]=$dev device[$count]=$dev
count=`expr $count + 1` count=$((count + 1))
if [[ $count -eq $gpu_num ]]; then if [[ $count -eq $gpu_num ]]; then
break break
fi fi
......
...@@ -5,17 +5,18 @@ get_devices(){ ...@@ -5,17 +5,18 @@ get_devices(){
device=() device=()
while : while :
do do
record=`mktemp -t temp.record.XXXXXX` record=$(mktemp -t temp.record.XXXXXX)
gpustat > $record gpustat > $record
all_devices=$(seq 0 `cat $record | sed '1,2d' | wc -l`); all_devices=$(seq 0 "$(sed '1,2d' ${record} | wc -l)");
count=0 count=0
for dev in ${all_devices[@]} for dev in ${all_devices[@]}
do do
line=`expr $dev + 2` line=$((dev + 2))
use=`cat $record | head -n $line | tail -1 | cut -d '|' -f3 | cut -d '/' -f1` use=$(head -n $line ${record} | tail -1 | cut -d '|' -f3 | cut -d '/' -f1)
if [[ $use -lt 100 ]]; then if [[ $use -lt 100 ]]; then
device[$count]=$dev device[$count]=$dev
count=`expr $count + 1` count=$((count + 1))
if [[ $count -eq $gpu_num ]]; then if [[ $count -eq $gpu_num ]]; then
break break
fi fi
......
#! /bin/bash #! /bin/bash
# Processing LibriSpeech Datasets # Processing ASR Datasets
# Copyright 2021 Natural Language Processing Laboratory # Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com) # Xu Chen (xuchenneu@163.com)
...@@ -20,7 +20,7 @@ stop_stage=0 ...@@ -20,7 +20,7 @@ stop_stage=0
######## hardware ######## ######## hardware ########
# devices # devices
device=() #device=()
gpu_num=8 gpu_num=8
update_freq=1 update_freq=1
...@@ -31,40 +31,40 @@ pwd_dir=$PWD ...@@ -31,40 +31,40 @@ pwd_dir=$PWD
src_lang=en src_lang=en
lang=${src_lang} lang=${src_lang}
dataset= dataset=asr
task=speech_to_text task=speech_to_text
vocab_type=unigram vocab_type=unigram
vocab_size=10000 vocab_size=5000
speed_perturb=0 speed_perturb=0
lcrm=1 lcrm=0
tokenizer=0 tokenizer=0
use_specific_dict=0 use_specific_dict=0
specific_prefix=valid specific_prefix=st
specific_dir=/home/xuchen/st/data/mustc/st_lcrm/en-de specific_dir=/home/xuchen/st/data/mustc/st/en-de
asr_vocab_prefix=spm_unigram10000_st_share asr_vocab_prefix=spm_unigram10000_st_share
org_data_dir=/media/data/${dataset} org_data_dir=~/st/data/${dataset}
data_dir=~/st/data/${dataset} data_dir=~/st/data/${dataset}/asr
train_split=train train_split=train
valid_split=valid valid_split=valid
test_split=test test_split=test
test_subset=dev-clean,dev-other,test-clean,test-other test_subset=test
# exp # exp
exp_prefix=${time} exp_prefix=$(date "+%m%d")
extra_tag= extra_tag=
extra_parameter= extra_parameter=
exp_tag=baseline exp_tag=baseline
exp_name= exp_name=
# config # config
train_config=train_ctc.yaml train_config=ctc
data_config=config.yaml data_config=config_asr.yaml
# training setting # training setting
fp16=1 fp16=1
max_tokens=20000 max_tokens=40000
step_valid=0 step_valid=0
# decoding setting # decoding setting
...@@ -77,17 +77,24 @@ if [[ ${speed_perturb} -eq 1 ]]; then ...@@ -77,17 +77,24 @@ if [[ ${speed_perturb} -eq 1 ]]; then
data_dir=${data_dir}_sp data_dir=${data_dir}_sp
exp_prefix=${exp_prefix}_sp exp_prefix=${exp_prefix}_sp
fi fi
if [[ ${lcrm} -eq 1 ]]; then
data_dir=${data_dir}_lcrm
exp_prefix=${exp_prefix}_lcrm
fi
if [[ ${use_specific_dict} -eq 1 ]]; then if [[ ${use_specific_dict} -eq 1 ]]; then
data_dir=${data_dir}_${specific_prefix} data_dir=${data_dir}_${specific_prefix}
exp_prefix=${exp_prefix}_${specific_prefix} exp_prefix=${exp_prefix}_${specific_prefix}
fi fi
if [[ ${tokenizer} -eq 1 ]]; then
data_dir=${data_dir}_tok
exp_prefix=${exp_prefix}_tok
fi
. ./local/parse_options.sh || exit 1; . ./local/parse_options.sh || exit 1;
# full path
train_config=$pwd_dir/conf/${train_config}
if [[ -z ${exp_name} ]]; then if [[ -z ${exp_name} ]]; then
exp_name=${exp_prefix}_$(basename ${train_config%.*})_${exp_tag} config_string=${train_config//,/_}
exp_name=${exp_prefix}_${config_string}_${exp_tag}
if [[ -n ${extra_tag} ]]; then if [[ -n ${extra_tag} ]]; then
exp_name=${exp_name}_${extra_tag} exp_name=${exp_name}_${extra_tag}
fi fi
...@@ -102,12 +109,10 @@ fi ...@@ -102,12 +109,10 @@ fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
### Task dependent. You have to make data the following preparation part by yourself. ### Task dependent. You have to make data the following preparation part by yourself.
### But you can utilize Kaldi recipes in most cases ### But you can utilize Kaldi recipes in most cases
echo "stage 0: Data Preparation" echo "stage 0: ASR Data Preparation"
if [[ ! -e ${data_dir} ]]; then if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir} mkdir -p ${data_dir}
fi fi
source ~/tools/audio/bin/activate
cmd="python ${root_dir}/examples/speech_to_text/prep_asr_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_asr_data.py
--data-root ${org_data_dir} --data-root ${org_data_dir}
...@@ -127,7 +132,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -127,7 +132,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="$cmd cmd="$cmd
--speed-perturb" --speed-perturb"
fi fi
if [[ ${lcrm} -eq 1 ]]; then if [[ ${lcrm} -eq 1 ]]; then
cmd="$cmd cmd="$cmd
--lowercase-src --lowercase-src
--rm-punc-src" --rm-punc-src"
...@@ -136,17 +141,20 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -136,17 +141,20 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd="$cmd cmd="$cmd
--tokenizer" --tokenizer"
fi fi
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval $cmd [[ $eval -eq 1 ]] && eval ${cmd}
fi fi
data_dir=${data_dir}/${lang}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: ASR Network Training" echo "stage 1: ASR Network Training"
[[ ! -d ${data_dir} ]] && echo "The data dir ${data_dir} is not existing!" && exit 1; [[ ! -d ${data_dir} ]] && echo "The data dir ${data_dir} is not existing!" && exit 1;
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then if [[ ${gpu_num} -eq 0 ]]; then
device=() device=""
else else
source ./local/utils.sh source ./local/utils.sh
device=$(get_devices $gpu_num 0) device=$(get_devices $gpu_num 0)
...@@ -163,12 +171,31 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -163,12 +171,31 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
cp ${BASH_SOURCE[0]} ${model_dir} cp ${BASH_SOURCE[0]} ${model_dir}
cp ${PWD}/train.sh ${model_dir} cp ${PWD}/train.sh ${model_dir}
cp ${train_config} ${model_dir}
config_list="${train_config//,/ }"
idx=0
for config in ${config_list[@]}
do
config_path=$pwd_dir/conf/${config}.yaml
if [[ ! -f ${config_path} ]]; then
echo "No config file ${config_path}"
exit
fi
cp ${config_path} ${model_dir}
if [[ idx -eq 0 ]]; then
extra_parameter="${extra_parameter}
--train-config ${config_path}"
else
extra_parameter="${extra_parameter}
--train-config${idx} ${config_path}"
fi
idx=$((idx + 1))
done
cmd="python3 -u ${root_dir}/fairseq_cli/train.py cmd="python3 -u ${root_dir}/fairseq_cli/train.py
${data_dir} ${data_dir}
--config-yaml ${data_config} --config-yaml ${data_config}
--train-config ${train_config}
--task ${task} --task ${task}
--max-tokens ${max_tokens} --max-tokens ${max_tokens}
--skip-invalid-size-inputs-valid-test --skip-invalid-size-inputs-valid-test
...@@ -177,7 +204,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -177,7 +204,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--save-dir ${model_dir} --save-dir ${model_dir}
--tensorboard-logdir ${model_dir}" --tensorboard-logdir ${model_dir}"
if [[ -n ${extra_parameter} ]]; then if [[ -n ${extra_parameter} ]]; then
cmd="${cmd} cmd="${cmd}
${extra_parameter}" ${extra_parameter}"
fi fi
...@@ -230,8 +257,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -230,8 +257,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# save info # save info
log=./history.log log=./history.log
echo "${time} | ${device} | ${data_dir} | ${model_dir} " >> $log echo "${time} | ${device} | ${data_dir} | ${exp_name} | ${model_dir} " >> $log
cat $log | tail -n 50 > tmp.log tail -n 50 ${log} > tmp.log
mv tmp.log $log mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
...@@ -239,7 +266,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -239,7 +266,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
sleep 2s sleep 2s
tail -n `wc -l ${model_dir}/train.log | awk '{print $1+1}'` -f ${model_dir}/train.log tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log
fi fi
fi fi
wait wait
...@@ -262,7 +289,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -262,7 +289,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then if [[ ${gpu_num} -eq 0 ]]; then
device=() device=""
else else
source ./local/utils.sh source ./local/utils.sh
device=$(get_devices $gpu_num 0) device=$(get_devices $gpu_num 0)
...@@ -270,14 +297,12 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -270,14 +297,12 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
#tmp_file=$(mktemp ${model_dir}/tmp-XXXXX)
#trap 'rm -rf ${tmp_file}' EXIT
result_file=${model_dir}/decode_result result_file=${model_dir}/decode_result
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=(${test_subset//,/ }) test_subset=${test_subset//,/ }
for subset in ${test_subset[@]}; do for subset in ${test_subset[@]}; do
subset=${subset} subset=${subset}_asr
cmd="python ${root_dir}/fairseq_cli/generate.py cmd="python ${root_dir}/fairseq_cli/generate.py
${data_dir} ${data_dir}
--config-yaml ${data_config} --config-yaml ${data_config}
...@@ -288,7 +313,11 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -288,7 +313,11 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--max-tokens ${max_tokens} --max-tokens ${max_tokens}
--beam ${beam_size} --beam ${beam_size}
--lenpen ${len_penalty} --lenpen ${len_penalty}
--scoring wer" --scoring wer
--wer-tokenizer 13a
--wer-lowercase
--wer-remove-punct
"
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
......
...@@ -3,17 +3,28 @@ ...@@ -3,17 +3,28 @@
# training the model # training the model
gpu_num=8 gpu_num=8
update_freq=2 update_freq=1
max_tokens=20000 max_tokens=40000
extra_tag= extra_tag=
extra_parameter= extra_parameter=
#extra_tag="${extra_tag}" #extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} " #extra_parameter="${extra_parameter} "
exp_tag= exp_tag=
train_config=train_ctc.yaml
#config_list=(base)
#config_list=(ctc)
#config_list=(base conformer)
#config_list=(pds_base_16)
config_list=(pds_base_16 conformer rpr)
# exp full name
exp_name=
train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
cmd="./run.sh cmd="./run.sh
--stage 1 --stage 1
...@@ -24,6 +35,9 @@ cmd="./run.sh ...@@ -24,6 +35,9 @@ cmd="./run.sh
--max_tokens ${max_tokens} --max_tokens ${max_tokens}
" "
if [[ -n ${exp_name} ]]; then
cmd="$cmd --exp_name ${exp_name}"
fi
if [[ -n ${exp_tag} ]]; then if [[ -n ${exp_tag} ]]; then
cmd="$cmd --exp_tag ${exp_tag}" cmd="$cmd --exp_tag ${exp_tag}"
fi fi
...@@ -34,5 +48,5 @@ if [[ -n ${extra_parameter} ]]; then ...@@ -34,5 +48,5 @@ if [[ -n ${extra_parameter} ]]; then
cmd="$cmd --extra_parameter \"${extra_parameter}\"" cmd="$cmd --extra_parameter \"${extra_parameter}\""
fi fi
echo $cmd echo ${cmd}
eval $cmd eval ${cmd}
use-enc-dlcl: True
use-dec-dlcl: True
#encoder-attention-type: rel_selfattn
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 20
max-decoder-relative-length: 20
\ No newline at end of file
...@@ -13,7 +13,7 @@ fi ...@@ -13,7 +13,7 @@ fi
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=10000 max_tokens=80000
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
......
gpu_num=1 gpu_num=4
cmd="sh train.sh"
while : while :
do do
all_devices=$(seq 0 `gpustat | sed '1,2d' | wc -l`); record=$(mktemp -t temp.record.XXXXXX)
gpustat > $record
all_devices=$(seq 0 "$(sed '1,2d' ${record} | wc -l)");
count=0 count=0
for dev in ${all_devices[@]} for dev in ${all_devices[@]}
do do
line=`expr $dev + 2` line=$((dev + 2))
use=`gpustat -p | head -n $line | tail -1 | cut -d '|' -f4 | wc -w` use=$(head -n $line ${record} | tail -1 | cut -d '|' -f3 | cut -d '/' -f1)
if [[ $use -eq 0 ]]; then
if [[ $use -lt 100 ]]; then
device[$count]=$dev device[$count]=$dev
count=`expr $count + 1` count=$((count + 1))
if [[ $count -eq $gpu_num ]]; then if [[ $count -eq $gpu_num ]]; then
break break
fi fi
......
...@@ -5,17 +5,18 @@ get_devices(){ ...@@ -5,17 +5,18 @@ get_devices(){
device=() device=()
while : while :
do do
record=`mktemp -t temp.record.XXXXXX` record=$(mktemp -t temp.record.XXXXXX)
gpustat > $record gpustat > $record
all_devices=$(seq 0 `cat $record | sed '1,2d' | wc -l`); all_devices=$(seq 0 "$(sed '1,2d' ${record} | wc -l)");
count=0 count=0
for dev in ${all_devices[@]} for dev in ${all_devices[@]}
do do
line=`expr $dev + 2` line=$((dev + 2))
use=`cat $record | head -n $line | tail -1 | cut -d '|' -f3 | cut -d '/' -f1` use=$(head -n $line ${record} | tail -1 | cut -d '|' -f3 | cut -d '/' -f1)
if [[ $use -lt 100 ]]; then if [[ $use -lt 100 ]]; then
device[$count]=$dev device[$count]=$dev
count=`expr $count + 1` count=$((count + 1))
if [[ $count -eq $gpu_num ]]; then if [[ $count -eq $gpu_num ]]; then
break break
fi fi
......
...@@ -20,7 +20,7 @@ stop_stage=0 ...@@ -20,7 +20,7 @@ stop_stage=0
######## hardware ######## ######## hardware ########
# devices # devices
#device=() device=()
gpu_num=8 gpu_num=8
update_freq=1 update_freq=1
...@@ -32,21 +32,21 @@ src_lang=en ...@@ -32,21 +32,21 @@ src_lang=en
tgt_lang=de tgt_lang=de
lang=${src_lang}-${tgt_lang} lang=${src_lang}-${tgt_lang}
dataset= dataset=mt
task=translation task=translation
vocab_type=unigram vocab_type=unigram
vocab_size=10000 vocab_size=10000
share_dict=1 share_dict=1
lcrm=1 lcrm=0
tokenizer=1 tokenizer=0
use_specific_dict=0 use_specific_dict=0
specific_prefix=wmt_share32k specific_prefix=st
specific_dir=/home/xuchen/st/data/wmt/mt_lcrm/en-de/unigram32000_share specific_dir=/home/xuchen/st/data/mustc/st/en-de/
src_vocab_prefix=spm_unigram32000_share src_vocab_prefix=spm_unigram10000_st_share
tgt_vocab_prefix=spm_unigram32000_share tgt_vocab_prefix=spm_unigram10000_st_share
org_data_dir=/media/data/${dataset} org_data_dir=~/st/data/${dataset}
data_dir=~/st/data/${dataset}/mt/${lang} data_dir=~/st/data/${dataset}/mt/${lang}
train_subset=train train_subset=train
valid_subset=dev valid_subset=dev
...@@ -61,7 +61,7 @@ exp_tag=baseline ...@@ -61,7 +61,7 @@ exp_tag=baseline
exp_name= exp_name=
# config # config
train_config=train.yaml train_config=base_s
# training setting # training setting
fp16=1 fp16=1
...@@ -104,9 +104,9 @@ fi ...@@ -104,9 +104,9 @@ fi
. ./local/parse_options.sh || exit 1; . ./local/parse_options.sh || exit 1;
# full path # full path
train_config=$pwd_dir/conf/${train_config}
if [[ -z ${exp_name} ]]; then if [[ -z ${exp_name} ]]; then
exp_name=${exp_prefix}_$(basename ${train_config%.*})_${exp_tag} config_string=${train_config//,/_}
exp_name=${exp_prefix}_${config_string}_${exp_tag}
if [[ -n ${extra_tag} ]]; then if [[ -n ${extra_tag} ]]; then
exp_name=${exp_name}_${extra_tag} exp_name=${exp_name}_${extra_tag}
fi fi
...@@ -150,7 +150,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -150,7 +150,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
mkdir -p ${data_dir}/data mkdir -p ${data_dir}/data
for split in ${train_subset} ${valid_subset} ${trans_subset}; do for split in ${train_subset} ${valid_subset} ${trans_subset}; do
{ {
cmd="cat ${org_data_dir}/${lang}/data/${split}.${src_lang}" cmd="cat ${org_data_dir}/${lang}/data/${split}/txt/${split}.${src_lang}"
if [[ ${lcrm} -eq 1 ]]; then if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${org_data_dir}/${lang}/data/${split}.${src_lang}" cmd="python local/lower_rm.py ${org_data_dir}/${lang}/data/${split}.${src_lang}"
fi fi
...@@ -178,7 +178,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -178,7 +178,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--source-lang ${src_lang} --target-lang ${tgt_lang} --source-lang ${src_lang} --target-lang ${tgt_lang}
--trainpref ${data_dir}/data/${train_subset} --trainpref ${data_dir}/data/${train_subset}
--validpref ${data_dir}/data/${valid_subset} --validpref ${data_dir}/data/${valid_subset}
--testpref ${data_dir}/data/${test_subset} --testpref ${data_dir}/data/${trans_subset}
--destdir ${data_dir}/data-bin --destdir ${data_dir}/data-bin
--srcdict ${data_dir}/${src_vocab_prefix}.txt --srcdict ${data_dir}/${src_vocab_prefix}.txt
--tgtdict ${data_dir}/${tgt_vocab_prefix}.txt --tgtdict ${data_dir}/${tgt_vocab_prefix}.txt
...@@ -196,7 +196,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -196,7 +196,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then if [[ ${gpu_num} -eq 0 ]]; then
device=() device=""
else else
source ./local/utils.sh source ./local/utils.sh
device=$(get_devices $gpu_num 0) device=$(get_devices $gpu_num 0)
...@@ -213,13 +213,32 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -213,13 +213,32 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
cp ${BASH_SOURCE[0]} ${model_dir} cp ${BASH_SOURCE[0]} ${model_dir}
cp ${PWD}/train.sh ${model_dir} cp ${PWD}/train.sh ${model_dir}
cp ${train_config} ${model_dir}
config_list="${train_config//,/ }"
idx=0
for config in ${config_list[@]}
do
config_path=$pwd_dir/conf/${config}.yaml
if [[ ! -f ${config_path} ]]; then
echo "No config file ${config_path}"
exit
fi
cp ${config_path} ${model_dir}
if [[ idx -eq 0 ]]; then
extra_parameter="${extra_parameter}
--train-config ${config_path}"
else
extra_parameter="${extra_parameter}
--train-config${idx} ${config_path}"
fi
idx=$((idx + 1))
done
cmd="python3 -u ${root_dir}/fairseq_cli/train.py cmd="python3 -u ${root_dir}/fairseq_cli/train.py
${data_dir} ${data_dir}
--source-lang ${src_lang} --source-lang ${src_lang}
--target-lang ${tgt_lang} --target-lang ${tgt_lang}
--train-config ${train_config}
--task ${task} --task ${task}
--max-tokens ${max_tokens} --max-tokens ${max_tokens}
--skip-invalid-size-inputs-valid-test --skip-invalid-size-inputs-valid-test
...@@ -228,7 +247,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -228,7 +247,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--save-dir ${model_dir} --save-dir ${model_dir}
--tensorboard-logdir ${model_dir}" --tensorboard-logdir ${model_dir}"
if [[ -n ${extra_parameter} ]]; then if [[ -n ${extra_parameter} ]]; then
cmd="${cmd} cmd="${cmd}
${extra_parameter}" ${extra_parameter}"
fi fi
...@@ -246,7 +265,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -246,7 +265,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
save_interval=1 save_interval=1
keep_last_epochs=10 keep_last_epochs=10
no_epoch_checkpoints=0 no_epoch_checkpoints=0
save_interval_updates=10000 save_interval_updates=500
keep_interval_updates=10 keep_interval_updates=10
else else
validate_interval=1 validate_interval=1
...@@ -290,8 +309,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -290,8 +309,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# save info # save info
log=./history.log log=./history.log
echo "${time} | ${device} | ${data_dir} | ${model_dir} " >> $log echo "${time} | ${device} | ${data_dir} | ${exp_name} | ${model_dir} " >> $log
cat $log | tail -n 50 > tmp.log tail -n 50 ${log} > tmp.log
mv tmp.log $log mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
...@@ -299,7 +318,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -299,7 +318,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
sleep 2s sleep 2s
tail -n `wc -l ${model_dir}/train.log | awk '{print $1+1}'` -f ${model_dir}/train.log tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log
fi fi
fi fi
wait wait
...@@ -322,7 +341,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -322,7 +341,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then if [[ ${gpu_num} -eq 0 ]]; then
device=() device=""
else else
source ./local/utils.sh source ./local/utils.sh
device=$(get_devices $gpu_num 0) device=$(get_devices $gpu_num 0)
...@@ -335,7 +354,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -335,7 +354,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
test_subset=(${test_subset//,/ }) test_subset=(${test_subset//,/ })
for subset in ${test_subset[@]}; do for subset in ${test_subset[@]}; do
subset=${subset}_st
cmd="python ${root_dir}/fairseq_cli/generate.py cmd="python ${root_dir}/fairseq_cli/generate.py
${data_dir} ${data_dir}
--source-lang ${src_lang} --source-lang ${src_lang}
......
...@@ -4,16 +4,20 @@ ...@@ -4,16 +4,20 @@
gpu_num=1 gpu_num=1
update_freq=1 update_freq=1
max_tokens=4096 max_tokens=8192
exp_tag=baseline
config_list=(base)
# exp full name
exp_name=
extra_tag= extra_tag=
extra_parameter= extra_parameter=
#extra_tag="${extra_tag}" #extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} " #extra_parameter="${extra_parameter} "
exp_tag=baseline train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
train_config=train.yaml
cmd="./run.sh cmd="./run.sh
--stage 1 --stage 1
...@@ -24,6 +28,9 @@ cmd="./run.sh ...@@ -24,6 +28,9 @@ cmd="./run.sh
--max_tokens ${max_tokens} --max_tokens ${max_tokens}
" "
if [[ -n ${exp_name} ]]; then
cmd="$cmd --exp_name ${exp_name}"
fi
if [[ -n ${exp_tag} ]]; then if [[ -n ${exp_tag} ]]; then
cmd="$cmd --exp_tag ${exp_tag}" cmd="$cmd --exp_tag ${exp_tag}"
fi fi
...@@ -34,5 +41,5 @@ if [[ -n ${extra_parameter} ]]; then ...@@ -34,5 +41,5 @@ if [[ -n ${extra_parameter} ]]; then
cmd="$cmd --extra_parameter \"${extra_parameter}\"" cmd="$cmd --extra_parameter \"${extra_parameter}\""
fi fi
echo $cmd echo ${cmd}
eval $cmd eval ${cmd}
...@@ -15,9 +15,7 @@ src_lang=en ...@@ -15,9 +15,7 @@ src_lang=en
tgt_lang=de tgt_lang=de
splits=(2019) splits=(2019)
source ~/tools/audio/bin/activate splits=$(echo ${splits[*]} | sed 's/ /_/g')
splits=`echo ${splits[*]} | sed 's/ /,/g'`
cp -r ${vocab_dir}/${asr_vocab_prefix}.* ${data_dir}/${src_lang}-${tgt_lang} cp -r ${vocab_dir}/${asr_vocab_prefix}.* ${data_dir}/${src_lang}-${tgt_lang}
cp -r ${vocab_dir}/${st_vocab_prefix}.* ${data_dir}/${src_lang}-${tgt_lang} cp -r ${vocab_dir}/${st_vocab_prefix}.* ${data_dir}/${src_lang}-${tgt_lang}
...@@ -48,4 +46,3 @@ cmd="python ${root_dir}/examples/speech_to_text/prep_st_data.py ...@@ -48,4 +46,3 @@ cmd="python ${root_dir}/examples/speech_to_text/prep_st_data.py
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
deactivate
train-subset: train_st train-subset: train_st
valid-subset: dev_st valid-subset: dev_st
max-epoch: 50 max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -24,7 +24,6 @@ warmup-updates: 10000 ...@@ -24,7 +24,6 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
......
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
ctc-weight: 0.3
\ No newline at end of file
use-enc-dlcl: True
use-dec-dlcl: True
encoder-attention-type: local
hard-mask-window: 0
gauss-mask-sigma: 3
init-mask-weight: 0
\ No newline at end of file
arch: pdss2t_transformer_s_8
train-subset: train_st train-subset: train_st
valid-subset: dev_st valid-subset: dev_st
max-epoch: 50 max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -14,7 +16,6 @@ report-accuracy: True ...@@ -14,7 +16,6 @@ report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_transformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -24,14 +25,11 @@ warmup-updates: 10000 ...@@ -24,14 +25,11 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
...@@ -40,5 +38,3 @@ encoder-attention-heads: 4 ...@@ -40,5 +38,3 @@ encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
arch: pdss2t_transformer_s_16
encoder-embed-dim: 256
pyramid-stages: 4
#pyramid-dropout: 0
pyramid-layers: 2_2_6_2
pyramid-ratios: 2_2_2_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8
pyramid-attn-heads: 4_4_4_4
train-subset: train_st train-subset: train_st
valid-subset: dev_st valid-subset: dev_st
max-epoch: 50 max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -14,7 +31,6 @@ report-accuracy: True ...@@ -14,7 +31,6 @@ report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_transformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -24,27 +40,16 @@ warmup-updates: 10000 ...@@ -24,27 +40,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 100
max-decoder-relative-length: 20
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
arch: pdss2t_transformer_s_32
encoder-embed-dim: 256
pyramid-stages: 5
#pyramid-dropout: 0
pyramid-layers: 2_2_3_3_2
pyramid-ratios: 2_2_2_2_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1_1
pyramid-kernel-sizes: 5_5_5_5_5
pyramid-ffn-ratios: 8_8_8_8_8
pyramid-attn-heads: 4_4_4_4_4
train-subset: train_st train-subset: train_st
valid-subset: dev_st valid-subset: dev_st
max-epoch: 50 max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -14,7 +31,6 @@ report-accuracy: True ...@@ -14,7 +31,6 @@ report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_conformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -24,31 +40,16 @@ warmup-updates: 10000 ...@@ -24,31 +40,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True decoder-embed-dim: 256
use-cnn-module: True decoder-ffn-embed-dim: 2048
cnn-module-kernel: 31 decoder-attention-heads: 4
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 100
max-decoder-relative-length: 20
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pyramid-stages: 4
#pyramid-dropout: 0
pyramid-layers: 3_3_3_3
pyramid-ratios: 2_2_1_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8
pyramid-attn-heads: 4_4_4_4
train-subset: train_st train-subset: train_st
valid-subset: dev_st valid-subset: dev_st
max-epoch: 50 max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -12,11 +29,8 @@ seed: 1 ...@@ -12,11 +29,8 @@ seed: 1
report-accuracy: True report-accuracy: True
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
arch: s2t_sate
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -26,37 +40,16 @@ warmup-updates: 10000 ...@@ -26,37 +40,16 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
acoustic-encoder: conformer
adapter: league
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 100
max-decoder-relative-length: 20
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
encoder-attention-type: rel_selfattn
#encoder-attention-type: relative
#decoder-attention-type: relative
#max-encoder-relative-length: 100
#max-decoder-relative-length: 20
train-subset: train_st train-subset: train_st
valid-subset: dev_st valid-subset: dev_st
max-epoch: 50 max-epoch: 100
max-update: 100000 max-update: 100000
num-workers: 8 num-workers: 8
...@@ -43,20 +43,29 @@ text-encoder-layers: 6 ...@@ -43,20 +43,29 @@ text-encoder-layers: 6
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4 encoder-attention-heads: 4
macaron-style: True #macaron-style: True
use-cnn-module: True #use-cnn-module: True
cnn-module-kernel: 31 #cnn-module-kernel: 31
acoustic-encoder: conformer #acoustic-encoder: pds
acoustic-encoder: transformer
adapter: league adapter: league
encoder-attention-type: relative encoder-embed-dim: 256
decoder-attention-type: relative pyramid-stages: 4
max-encoder-relative-length: 100 #pyramid-dropout: 0
max-decoder-relative-length: 20 pyramid-layers: 3_3_3_3
pyramid-ratios: 2_2_1_2
pyramid-fusion: True
pyramid-fusion-method: all_conv
pyramid-embed-dims: 256_256_256_256
pyramid-ds-method: conv
pyramid-embed-norm: True
pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_8
pyramid-attn-heads: 4_4_4_4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
...@@ -13,7 +13,7 @@ fi ...@@ -13,7 +13,7 @@ fi
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=10000 max_tokens=80000
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
...@@ -31,9 +31,9 @@ cmd="./run.sh ...@@ -31,9 +31,9 @@ cmd="./run.sh
if [[ -n ${data_dir} ]]; then if [[ -n ${data_dir} ]]; then
cmd="$cmd --data_dir ${data_dir}" cmd="$cmd --data_dir ${data_dir}"
fi fi
if [[ -n ${test_subset} ]]; then if [[ ${#test_subset[@]} -eq 0 ]]; then
test_subset=`echo ${test_subset[*]} | sed 's/ /,/g'` subsets=$(echo ${test_subset[*]} | sed 's/ /,/g')
cmd="$cmd --test_subset ${test_subset}" cmd="$cmd --test_subset ${subsets}"
fi fi
echo $cmd echo $cmd
......
gpu_num=1 gpu_num=4
cmd="sh train.sh"
while : while :
do do
all_devices=$(seq 0 `gpustat | sed '1,2d' | wc -l`); record=$(mktemp -t temp.record.XXXXXX)
gpustat > $record
all_devices=$(seq 0 "$(sed '1,2d' ${record} | wc -l)");
count=0 count=0
for dev in ${all_devices[@]} for dev in ${all_devices[@]}
do do
line=`expr $dev + 2` line=$((dev + 2))
use=`gpustat -p | head -n $line | tail -1 | cut -d '|' -f4 | wc -w` use=$(head -n $line ${record} | tail -1 | cut -d '|' -f3 | cut -d '/' -f1)
if [[ $use -eq 0 ]]; then
if [[ $use -lt 100 ]]; then
device[$count]=$dev device[$count]=$dev
count=`expr $count + 1` count=$((count + 1))
if [[ $count -eq $gpu_num ]]; then if [[ $count -eq $gpu_num ]]; then
break break
fi fi
......
...@@ -5,17 +5,18 @@ get_devices(){ ...@@ -5,17 +5,18 @@ get_devices(){
device=() device=()
while : while :
do do
record=`mktemp -t temp.record.XXXXXX` record=$(mktemp -t temp.record.XXXXXX)
gpustat > $record gpustat > $record
all_devices=$(seq 0 `cat $record | sed '1,2d' | wc -l`); all_devices=$(seq 0 "$(sed '1,2d' ${record} | wc -l)");
count=0 count=0
for dev in ${all_devices[@]} for dev in ${all_devices[@]}
do do
line=`expr $dev + 2` line=$((dev + 2))
use=`cat $record | head -n $line | tail -1 | cut -d '|' -f3 | cut -d '/' -f1` use=$(head -n $line ${record} | tail -1 | cut -d '|' -f3 | cut -d '/' -f1)
if [[ $use -lt 100 ]]; then if [[ $use -lt 100 ]]; then
device[$count]=$dev device[$count]=$dev
count=`expr $count + 1` count=$((count + 1))
if [[ $count -eq $gpu_num ]]; then if [[ $count -eq $gpu_num ]]; then
break break
fi fi
......
...@@ -32,14 +32,14 @@ src_lang=en ...@@ -32,14 +32,14 @@ src_lang=en
tgt_lang=de tgt_lang=de
lang=${src_lang}-${tgt_lang} lang=${src_lang}-${tgt_lang}
dataset=mustc-v2 dataset=st
task=speech_to_text task=speech_to_text
vocab_type=unigram vocab_type=unigram
asr_vocab_size=5000 asr_vocab_size=5000
vocab_size=10000 vocab_size=10000
share_dict=1 share_dict=1
speed_perturb=0 speed_perturb=0
lcrm=1 lcrm=0
tokenizer=0 tokenizer=0
use_specific_dict=0 use_specific_dict=0
...@@ -48,19 +48,19 @@ specific_dir=/home/xuchen/st/data/mustc/st_lcrm/en-de ...@@ -48,19 +48,19 @@ specific_dir=/home/xuchen/st/data/mustc/st_lcrm/en-de
asr_vocab_prefix=spm_unigram10000_st_share asr_vocab_prefix=spm_unigram10000_st_share
st_vocab_prefix=spm_unigram10000_st_share st_vocab_prefix=spm_unigram10000_st_share
org_data_dir=/media/data/${dataset} org_data_dir=~/st/data/${dataset}
data_dir=~/st/data/${dataset}/st data_dir=~/st/data/${dataset}/st
test_subset=tst-COMMON test_subset=tst-COMMON
# exp # exp
exp_prefix=${time} exp_prefix=$(date "+%m%d")
extra_tag= extra_tag=
extra_parameter= extra_parameter=
exp_tag=baseline exp_tag=baseline
exp_name= exp_name=
# config # config
train_config=train_ctc.yaml train_config=ctc
# training setting # training setting
fp16=1 fp16=1
...@@ -98,10 +98,9 @@ fi ...@@ -98,10 +98,9 @@ fi
. ./local/parse_options.sh || exit 1; . ./local/parse_options.sh || exit 1;
# full path
train_config=$pwd_dir/conf/${train_config}
if [[ -z ${exp_name} ]]; then if [[ -z ${exp_name} ]]; then
exp_name=${exp_prefix}_$(basename ${train_config%.*})_${exp_tag} config_string=${train_config//,/_}
exp_name=${exp_prefix}_${config_string}_${exp_tag}
if [[ -n ${extra_tag} ]]; then if [[ -n ${extra_tag} ]]; then
exp_name=${exp_name}_${extra_tag} exp_name=${exp_name}_${extra_tag}
fi fi
...@@ -120,7 +119,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -120,7 +119,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir}/${lang} ]]; then if [[ ! -e ${data_dir}/${lang} ]]; then
mkdir -p ${data_dir}/${lang} mkdir -p ${data_dir}/${lang}
fi fi
source ~/tools/audio/bin/activate
cmd="python ${root_dir}/examples/speech_to_text/prep_asr_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_asr_data.py
--data-root ${org_data_dir} --data-root ${org_data_dir}
...@@ -183,7 +181,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -183,7 +181,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
deactivate
fi fi
data_dir=${data_dir}/${lang} data_dir=${data_dir}/${lang}
...@@ -194,7 +191,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -194,7 +191,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then if [[ ${gpu_num} -eq 0 ]]; then
device=() device=""
else else
source ./local/utils.sh source ./local/utils.sh
device=$(get_devices $gpu_num 0) device=$(get_devices $gpu_num 0)
...@@ -211,12 +208,31 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -211,12 +208,31 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
cp ${BASH_SOURCE[0]} ${model_dir} cp ${BASH_SOURCE[0]} ${model_dir}
cp ${PWD}/train.sh ${model_dir} cp ${PWD}/train.sh ${model_dir}
cp ${train_config} ${model_dir}
config_list="${train_config//,/ }"
idx=0
for config in ${config_list[@]}
do
config_path=$pwd_dir/conf/${config}.yaml
if [[ ! -f ${config_path} ]]; then
echo "No config file ${config_path}"
exit
fi
cp ${config_path} ${model_dir}
if [[ idx -eq 0 ]]; then
extra_parameter="${extra_parameter}
--train-config ${config_path}"
else
extra_parameter="${extra_parameter}
--train-config${idx} ${config_path}"
fi
idx=$((idx + 1))
done
cmd="python3 -u ${root_dir}/fairseq_cli/train.py cmd="python3 -u ${root_dir}/fairseq_cli/train.py
${data_dir} ${data_dir}
--config-yaml ${data_config} --config-yaml ${data_config}
--train-config ${train_config}
--task ${task} --task ${task}
--max-tokens ${max_tokens} --max-tokens ${max_tokens}
--skip-invalid-size-inputs-valid-test --skip-invalid-size-inputs-valid-test
...@@ -225,7 +241,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -225,7 +241,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--save-dir ${model_dir} --save-dir ${model_dir}
--tensorboard-logdir ${model_dir}" --tensorboard-logdir ${model_dir}"
if [[ -n ${extra_parameter} ]]; then if [[ -n ${extra_parameter} ]]; then
cmd="${cmd} cmd="${cmd}
${extra_parameter}" ${extra_parameter}"
fi fi
...@@ -287,8 +303,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -287,8 +303,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# save info # save info
log=./history.log log=./history.log
echo "${time} | ${device} | ${data_dir} | ${model_dir} " >> $log echo "${time} | ${device} | ${data_dir} | ${exp_name} | ${model_dir} " >> $log
cat $log | tail -n 50 > tmp.log tail -n 50 ${log} > tmp.log
mv tmp.log $log mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device} export CUDA_VISIBLE_DEVICES=${device}
...@@ -296,7 +312,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -296,7 +312,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
sleep 2s sleep 2s
tail -n `wc -l ${model_dir}/train.log | awk '{print $1+1}'` -f ${model_dir}/train.log tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log
fi fi
fi fi
wait wait
...@@ -319,7 +335,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -319,7 +335,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then if [[ ${gpu_num} -eq 0 ]]; then
device=() device=""
else else
source ./local/utils.sh source ./local/utils.sh
device=$(get_devices $gpu_num 0) device=$(get_devices $gpu_num 0)
...@@ -330,8 +346,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -330,8 +346,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
result_file=${model_dir}/decode_result result_file=${model_dir}/decode_result
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
test_subset=(${test_subset//,/ }) test_subset=${test_subset//,/ }
for subset in ${test_subset[@]}; do for subset in "${test_subset[@]}"; do
subset=${subset}_st subset=${subset}_st
cmd="python ${root_dir}/fairseq_cli/generate.py cmd="python ${root_dir}/fairseq_cli/generate.py
${data_dir} ${data_dir}
......
...@@ -3,30 +3,29 @@ ...@@ -3,30 +3,29 @@
# training the model # training the model
gpu_num=8 gpu_num=8
update_freq=2 update_freq=1
max_tokens=20000 max_tokens=40000
exp_name=
extra_tag= extra_tag=
extra_parameter= extra_parameter=
#extra_tag="${extra_tag}" #extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} " #extra_parameter="${extra_parameter} "
#extra_tag="${extra_tag}_encdlcl" exp_tag=
#extra_parameter="${extra_parameter} --use-enc-dlcl"
#config_list=(base)
config_list=(ctc)
#config_list=(sate_ctc)
#config_list=(ctc conformer rpr)
#config_list=(base sate)
#extra_tag="${extra_tag}_decdlcl" #config_list=(pds_base)
#extra_parameter="${extra_parameter} --use-dec-dlcl" #config_list=(pds_base conformer)
# exp full name
exp_name=
exp_tag=baseline train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
train_config=train_ctc.yaml
#train_config=train_ctc_conformer.yaml
#train_config=train_ctc_conformer_rpr.yaml
#train_config=train_ctc_sate.yaml
#train_config=train_ctc_sate_rpr.yaml
#train_config=train_ctc_sate_conformer.yaml
#train_config=train_ctc_sate_conformer_rpr.yaml
cmd="./run.sh cmd="./run.sh
--stage 1 --stage 1
...@@ -50,5 +49,5 @@ if [[ -n ${extra_parameter} ]]; then ...@@ -50,5 +49,5 @@ if [[ -n ${extra_parameter} ]]; then
cmd="$cmd --extra_parameter \"${extra_parameter}\"" cmd="$cmd --extra_parameter \"${extra_parameter}\""
fi fi
echo $cmd echo ${cmd}
eval $cmd eval ${cmd}
train-subset: train_st
valid-subset: dev_st
max-epoch: 50
max-update: 100000
num-workers: 8
patience: 10
no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
arch: s2t_conformer_m
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 1e-3
#adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
#dropout: 0.1
#activation-fn: relu
#encoder-embed-dim: 256
#encoder-ffn-embed-dim: 2048
#encoder-layers: 12
#decoder-layers: 6
#encoder-attention-heads: 4
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
# conformer
#macaron-style: True
#use-cnn-module: True
#cnn-module-kernel: 31
# relative position encoding
#encoder-attention-type: relative
#decoder-attention-type: relative
#max-encoder-relative-length: 100
#max-decoder-relative-length: 20
train-subset: train_st,train_covost
valid-subset: dev_st
max-epoch: 50
max-update: 100000
num-workers: 8
patience: 10
no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
#load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from:
arch: s2t_sate
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
acoustic-encoder: transformer
adapter: league
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 100
max-decoder-relative-length: 20
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
MAIN_ROOT=$PWD/../../..
KALDI_ROOT=$MAIN_ROOT/tools/kaldi
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PATH
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1
. $KALDI_ROOT/tools/config/common_path.sh
export LC_ALL=C
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:$MAIN_ROOT/src/lib
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:$MAIN_ROOT/tools/chainer_ctc/ext/warp-ctc/build
. "${MAIN_ROOT}"/tools/activate_python.sh && . "${MAIN_ROOT}"/tools/extra_path.sh
export PATH=$MAIN_ROOT/utils:$MAIN_ROOT/espnet/bin:$PATH
export OMP_NUM_THREADS=1
# check extra module installation
if ! which tokenizer.perl > /dev/null; then
echo "Error: it seems that moses is not installed." >&2
echo "Error: please install moses as follows." >&2
echo "Error: cd ${MAIN_ROOT}/tools && make moses.done" >&2
return 1
fi
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
train-subset: train
valid-subset: valid
max-epoch: 50
max-update: 100000
num-workers: 8
patience: 10
no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
skip-invalid-size-inputs-valid-test: True
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
arch: dlcl_transformer
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 1e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 20
max-decoder-relative-length: 20
use-enc-dlcl: True
use-dec-dlcl: True
MAIN_ROOT=$PWD/../../..
KALDI_ROOT=$MAIN_ROOT/tools/kaldi
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PATH
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1
. $KALDI_ROOT/tools/config/common_path.sh
export LC_ALL=C
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:$MAIN_ROOT/src/lib
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:$MAIN_ROOT/tools/chainer_ctc/ext/warp-ctc/build
. "${MAIN_ROOT}"/tools/activate_python.sh && . "${MAIN_ROOT}"/tools/extra_path.sh
export PATH=$MAIN_ROOT/utils:$MAIN_ROOT/espnet/bin:$PATH
export OMP_NUM_THREADS=1
# check extra module installation
if ! which tokenizer.perl > /dev/null; then
echo "Error: it seems that moses is not installed." >&2
echo "Error: please install moses as follows." >&2
echo "Error: cd ${MAIN_ROOT}/tools && make moses.done" >&2
return 1
fi
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
train-subset: train_st
valid-subset: dev_st
max-epoch: 50
max-update: 100000
num-workers: 8
patience: 10
no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
#load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from:
arch: s2t_sate
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
acoustic-encoder: transformer
adapter: league
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 100
max-decoder-relative-length: 20
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
train-subset: train_st
valid-subset: dev_st
max-epoch: 50
max-update: 100000
num-workers: 8
patience: 10
no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
arch: s2t_conformer_m
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 1e-3
#adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
#dropout: 0.1
#activation-fn: relu
#encoder-embed-dim: 256
#encoder-ffn-embed-dim: 2048
#encoder-layers: 12
#decoder-layers: 6
#encoder-attention-heads: 4
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
# conformer
#macaron-style: True
#use-cnn-module: True
#cnn-module-kernel: 31
# relative position encoding
#encoder-attention-type: relative
#decoder-attention-type: relative
#max-encoder-relative-length: 100
#max-decoder-relative-length: 20
train-subset: train_st,train_covost
valid-subset: dev_st
max-epoch: 50
max-update: 100000
num-workers: 8
patience: 10
no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
#load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from:
arch: s2t_sate
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
acoustic-encoder: transformer
adapter: league
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 100
max-decoder-relative-length: 20
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
MAIN_ROOT=$PWD/../../..
KALDI_ROOT=$MAIN_ROOT/tools/kaldi
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PATH
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1
. $KALDI_ROOT/tools/config/common_path.sh
export LC_ALL=C
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:$MAIN_ROOT/src/lib
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:$MAIN_ROOT/tools/chainer_ctc/ext/warp-ctc/build
. "${MAIN_ROOT}"/tools/activate_python.sh && . "${MAIN_ROOT}"/tools/extra_path.sh
export PATH=$MAIN_ROOT/utils:$MAIN_ROOT/espnet/bin:$PATH
export OMP_NUM_THREADS=1
# check extra module installation
if ! which tokenizer.perl > /dev/null; then
echo "Error: it seems that moses is not installed." >&2
echo "Error: please install moses as follows." >&2
echo "Error: cd ${MAIN_ROOT}/tools && make moses.done" >&2
return 1
fi
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import checkpoint_utils
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text import (
S2TTransformerModel,
s2t_transformer_s,
)
@register_model("s2ttransformer_simul_trans")
class SimulS2TTransformerModel(S2TTransformerModel):
"""
Implementation of the paper:
SimulMT to SimulST: Adapting Simultaneous Text Translation to
End-to-End Simultaneous Speech Translation
https://www.aclweb.org/anthology/2020.aacl-main.58.pdf
"""
@staticmethod
def add_args(parser):
super(SimulS2TTransformerModel, SimulS2TTransformerModel).add_args(parser)
parser.add_argument(
"--train-monotonic-only",
action="store_true",
default=False,
help="Only train monotonic attention",
)
# @classmethod
# def build_decoder(cls, args, task, embed_tokens):
# tgt_dict = task.tgt_dict
#
# from examples.simultaneous_translation.models.transformer_monotonic_attention import (
# TransformerMonotonicDecoder,
# )
#
# decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)
#
# if getattr(args, "load_pretrained_decoder_from", None):
# decoder = checkpoint_utils.load_pretrained_component_from_model(
# component=decoder, checkpoint=args.load_pretrained_decoder_from
# )
# return decoder
@register_model_architecture(
"s2ttransformer_simul_trans", "s2ttransformer_simul_trans_base"
)
def s2ttransformer_simul_trans_base(args):
s2t_transformer_s(args)
...@@ -38,7 +38,7 @@ log = logging.getLogger(__name__) ...@@ -38,7 +38,7 @@ log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"] MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
class ASR_Dataset(Dataset): class ASRDataset(Dataset):
""" """
Create a Dataset for MuST-C. Each item is a tuple of the form: Create a Dataset for MuST-C. Each item is a tuple of the form:
waveform, sample_rate, source utterance, target utterance, speaker_id, waveform, sample_rate, source utterance, target utterance, speaker_id,
...@@ -70,10 +70,7 @@ class ASR_Dataset(Dataset): ...@@ -70,10 +70,7 @@ class ASR_Dataset(Dataset):
self.data = [] self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = wav_root / wav_filename wav_path = wav_root / wav_filename
try: sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
except TypeError:
sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
seg_group = sorted(_seg_group, key=lambda x: float(x["offset"])) seg_group = sorted(_seg_group, key=lambda x: float(x["offset"]))
for i, segment in enumerate(seg_group): for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate) offset = int(float(segment["offset"]) * sample_rate)
...@@ -185,7 +182,7 @@ def process(args): ...@@ -185,7 +182,7 @@ def process(args):
for split in splits: for split in splits:
print(f"Fetching split {split}...") print(f"Fetching split {split}...")
dataset = ASR_Dataset(root.as_posix(), lang, split, args.speed_perturb, args.tokenizer) dataset = ASRDataset(root.as_posix(), lang, split, args.speed_perturb, args.tokenizer)
is_train_split = split.startswith("train") is_train_split = split.startswith("train")
print("Extracting log mel filter bank features...") print("Extracting log mel filter bank features...")
if is_train_split and args.cmvn_type == "global": if is_train_split and args.cmvn_type == "global":
...@@ -246,7 +243,7 @@ def process(args): ...@@ -246,7 +243,7 @@ def process(args):
if args.task == "st" and args.add_src: if args.task == "st" and args.add_src:
manifest["src_text"] = [] manifest["src_text"] = []
dataset = ASR_Dataset(args.data_root, lang, split, args.speed_perturb, args.tokenizer) dataset = ASRDataset(args.data_root, lang, split, args.speed_perturb, args.tokenizer)
for idx in range(len(dataset)): for idx in range(len(dataset)):
items = dataset.get_fast(idx) items = dataset.get_fast(idx)
for item in items: for item in items:
......
...@@ -25,7 +25,7 @@ log = logging.getLogger(__name__) ...@@ -25,7 +25,7 @@ log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["src_text", "tgt_text"] MANIFEST_COLUMNS = ["src_text", "tgt_text"]
class MTData(Dataset): class MTDataset(Dataset):
""" """
Create a Dataset for MuST-C. Each item is a tuple of the form: Create a Dataset for MuST-C. Each item is a tuple of the form:
waveform, sample_rate, source utterance, target utterance, speaker_id, waveform, sample_rate, source utterance, target utterance, speaker_id,
...@@ -72,7 +72,7 @@ def process(args): ...@@ -72,7 +72,7 @@ def process(args):
is_train_split = split.startswith("train") is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS} manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = MTData(args.data_root, src_lang, tgt_lang, split) dataset = MTDataset(args.data_root, src_lang, tgt_lang, split)
for src_text, tgt_text in tqdm(dataset): for src_text, tgt_text in tqdm(dataset):
if args.lowercase_src: if args.lowercase_src:
src_text = src_text.lower() src_text = src_text.lower()
......
...@@ -75,10 +75,7 @@ class MUSTC(Dataset): ...@@ -75,10 +75,7 @@ class MUSTC(Dataset):
self.data = [] self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = wav_root / wav_filename wav_path = wav_root / wav_filename
try: sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
except TypeError:
sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
seg_group = sorted(_seg_group, key=lambda x: float(x["offset"])) seg_group = sorted(_seg_group, key=lambda x: float(x["offset"]))
for i, segment in enumerate(seg_group): for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate) offset = int(float(segment["offset"]) * sample_rate)
......
...@@ -38,14 +38,15 @@ log = logging.getLogger(__name__) ...@@ -38,14 +38,15 @@ log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"] MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
class ST_Dataset(Dataset): class STDataset(Dataset):
""" """
Create a Dataset for MuST-C. Each item is a tuple of the form: Create a Dataset for MuST-C. Each item is a tuple of the form:
waveform, sample_rate, source utterance, target utterance, speaker_id, waveform, sample_rate, source utterance, target utterance, speaker_id,
utterance_id utterance_id
""" """
def __init__(self, root: str, src_lang, tgt_lang: str, split: str, speed_perturb: bool = False, tokenizer: bool = False) -> None: def __init__(self, root: str, src_lang, tgt_lang: str, split: str,
speed_perturb: bool = False, tokenizer: bool = False) -> None:
_root = Path(root) / f"{src_lang}-{tgt_lang}" / split _root = Path(root) / f"{src_lang}-{tgt_lang}" / split
wav_root, txt_root = _root / "wav", _root / "txt" wav_root, txt_root = _root / "wav", _root / "txt"
if tokenizer: if tokenizer:
...@@ -71,10 +72,7 @@ class ST_Dataset(Dataset): ...@@ -71,10 +72,7 @@ class ST_Dataset(Dataset):
self.data = [] self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = wav_root / wav_filename wav_path = wav_root / wav_filename
try: sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
except TypeError:
sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
seg_group = sorted(_seg_group, key=lambda x: float(x["offset"])) seg_group = sorted(_seg_group, key=lambda x: float(x["offset"]))
for i, segment in enumerate(seg_group): for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate) offset = int(float(segment["offset"]) * sample_rate)
...@@ -194,7 +192,7 @@ def process(args): ...@@ -194,7 +192,7 @@ def process(args):
for split in splits: for split in splits:
print(f"Fetching split {split}...") print(f"Fetching split {split}...")
dataset = ST_Dataset(root.as_posix(), src_lang, tgt_lang, split, args.speed_perturb, args.tokenizer) dataset = STDataset(root.as_posix(), src_lang, tgt_lang, split, args.speed_perturb, args.tokenizer)
is_train_split = split.startswith("train") is_train_split = split.startswith("train")
print("Extracting log mel filter bank features...") print("Extracting log mel filter bank features...")
if is_train_split and args.cmvn_type == "global": if is_train_split and args.cmvn_type == "global":
...@@ -255,7 +253,7 @@ def process(args): ...@@ -255,7 +253,7 @@ def process(args):
if args.task == "st" and args.add_src: if args.task == "st" and args.add_src:
manifest["src_text"] = [] manifest["src_text"] = []
dataset = ST_Dataset(args.data_root, src_lang, tgt_lang, split, args.speed_perturb, args.tokenizer) dataset = STDataset(args.data_root, src_lang, tgt_lang, split, args.speed_perturb, args.tokenizer)
for idx in range(len(dataset)): for idx in range(len(dataset)):
items = dataset.get_fast(idx) items = dataset.get_fast(idx)
for item in items: for item in items:
......
...@@ -10,6 +10,7 @@ import torch.nn.functional as F ...@@ -10,6 +10,7 @@ import torch.nn.functional as F
from fairseq import metrics, utils from fairseq import metrics, utils
from fairseq.criterions import register_criterion from fairseq.criterions import register_criterion
from fairseq.data.data_utils import post_process from fairseq.data.data_utils import post_process
from fairseq.logging.meters import safe_round
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
...@@ -23,14 +24,15 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -23,14 +24,15 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0 self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0
self.pad_idx = task.target_dictionary.pad() self.pad_idx = task.target_dictionary.pad()
self.eos_idx = task.target_dictionary.eos() self.eos_idx = task.target_dictionary.eos()
self.report_accuracy = True self.report_accuracy = True
assert 0 <= ctc_weight <= 1 assert 0 <= ctc_weight
self.ctc_weight = ctc_weight self.ctc_weight = ctc_weight
if self.ctc_weight > 0: if self.ctc_weight > 0:
assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary." assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.zero_infinity = True
self.post_process = post_process self.post_process = post_process
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -54,7 +56,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -54,7 +56,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
default="letter", default="letter",
type=str, type=str,
help="how to post process predictions into words. can be letter, " help="how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. " "word-piece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options", "See fairseq.data.data_utils.post_process() for full list of options",
) )
...@@ -72,7 +74,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -72,7 +74,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
) )
# net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = ( sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"] sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
...@@ -100,10 +101,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -100,10 +101,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
def compute_ctc_loss(self, model, sample, encoder_out): def compute_ctc_loss(self, model, sample, encoder_out):
transcript = sample["transcript"] transcript = sample["transcript"]
ctc_logit = model.encoder.compute_ctc_logit(encoder_out) ctc_logit = encoder_out["ctc_logit"][0]
lprobs = model.get_normalized_probs( lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True [ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder ).contiguous() # (T, B, C) from the encoder
lprobs.batch_first = False
non_padding_mask = ~encoder_out["encoder_padding_mask"][0] non_padding_mask = ~encoder_out["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1) input_lengths = non_padding_mask.long().sum(-1)
...@@ -114,14 +117,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -114,14 +117,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
transcript_lengths = pad_mask.sum(-1) transcript_lengths = pad_mask.sum(-1)
with torch.backends.cudnn.flags(enabled=False): with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss( loss = self.ctc_loss(
lprobs, lprobs,
targets_flat, targets_flat,
input_lengths, input_lengths,
transcript_lengths, transcript_lengths,
blank=self.blank_idx,
reduction="sum",
zero_infinity=self.zero_infinity,
) )
logging_output = { logging_output = {
...@@ -141,9 +141,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -141,9 +141,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
wv_errs = 0 wv_errs = 0
for lp, t, inp_l in zip( for lp, t, inp_l in zip(
lprobs_t, lprobs_t,
sample["target_label"] sample["target_label"] if "target_label" in sample else sample["target"],
if "target_label" in sample
else sample["target"],
input_lengths, input_lengths,
): ):
lp = lp[:inp_l].unsqueeze(0) lp = lp[:inp_l].unsqueeze(0)
...@@ -239,6 +237,44 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -239,6 +237,44 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
else float("nan"), else float("nan"),
) )
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
metrics.log_scalar("_c_errors", c_errors)
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
metrics.log_scalar("_c_total", c_total)
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_w_errors", w_errors)
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
metrics.log_scalar("_wv_errors", wv_errors)
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
metrics.log_scalar("_w_total", w_total)
if c_total > 0:
metrics.log_derived(
"uer",
lambda meters: safe_round(
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
)
if meters["_c_total"].sum > 0
else float("nan"),
)
if w_total > 0:
metrics.log_derived(
"wer",
lambda meters: safe_round(
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
metrics.log_derived(
"raw_wer",
lambda meters: safe_round(
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
@staticmethod @staticmethod
def logging_outputs_can_be_summed() -> bool: def logging_outputs_can_be_summed() -> bool:
""" """
......
...@@ -19,7 +19,6 @@ class FairseqDecoder(nn.Module): ...@@ -19,7 +19,6 @@ class FairseqDecoder(nn.Module):
self.onnx_trace = False self.onnx_trace = False
self.adaptive_softmax = None self.adaptive_softmax = None
def forward(self, prev_output_tokens, encoder_out=None, **kwargs): def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
""" """
Args: Args:
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .berard import * # noqa from .berard import * # noqa
from .ctc import * # noqa
from .convtransformer import * # noqa from .convtransformer import * # noqa
from .s2t_transformer import * # noqa from .s2t_transformer import * # noqa
from .s2t_conformer import * # noqa from .s2t_conformer import * # noqa
from .pys2t_transformer import * # noqa from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa from .s2t_sate import * # noqa
...@@ -7,8 +7,8 @@ from typing import Dict, List, Optional, Tuple ...@@ -7,8 +7,8 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq import checkpoint_utils, utils from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqEncoderDecoderModel, FairseqEncoderDecoderModel,
...@@ -28,6 +28,7 @@ class ConvTransformerModel(FairseqEncoderDecoderModel): ...@@ -28,6 +28,7 @@ class ConvTransformerModel(FairseqEncoderDecoderModel):
Transformer-based Speech translation model from ESPNet-ST Transformer-based Speech translation model from ESPNet-ST
https://arxiv.org/abs/2004.10234 https://arxiv.org/abs/2004.10234
""" """
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
...@@ -303,11 +304,11 @@ class ConvTransformerEncoder(FairseqEncoder): ...@@ -303,11 +304,11 @@ class ConvTransformerEncoder(FairseqEncoder):
x = self.embed_scale * x x = self.embed_scale * x
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
input_len_0 = (src_lengths.float() / subsampling_factor).ceil().long()
input_lengths = torch.min( input_len_1 = x.size(0) * torch.ones([src_lengths.size(0)]).long().to(
(src_lengths.float() / subsampling_factor).ceil().long(), input_len_0.device
x.size(0) * src_lengths.new_ones([src_lengths.size(0)]).long()
) )
input_lengths = torch.min(input_len_0, input_len_1)
encoder_padding_mask = lengths_to_padding_mask(input_lengths) encoder_padding_mask = lengths_to_padding_mask(input_lengths)
......
#!/usr/bin/env python3
import logging
import torch
import torch.nn as nn
from fairseq.modules import (
FairseqDropout,
LayerNorm,
)
logger = logging.getLogger(__name__)
class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout, need_layernorm=False):
super(CTC, self).__init__()
self.ctc_projection = nn.Linear(embed_dim, dictionary_size, bias=False)
nn.init.normal_(
self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5
)
self.ctc_dropout_module = FairseqDropout(
p=dropout, module_name=self.__class__.__name__
)
self.softmax = nn.Softmax(dim=-1)
self.need_layernorm = need_layernorm
if self.need_layernorm:
self.LayerNorm = LayerNorm(embed_dim)
def forward(self, x):
if self.need_layernorm:
x = self.LayerNorm(x)
x = self.ctc_projection(self.ctc_dropout_module(x))
return x
@staticmethod
def softmax(ctc_logit, temperature=1.0):
return torch.nn.functional.softmax(ctc_logit / temperature, dim=-1)
@staticmethod
def log_softmax(ctc_logit, temperature=1.0):
return torch.nn.functional.log_softmax(ctc_logit / temperature, dim=-1)
@staticmethod
def argmax(ctc_logit):
return torch.argmax(ctc_logit, dim=-1)
...@@ -1812,7 +1812,8 @@ def emformer_encoder(klass): ...@@ -1812,7 +1812,8 @@ def emformer_encoder(klass):
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
encoder_out = super().forward(src_tokens, src_lengths) encoder_out = super().forward(src_tokens, src_lengths)
(output, encoder_padding_masks, [], _) = encoder_out["encoder_out"][0] output = encoder_out["encoder_out"][0]
encoder_padding_masks = encoder_out["encoder_padding_mask"][0]
# This is because that in the original implementation # This is because that in the original implementation
# the output didn't consider the last segment as right context. # the output didn't consider the last segment as right context.
......
...@@ -7,19 +7,18 @@ from functools import reduce ...@@ -7,19 +7,18 @@ from functools import reduce
import torch.nn as nn import torch.nn as nn
from fairseq import checkpoint_utils, utils from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.models.speech_to_text import S2TTransformerModel from fairseq.models.speech_to_text import CTC, S2TTransformerModel
from fairseq.modules import ( from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
PyramidTransformerEncoderLayer, PDSTransformerEncoderLayer,
MultiheadAttention, MultiheadAttention,
DownSampleConvolutionModule DownSampleConvolutionModule
) )
...@@ -34,20 +33,22 @@ def lengths_to_padding_mask_with_maxlen(lens, max_length): ...@@ -34,20 +33,22 @@ def lengths_to_padding_mask_with_maxlen(lens, max_length):
return mask return mask
class Permute_120(nn.Module): class Permute120(nn.Module):
def forward(self, x): @staticmethod
def forward(x):
return x.permute(1, 2, 0) return x.permute(1, 2, 0)
class Permute_201(nn.Module): class Permute201(nn.Module):
def forward(self, x): @staticmethod
def forward(x):
return x.permute(2, 0, 1) return x.permute(2, 0, 1)
class ReducedEmbed(nn.Module): class Downsampling(nn.Module):
# Reduced embedding for Pyramid Transformer # down-sampling module
def __init__( def __init__(
self, self,
reduced_way: str, reduced_way: str,
...@@ -62,6 +63,7 @@ class ReducedEmbed(nn.Module): ...@@ -62,6 +63,7 @@ class ReducedEmbed(nn.Module):
self.stride = stride self.stride = stride
self.reduced_way = reduced_way self.reduced_way = reduced_way
# default conv
if self.reduced_way == "conv": if self.reduced_way == "conv":
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding), nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
...@@ -76,25 +78,16 @@ class ReducedEmbed(nn.Module): ...@@ -76,25 +78,16 @@ class ReducedEmbed(nn.Module):
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding), nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
nn.ReLU() nn.ReLU()
) )
elif self.reduced_way == "fuse":
self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
)
self.pool_conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=1),
)
else: else:
logger.error("Unsupported reduced way!") logger.error("Unsupported reduced way!")
self.embed_norm = embed_norm self.embed_norm = embed_norm
if self.embed_norm: if self.embed_norm:
# if self.reduced_way == "fuse":
# self.in_norm = LayerNorm(in_channels)
self.norm = LayerNorm(out_channels) self.norm = LayerNorm(out_channels)
def forward(self, x, lengths): def forward(self, x, lengths):
seq_len, bsz, dim = x.size() seq_len, bsz, dim = x.size()
# assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride) assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride)
# mask batch padding # mask batch padding
if not torch.all(lengths == seq_len): if not torch.all(lengths == seq_len):
...@@ -113,22 +106,14 @@ class ReducedEmbed(nn.Module): ...@@ -113,22 +106,14 @@ class ReducedEmbed(nn.Module):
x = self.conv(self.in_norm(x)) x = self.conv(self.in_norm(x))
x = x.permute(2, 0, 1) # seq_len, bsz, dim x = x.permute(2, 0, 1) # seq_len, bsz, dim
else: else:
# if self.reduced_way == "fuse":
# x = self.in_norm(x)
x = x.permute(1, 2, 0) # B * D * T x = x.permute(1, 2, 0) # B * D * T
origin_x = x
x = self.conv(x) x = self.conv(x)
if self.reduced_way == "glu": if self.reduced_way == "glu":
x = self.glu(x) x = self.glu(x)
if self.reduced_way == "fuse":
x2 = nn.functional.adaptive_avg_pool1d(origin_x, x.size(-1))
x2 = self.pool_conv(x2)
x = x + x2
x = x.permute(2, 0, 1) # T * B * D x = x.permute(2, 0, 1) # T * B * D
if self.embed_norm: if self.embed_norm:
x = self.norm(x) x = self.norm(x)
# assert max(lengths) == x.size(0)
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, x.size(0)) padding_mask = lengths_to_padding_mask_with_maxlen(lengths, x.size(0))
# mask batch padding # mask batch padding
...@@ -142,43 +127,9 @@ class ReducedEmbed(nn.Module): ...@@ -142,43 +127,9 @@ class ReducedEmbed(nn.Module):
return x, lengths, padding_mask return x, lengths, padding_mask
class BlockFuse(nn.Module): @register_model("pdss2t_transformer")
class PDSS2TTransformerModel(S2TTransformerModel):
def __init__(self, embed_dim, prev_embed_dim, num_head, dropout): """Progressive down-sampling for acoustic encoding."""
super().__init__()
self.pre_layer_norm = LayerNorm(prev_embed_dim)
self.out_layer_norm = LayerNorm(embed_dim)
self.attn = MultiheadAttention(
embed_dim,
num_head,
kdim=prev_embed_dim,
vdim=prev_embed_dim,
dropout=dropout,
encoder_decoder_attention=True,
)
def forward(self, x, state, padding):
state = self.pre_layer_norm(state)
state, attn = self.attn(
query=x,
key=state,
value=state,
key_padding_mask=padding,
static_kv=True,
)
state = self.out_layer_norm(state)
return state
@register_model("pys2t_transformer")
class PYS2TTransformerModel(S2TTransformerModel):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
speech-to-text tasks. The Transformer encoder/decoder remains the same.
A trainable input subsampler is prepended to the Transformer encoder to
project inputs into the encoder dimension as well as downsample input
sequence for computational efficiency."""
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
...@@ -377,6 +328,7 @@ class PYS2TTransformerModel(S2TTransformerModel): ...@@ -377,6 +328,7 @@ class PYS2TTransformerModel(S2TTransformerModel):
help='decoder layer history type' help='decoder layer history type'
) )
# local modeling
parser.add_argument( parser.add_argument(
'--hard-mask-window', '--hard-mask-window',
type=float, type=float,
...@@ -437,93 +389,86 @@ class PYS2TTransformerModel(S2TTransformerModel): ...@@ -437,93 +389,86 @@ class PYS2TTransformerModel(S2TTransformerModel):
help="Kernel size of convolution module.", help="Kernel size of convolution module.",
) )
# pyramid setting # pds setting
parser.add_argument( parser.add_argument(
"--pyramid-stages", "--pds-stages",
type=int, type=int,
help="the number of the stage", help="the number of the stage",
) )
parser.add_argument( parser.add_argument(
"--pyramid-layers", "--pds-layers",
type=str, type=str,
help="the number of the encoder layers", help="the number of the encoder layers in each stage",
) )
parser.add_argument( parser.add_argument(
"--pyramid-sr-ratios", "--pds-ratios",
type=str, type=str,
help="the ratio of the subsampling", help="the ratio of the down-sampling in each stage",
) )
parser.add_argument( parser.add_argument(
"--pyramid-attn-sample-ratio", "--pds-ds-method",
type=str, type=str,
help="the ratio of the subsampling in the self attention module", choices=["glu", "conv", "proj", "fusion"],
help="the down-sampling method",
) )
parser.add_argument( parser.add_argument(
"--pyramid-reduced-embed", "--pds-embed-dims",
type=str, type=str,
choices=["glu", "conv", "proj", "fuse"], help="the embedding dimension in each stage",
help="the reduced way of the embedding",
) )
parser.add_argument( parser.add_argument(
"--pyramid-embed-norm", "--pds-kernel-sizes",
action="store_true", type=str,
help="use layer norm in reduced embedding", help="the kernel size of the down-sampling module in each stage",
) )
parser.add_argument( parser.add_argument(
"--pyramid-block-attn", "--pds-embed-norm",
action="store_true", action="store_true",
help="use block attention", help="use layer norm in the down-sampling module",
) )
parser.add_argument( parser.add_argument(
"--pyramid-fuse-way", "--pds-position-embed",
type=str, type=str,
help="fused way for block attention", help="use the position embedding or not before each encoding",
) )
parser.add_argument( parser.add_argument(
"--pyramid-position-embed", "--pds-attn-heads",
type=str, type=str,
help="use the position embedding or not", help="the number of the attention heads in each stage",
) )
parser.add_argument( parser.add_argument(
"--pyramid-embed-dims", "--pds-attn-ds-ratio",
type=str, type=str,
help="the embedding dimension", help="the ratio of the down-sampling in the self attention module",
) )
parser.add_argument( parser.add_argument(
"--pyramid-kernel-sizes", "--pds-ffn-ratios",
type=str, type=str,
help="the kernel size of the reduced embedding", help="the ratio of the ffn in each stage",
) )
parser.add_argument( parser.add_argument(
"--pyramid-ffn-ratios", "--pds-fusion",
type=str, action="store_true",
help="the ratio of the ffn", help="use the representation fusion method",
) )
parser.add_argument( parser.add_argument(
"--pyramid-heads", "--pds-fusion-method",
type=str, type=str,
help="the number of the attention heads", help="the fusion method",
) )
parser.add_argument( parser.add_argument(
"--pyramid-fuse", "--pds-dropout",
action="store_true",
help="fuse the features in multiple stages",
)
parser.add_argument(
"--pyramid-dropout",
type=float, type=float,
help="dropout of the pyramid transformer", help="dropout in each stage",
)
parser.add_argument(
"--ctc-layer",
type=int,
help="the position of the ctc loss",
) )
pass pass
@classmethod @classmethod
def build_encoder(cls, args, task=None, embed_tokens=None): def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = PyS2TTransformerEncoder(args, task, embed_tokens) encoder = PDSS2TTransformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None): if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model( encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
...@@ -535,8 +480,8 @@ class PYS2TTransformerModel(S2TTransformerModel): ...@@ -535,8 +480,8 @@ class PYS2TTransformerModel(S2TTransformerModel):
return encoder return encoder
class PyS2TTransformerEncoder(FairseqEncoder): class PDSS2TTransformerEncoder(FairseqEncoder):
"""Speech-to-text Pyramid Transformer encoder""" """Progressive Down-sampling for Acoustic Encoding"""
def __init__(self, args, task=None, embed_tokens=None): def __init__(self, args, task=None, embed_tokens=None):
super().__init__(None) super().__init__(None)
...@@ -548,157 +493,144 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -548,157 +493,144 @@ class PyS2TTransformerEncoder(FairseqEncoder):
self.dropout = FairseqDropout( self.dropout = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__ p=args.dropout, module_name=self.__class__.__name__
) )
self.pyramid_dropout = FairseqDropout( self.pds_dropout = FairseqDropout(
p=getattr(args, "pyramid_dropout", args.dropout), module_name=self.__class__.__name__ p=getattr(args, "pds_dropout", args.dropout), module_name=self.__class__.__name__
) )
self.pyramid_stages = getattr(args, "pyramid_stages", 4) self.pds_stages = getattr(args, "pds_stages", 4)
self.pds_layers = [int(n) for n in args.pds_layers.split("_")]
self.pds_ratios = [int(n) for n in args.pds_ratios.split("_")]
self.pyramid_layers = [int(n) for n in args.pyramid_layers.split("_")] # down-sampling module
self.pyramid_sr_ratios = [int(n) for n in args.pyramid_sr_ratios.split("_")] self.pds_ds_method = args.pds_ds_method
self.pds_embed_dims = [int(n) for n in args.pds_embed_dims.split("_")]
self.pds_kernel_sizes = [int(n) for n in args.pds_kernel_sizes.split("_")]
self.pds_embed_norm = args.pds_embed_norm
self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")]
self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")]
self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")]
if self.attn_type == "reduced": if self.attn_type == "reduced":
self.pyramid_attn_sample_ratios = [int(n) for n in args.pyramid_attn_sample_ratios.split("_")] self.pds_attn_ds_ratios = [int(n) for n in args.pds_attn_ds_ratios.split("_")]
else: else:
self.pyramid_attn_sample_ratios = None self.pds_attn_ds_ratios = None
self.pyramid_embed_dims = [int(n) for n in args.pyramid_embed_dims.split("_")]
self.pyramid_position_embed = [int(n) for n in args.pyramid_position_embed.split("_")] self.fusion = getattr(args, "pds_fusion", False)
self.pyramid_kernel_sizes = [int(n) for n in args.pyramid_kernel_sizes.split("_")] self.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
self.pyramid_ffn_ratios = [int(n) for n in args.pyramid_ffn_ratios.split("_")] self.pds_fusion_transform = "conv"
self.pyramid_heads = [int(n) for n in args.pyramid_heads.split("_")] if len(self.pds_fusion_method.split("_")) == 2:
self.pyramid_reduced_embed = args.pyramid_reduced_embed items = self.pds_fusion_method.split("_")
self.pyramid_embed_norm = args.pyramid_embed_norm self.pds_fusion_method = items[0]
self.pds_fusion_transform = items[1]
self.pyramid_block_attn = getattr(args, "pyramid_block_attn", False)
self.fuse = getattr(args, "pyramid_fuse", False) fusion_stages_num = 0
self.pyramid_fuse_way = getattr(args, "pyramid_fuse_way", "all_conv") if self.fusion:
self.pyramid_fuse_transform = "conv" if self.pds_fusion_way == "all":
if len(self.pyramid_fuse_way.split("_")) == 2: fusion_stages_num = self.pds_stages
items = self.pyramid_fuse_way.split("_") elif self.pds_fusion_way == "same":
self.pyramid_fuse_way = items[0] for dim in self.pds_embed_dims:
self.pyramid_fuse_transform = items[1]
fuse_stages_num = 0
if self.fuse:
if self.pyramid_fuse_way == "all":
fuse_stages_num = self.pyramid_stages
elif self.pyramid_fuse_way == "same":
for dim in self.pyramid_embed_dims:
if dim == self.embed_dim: if dim == self.embed_dim:
fuse_stages_num += 1 fusion_stages_num += 1
else: else:
logger.error("Unsupported fusion!") logger.error("Unsupported fusion!")
if fuse_stages_num == 1: if fusion_stages_num == 1:
fuse_stages_num = 0 fusion_stages_num = 0
self.fuse_stages_num = fuse_stages_num self.fusion_stages_num = fusion_stages_num
for i in range(self.pyramid_stages): for i in range(self.pds_stages):
num_layers = self.pyramid_layers[i] num_layers = self.pds_layers[i]
sr_ratio = self.pyramid_sr_ratios[i] ds_ratio = self.pds_ratios[i]
attn_sample_ratio = self.pyramid_attn_sample_ratios[i] if self.attn_type == "reduced" else -1
embed_dim = self.pyramid_embed_dims[i] embed_dim = self.pds_embed_dims[i]
kernel_size = self.pyramid_kernel_sizes[i] kernel_size = self.pds_kernel_sizes[i]
ffn_ratio = self.pyramid_ffn_ratios[i] use_pos_embed = self.pds_position_embed[i]
num_head = self.pyramid_heads[i]
use_pos_embed = self.pyramid_position_embed[i] num_head = self.pds_attn_heads[i]
logger.info("The stage {}: layer {}, sample ratio {}, attention sample ratio {}, embed dim {}, " attn_ds_ratio = self.pds_attn_ds_ratios[i] if self.attn_type == "reduced" else -1
"kernel size {}, ffn ratio {}, num head {}, position embed {}, " ffn_ratio = self.pds_ffn_ratios[i]
"fuse {}, fuse way {}, transformer {}.".
format(i, num_layers, sr_ratio, attn_sample_ratio, logger.info("The stage {}: layer {}, down-sample ratio {}, embed dim {}, "
embed_dim, kernel_size, ffn_ratio, num_head, use_pos_embed, "kernel size {}, position embed {}, ffn ratio {}, num head {}, "
self.fuse, self.pyramid_fuse_way, self.pyramid_fuse_transform)) "fusion {}, fusion method {}, fusion transformer {}.".
format(i, num_layers, ds_ratio, embed_dim,
kernel_size, use_pos_embed, ffn_ratio, num_head,
self.fusion, self.pds_fusion_method, self.pds_fusion_transform))
if i == 0: if i == 0:
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim)
if args.no_scale_embedding: if args.no_scale_embedding:
self.embed_scale = 1.0 self.embed_scale = 1.0
reduced_embed = ReducedEmbed( downsampling = Downsampling(
self.pyramid_reduced_embed, self.pds_ds_method,
self.pyramid_embed_norm, self.pds_embed_norm,
args.input_feat_per_channel * args.input_channels if i == 0 else self.pyramid_embed_dims[i-1], args.input_feat_per_channel * args.input_channels if i == 0 else self.pds_embed_dims[i-1],
embed_dim, embed_dim,
kernel_sizes=kernel_size, kernel_sizes=kernel_size,
stride=sr_ratio, stride=ds_ratio,
padding=(kernel_size - 1) // 2, padding=(kernel_size - 1) // 2,
) )
if use_pos_embed: if use_pos_embed:
pos_embed = PositionalEmbedding( pos_embed = PositionalEmbedding(args.max_source_positions, embed_dim, self.padding_idx)
args.max_source_positions, embed_dim,
self.padding_idx, pos_emb_type=self.attn_type
)
else: else:
pos_embed = None pos_embed = None
block = nn.ModuleList([ stage = nn.ModuleList([
PyramidTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_sample_ratio) PDSTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_ds_ratio)
for _ in range(num_layers)]) for _ in range(num_layers)])
block_fuse = None fusion_pre_layer_norm = None
if self.pyramid_block_attn: fusion_post_layer_norm = None
if i != self.pyramid_stages - 1: fusion_downsampling = None
block_fuse = BlockFuse(self.embed_dim, embed_dim, if fusion_stages_num != 0:
self.pyramid_heads[-1], dropout=args.dropout if self.pds_fusion_method == "all" or (
) self.pds_fusion_method == "same" and self.embed_dim == embed_dim
fuse_pre_layer_norm = None
fuse_post_layer_norm = None
down_sample = None
if fuse_stages_num != 0:
if self.pyramid_fuse_way == "all" or (
self.pyramid_fuse_way == "same" and self.embed_dim == embed_dim
): ):
if i != self.pyramid_stages - 1: if i != self.pds_stages - 1:
shrink_size = reduce(lambda a, b: a * b, self.pyramid_sr_ratios[i + 1:]) ratio = reduce(lambda a, b: a * b, self.pds_sr_ratios[i + 1:])
else: else:
shrink_size = 1 ratio = 1
fuse_pre_layer_norm = LayerNorm(embed_dim) fusion_pre_layer_norm = LayerNorm(embed_dim)
fuse_post_layer_norm = LayerNorm(self.embed_dim) fusion_post_layer_norm = LayerNorm(self.embed_dim)
if self.pyramid_fuse_transform == "conv": # default conv
down_sample = nn.Sequential( if self.pds_fusion_transform == "conv":
Permute_120(), fusion_downsampling = nn.Sequential(
Permute120(),
nn.Conv1d(embed_dim, self.embed_dim, nn.Conv1d(embed_dim, self.embed_dim,
kernel_size=shrink_size, kernel_size=ratio,
stride=shrink_size), stride=ratio),
nn.BatchNorm1d(self.embed_dim), nn.BatchNorm1d(self.embed_dim),
nn.ReLU(), nn.ReLU(),
Permute_201(), Permute201(),
) )
elif self.pyramid_fuse_transform == "pool": elif self.pds_fusion_transform == "pool":
down_sample = nn.Sequential( fusion_downsampling = nn.Sequential(
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1), nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1),
nn.BatchNorm1d(self.embed_dim), nn.BatchNorm1d(self.embed_dim),
nn.ReLU(), nn.ReLU(),
Permute_201(), Permute201(),
) )
elif self.pyramid_fuse_transform == "conv2": elif self.pds_fusion_transform == "conv2":
down_sample = DownSampleConvolutionModule( fusion_downsampling = DownSampleConvolutionModule(
self.embed_dim, self.embed_dim,
kernel_size=shrink_size, kernel_size=ratio,
stride=shrink_size, stride=ratio,
) )
else: else:
logger.error("Unsupported fusion transform!") logger.error("Unsupported fusion transform!")
setattr(self, f"reduced_embed{i + 1}", reduced_embed) setattr(self, f"downsampling{i + 1}", downsampling)
setattr(self, f"pos_embed{i + 1}", pos_embed) setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"block{i + 1}", block) setattr(self, f"stage{i + 1}", stage)
setattr(self, f"block_fuse{i + 1}", block_fuse) setattr(self, f"fusion_downsampling{i + 1}", fusion_downsampling)
setattr(self, f"down_sample{i + 1}", down_sample) setattr(self, f"fusion_pre_layer_norm{i + 1}", fusion_pre_layer_norm)
setattr(self, f"fuse_pre_layer_norm{i + 1}", fuse_pre_layer_norm) setattr(self, f"fusion_post_layer_norm{i + 1}", fusion_post_layer_norm)
setattr(self, f"fuse_post_layer_norm{i + 1}", fuse_post_layer_norm)
if self.pyramid_block_attn:
self.block_layer_norm = LayerNorm(self.embed_dim)
if args.encoder_normalize_before: if self.fusion_stages_num != 0:
self.layer_norm = LayerNorm(self.embed_dim) self.fusion_weight = nn.Parameter(torch.Tensor(fusion_stages_num).fill_(1.0))
else: self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True)
self.layer_norm = None
if self.fuse_stages_num != 0 or self.pyramid_block_attn:
self.fuse_weight = nn.Parameter(torch.Tensor(fuse_stages_num).fill_(1.0))
self.fuse_weight.data = self.fuse_weight.data / self.fuse_weight.data.sum(0, keepdim=True)
self.use_ctc = "sate" in args.arch or \ self.use_ctc = "sate" in args.arch or \
(("ctc" in getattr(args, "criterion", False)) and (("ctc" in getattr(args, "criterion", False)) and
...@@ -706,65 +638,95 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -706,65 +638,95 @@ class PyS2TTransformerEncoder(FairseqEncoder):
if self.use_ctc: if self.use_ctc:
self.ctc_layer = (args.encoder_layers + args.ctc_layer) % args.encoder_layers self.ctc_layer = (args.encoder_layers + args.ctc_layer) % args.encoder_layers
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
embed_dim = self.pds_embed_dims[-1]
if self.inter_ctc:
ctc_layer = self.ctc_layer
for i in range(self.pds_stages):
ctc_layer -= self.pds_layers[i]
if ctc_layer <= 0:
embed_dim = self.pds_embed_dims[i]
break
self.ctc = CTC(embed_dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
if task.source_dictionary == task.target_dictionary:
self.ctc.ctc_projection.weight = embed_tokens.weight
if task.source_dictionary == task.target_dictionary and getattr(args, "share_all_embeddings", False): if args.encoder_normalize_before:
self.ctc_projection = nn.Linear( self.layer_norm = LayerNorm(self.embed_dim)
embed_tokens.weight.shape[1], else:
embed_tokens.weight.shape[0], self.layer_norm = None
bias=False,
) self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
self.ctc_projection.weight = embed_tokens.weight self.dis = 2
else: self.cos_sim = dict()
embed_dim = self.pyramid_embed_dims[-1]
def add_to_dict(self, x, dis, idx):
if self.inter_ctc: sim = 0
ctc_layer = self.ctc_layer seq_len = x.size(0)
for i in range(self.pyramid_stages): cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
ctc_layer -= self.pyramid_layers[i] for i in range(dis, seq_len - dis):
if ctc_layer <= 0: a = x[i, :, :]
embed_dim = self.pyramid_embed_dims[i] for j in range(-dis, dis + 1):
break if j == 0:
self.ctc_layer_norm = LayerNorm(embed_dim) continue
b = x[i + j, :, :]
self.ctc_projection = nn.Linear(embed_dim, len(task.source_dictionary), bias=False) sim_j = cos(a, b).mean()
nn.init.normal_( sim += sim_j
self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5 sim = sim / 2 / dis / (seq_len - 2 * dis)
)
self.ctc_dropout_module = FairseqDropout( if idx not in self.cos_sim:
p=args.dropout, module_name=self.__class__.__name__ self.cos_sim[idx] = []
) self.cos_sim[idx].append(float(sim))
self.softmax = nn.Softmax(dim=-1)
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
batch = src_tokens.size(0) batch = src_tokens.size(0)
x = src_tokens.transpose(0, 1) x = src_tokens.transpose(0, 1)
input_lengths = src_lengths input_lengths = src_lengths
# padding to the multiply of 2 # padding to the multiply of 2
max_len = x.size(0) max_len = x.size(0)
length = reduce(lambda a, b: a*b, self.pyramid_sr_ratios) length = reduce(lambda a, b: a*b, self.pds_ratios)
padding_to_len = (length - max_len % length) padding_to_len = (length - max_len % length)
if padding_to_len > 0: if padding_to_len > 0:
padding_for_pyramid = x.new_zeros((padding_to_len, batch, x.size(2))) padding_for_pds = x.new_zeros((padding_to_len, batch, x.size(2)))
x = torch.cat([x, padding_for_pyramid], dim=0) x = torch.cat([x, padding_for_pds], dim=0)
# gather cosine similarity
cos_sim_idx = -1
dis = self.dis
if self.gather_cos_sim:
self.add_to_dict(x, dis, cos_sim_idx)
layer_idx = 0 layer_idx = 0
ctc_logit = None ctc_logit = None
prev_state = [] prev_state = []
prev_padding = [] prev_padding = []
for i in range(self.pyramid_stages): for i in range(self.pds_stages):
reduced_embed = getattr(self, f"reduced_embed{i + 1}") downsampling = getattr(self, f"downsampling{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}") pos_embed = getattr(self, f"pos_embed{i + 1}")
block = getattr(self, f"block{i + 1}") stage = getattr(self, f"stage{i + 1}")
x, input_lengths, encoder_padding_mask = downsampling(x, input_lengths)
x, input_lengths, encoder_padding_mask = reduced_embed(x, input_lengths) # gather cosine similarity
cos_sim_idx += 10
cos_sim_idx = cos_sim_idx // 10 * 10 - 1
if self.gather_cos_sim:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
# add the position encoding and dropout # add the position encoding and dropout
if pos_embed: if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1) positions = pos_embed(encoder_padding_mask).transpose(0, 1)
#if self.attn_type != "rel_selfattn":
# x += positions
x += positions x += positions
positions = self.dropout(positions) positions = self.dropout(positions)
else: else:
...@@ -773,82 +735,77 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -773,82 +735,77 @@ class PyS2TTransformerEncoder(FairseqEncoder):
if i == 0: if i == 0:
x = self.dropout(x) x = self.dropout(x)
else: else:
x = self.pyramid_dropout(x) x = self.pds_dropout(x)
for layer in block: for layer in stage:
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1 layer_idx += 1
# gather cosine similarity
if self.gather_cos_sim:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x) ctc_logit = self.CTC(x)
prev_state.append(x) prev_state.append(x)
prev_padding.append(encoder_padding_mask) prev_padding.append(encoder_padding_mask)
if self.fuse_stages_num != 0: if self.fusion_stages_num != 0:
fuse_state = [] fusion_state = []
i = -1 i = -1
seq_len = x.size(0) seq_len = x.size(0)
for state in prev_state: for state in prev_state:
i += 1 i += 1
down_sample = getattr(self, f"down_sample{i + 1}") fusion_downsampling = getattr(self, f"fusion_downsampling{i + 1}")
fuse_pre_layer_norm = getattr(self, f"fuse_pre_layer_norm{i + 1}") fusion_pre_layer_norm = getattr(self, f"fusion_pre_layer_norm{i + 1}")
fuse_post_layer_norm = getattr(self, f"fuse_post_layer_norm{i + 1}") fusion_post_layer_norm = getattr(self, f"fusion_post_layer_norm{i + 1}")
if fuse_pre_layer_norm is not None or fuse_pre_layer_norm is not None: if fusion_pre_layer_norm is not None or fusion_pre_layer_norm is not None:
state = fuse_pre_layer_norm(state) state = fusion_pre_layer_norm(state)
if self.pyramid_fuse_transform == "pool": if self.pds_fusion_transform == "conv":
state = fusion_downsampling(state)
elif self.pds_fusion_transform == "conv2":
state = fusion_downsampling(state, prev_padding[i])
elif self.pds_fusion_transform == "pool":
state = state.permute(1, 2, 0) # bsz, dim, seq_len state = state.permute(1, 2, 0) # bsz, dim, seq_len
if i != self.pyramid_stages - 1: if i != self.pds_stages - 1:
state = nn.functional.adaptive_max_pool1d(state, seq_len) state = nn.functional.adaptive_max_pool1d(state, seq_len)
state = down_sample(state) state = fusion_downsampling(state)
elif self.pyramid_fuse_transform == "conv":
state = down_sample(state) state = fusion_post_layer_norm(state)
elif self.pyramid_fuse_transform == "conv2": fusion_state.append(state)
state = down_sample(state, prev_padding[i]) x = (torch.stack(fusion_state, dim=0) * self.fusion_weight.view(-1, 1, 1, 1)).sum(0)
state = fuse_post_layer_norm(state)
fuse_state.append(state)
fuse_weight = self.fuse_weight
x = (torch.stack(fuse_state, dim=0) * fuse_weight.view(-1, 1, 1, 1)).sum(0)
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x)
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C] "encoder_states": [], # List[T x B x C]
"ctc_logit": [ctc_logit if ctc_logit is not None else x],
"src_tokens": [], "src_tokens": [],
"src_lengths": [], "src_lengths": [],
} }
def compute_ctc_logit(self, encoder_out):
assert self.use_ctc, "CTC is not available!"
if isinstance(encoder_out, dict) and "ctc_logit" in encoder_out:
encoder_state = encoder_out["ctc_logit"][0]
else:
encoder_state = encoder_out
ctc_logit = self.ctc_projection(self.ctc_dropout_module(encoder_state))
return ctc_logit
def compute_ctc_prob(self, encoder_out, temperature=1.0):
assert self.use_ctc, "CTC is not available!"
ctc_logit = self.compute_ctc_logit(encoder_out) / temperature
return self.softmax(ctc_logit)
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = ( new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0 [] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
) )
new_ctc_logit = (
[] if len(encoder_out["ctc_logit"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["ctc_logit"]]
)
new_encoder_padding_mask = ( new_encoder_padding_mask = (
[] if len(encoder_out["encoder_padding_mask"]) == 0 [] if len(encoder_out["encoder_padding_mask"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]] else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
...@@ -866,6 +823,7 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -866,6 +823,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
return { return {
"encoder_out": new_encoder_out, # T x B x C "encoder_out": new_encoder_out, # T x B x C
"ctc_logit": [new_ctc_logit], # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T "encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C "encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C] "encoder_states": encoder_states, # List[T x B x C]
...@@ -874,7 +832,7 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -874,7 +832,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
} }
@register_model_architecture(model_name="pys2t_transformer", arch_name="pys2t_transformer") @register_model_architecture(model_name="pdss2t_transformer", arch_name="pdss2t_transformer")
def base_architecture(args): def base_architecture(args):
# Convolutional subsampler # Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "") args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "")
...@@ -926,96 +884,252 @@ def base_architecture(args): ...@@ -926,96 +884,252 @@ def base_architecture(args):
args.use_cnn_module = getattr(args, "use_cnn_module", False) args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31) args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Pyramid # PDS
args.pyramid_stages = getattr(args, "pyramid_stages", None) args.pds_stages = getattr(args, "pds_stages", None)
args.pyramid_layers = getattr(args, "pyramid_layers", None) args.pds_layers = getattr(args, "pds_layers", None)
args.pyramid_sr_ratios = getattr(args, "pyramid_sr_ratios", None) args.pds_ratios = getattr(args, "pds_ratios", None)
args.pyramid_attn_sample_ratios = getattr(args, "pyramid_attn_sample_ratios", None)
args.pyramid_embed_dims = getattr(args, "pyramid_embed_dims", None) args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pyramid_kernel_sizes = getattr(args, "pyramid_kernel_sizes", None) args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pyramid_ffn_ratios = getattr(args, "pyramid_ffn_ratios", None) args.pds_embed_norm = getattr(args, "pds_embed_norm", True)
args.pyramid_heads = getattr(args, "pyramid_heads", None) args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pyramid_position_embed = getattr(args, "pyramid_position_embed", None)
args.pyramid_reduced_embed = getattr(args, "pyramid_reduced_embed", "conv") args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pyramid_embed_norm = getattr(args, "pyramid_embed_norm", False) args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.ctc_layer = getattr(args, "ctc_layer", -1)
args.pyramid_dropout = getattr(args, "pyramid_dropout", args.dropout) args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_s") args.fusion = getattr(args, "fusion", False)
def pys2t_transformer_s(args): args.fusion_method = getattr(args, "fusion_method", "all_conv")
def set_pds_base_8(args):
args.pds_stages = getattr(args, "pds_stages", 4)
args.pds_ratios = getattr(args, "pds_ratios", "2_2_1_2")
args.pds_layers = getattr(args, "pds_layers", "3_3_3_3")
args.pds_kernel_sizes = getattr(args, "pds_kernel_sizes", "5_5_5_5")
args.pds_position_embed = getattr(args, "pds_position_embed", "1_1_1_1")
def set_pds_base_16(args):
args.pds_stages = getattr(args, "pds_stages", 4)
args.pds_ratios = getattr(args, "pds_ratios", "2_2_2_2")
args.pds_layers = getattr(args, "pds_layers", "2_2_6_2")
args.pds_kernel_sizes = getattr(args, "pds_kernel_sizes", "5_5_5_5")
args.pds_position_embed = getattr(args, "pds_position_embed", "1_1_1_1")
def set_pds_base_32(args):
args.pds_stages = getattr(args, "pds_stages", 5)
args.pds_ratios = getattr(args, "pds_ratios", "2_2_2_2_2")
args.pds_layers = getattr(args, "pds_layers", "2_2_3_3_2")
args.pds_kernel_sizes = getattr(args, "pds_kernel_sizes", "5_5_5_5_5")
args.pds_position_embed = getattr(args, "pds_position_embed", "1_1_1_1_1")
def set_pds_deep_8(args):
args.pds_stages = getattr(args, "pds_stages", 4)
args.pds_ratios = getattr(args, "pds_ratios", "2_2_1_2")
args.pds_layers = getattr(args, "pds_layers", "7_7_7_9")
args.pds_kernel_sizes = getattr(args, "pds_kernel_sizes", "5_5_5_5")
args.pds_position_embed = getattr(args, "pds_position_embed", "1_1_1_1")
def set_pds_deep_16(args):
args.pds_stages = getattr(args, "pds_stages", 4)
args.pds_ratios = getattr(args, "pds_ratios", "2_2_2_2")
args.pds_layers = getattr(args, "pds_layers", "5_5_12_8")
args.pds_kernel_sizes = getattr(args, "pds_kernel_sizes", "5_5_5_5")
args.pds_position_embed = getattr(args, "pds_position_embed", "1_1_1_1")
def set_pds_deep_32(args):
args.pds_stages = getattr(args, "pds_stages", 5)
args.pds_ratios = getattr(args, "pds_ratios", "2_2_2_2_2")
args.pds_layers = getattr(args, "pds_layers", "5_5_7_7_6")
args.pds_kernel_sizes = getattr(args, "pds_kernel_sizes", "5_5_5_5_5")
args.pds_position_embed = getattr(args, "pds_position_embed", "1_1_1_1_1")
@register_model_architecture("pdss2t_transformer", "pdss2t_transformer_s")
def pdss2t_transformer_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1) args.dropout = getattr(args, "dropout", 0.1)
args.pyramid_stages = getattr(args, "pyramid_stages", 4) # PDS
args.pyramid_layers = getattr(args, "pyramid_layers", "3_3_3_3") set_pds_base_16(args)
args.pyramid_embed_dims = getattr(args, "pyramid_embed_dims", "64_128_256_512") args.pds_embed_dims = getattr(args, "pds_embed_dims", "256_256_256_256")
args.pyramid_kernel_sizes = getattr(args, "pyramid_kernel_sizes", "2_2_2_2") args.pds_attn_heads = getattr(args, "pds_attn_heads", "4_4_4_4")
args.pyramid_ffn_ratios = getattr(args, "pyramid_ffn_ratios", "4_4_4_4") args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", "8_8_8_8")
args.pyramid_attn_sample_ratios = getattr(args, "pyramid_attn_sample_ratios", "8_4_2_1")
args.pyramid_sr_ratios = getattr(args, "pyramid_sr_ratios", "2_2_2_2")
args.pyramid_heads = getattr(args, "pyramid_heads", "1_2_4_8")
args.pyramid_position_embed = getattr(args, "pyramid_position_embed", "1_1_1_1")
args.pyramid_reduced_embed = getattr(args, "pyramid_reduced_embed", "conv")
args.pyramid_embed_norm = getattr(args, "pyramid_embed_norm", False)
base_architecture(args) base_architecture(args)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_s_relative") @register_model_architecture("pdss2t_transformer", "pdss2t_transformer_s_8")
def pys2t_transformer_s_relative(args): def pdss2t_transformer_s_8(args):
args.max_encoder_relative_length = 100 args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.max_decoder_relative_length = 20 args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.k_only = True args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
pys2t_transformer_s(args) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
# PDS
set_pds_base_8(args)
args.pds_embed_dims = getattr(args, "pds_embed_dims", "256_256_256_256")
args.pds_attn_heads = getattr(args, "pds_attn_heads", "4_4_4_4")
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", "8_8_8_8")
base_architecture(args)
@register_model_architecture("pdss2t_transformer", "pdss2t_transformer_s_16")
def pdss2t_transformer_s_16(args):
pdss2t_transformer_s(args)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_xs")
def pys2t_transformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.decoder_layers = getattr(args, "decoder_layers", 3)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4)
args.dropout = getattr(args, "dropout", 0.3)
pys2t_transformer_s(args)
@register_model_architecture("pdss2t_transformer", "pdss2t_transformer_s_32")
def pdss2t_transformer_s_32(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_sp") # PDS
def pys2t_transformer_sp(args): set_pds_base_32(args)
args.encoder_layers = getattr(args, "encoder_layers", 16) args.pds_embed_dims = getattr(args, "pds_embed_dims", "256_256_256_256")
pys2t_transformer_s(args) args.pds_attn_heads = getattr(args, "pds_attn_heads", "4_4_4_4_4")
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", "8_8_8_8_8")
base_architecture(args)
@register_model_architecture("pdss2t_transformer", "pdss2t_transformer_sd")
def pdss2t_transformer_sd(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
# PDS
set_pds_deep_16(args)
args.pds_embed_dims = getattr(args, "pds_embed_dims", "256_256_256_256")
args.pds_attn_heads = getattr(args, "pds_attn_heads", "4_4_4_4")
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", "8_8_8_8")
base_architecture(args)
@register_model_architecture("pdss2t_transformer", "pdss2t_transformer_sd_8")
def pdss2t_transformer_sd_8(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
# PDS
set_pds_deep_8(args)
args.pds_embed_dims = getattr(args, "pds_embed_dims", "256_256_256_256")
args.pds_attn_heads = getattr(args, "pds_attn_heads", "4_4_4_4")
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", "8_8_8_8")
@register_model_architecture("pys2t_transformer", "pys2t_transformer_m") base_architecture(args)
def pys2t_transformer_m(args):
@register_model_architecture("pdss2t_transformer", "pdss2t_transformer_sd_16")
def pdss2t_transformer_sd_16(args):
pdss2t_transformer_sd(args)
@register_model_architecture("pdss2t_transformer", "pdss2t_transformer_sd_32")
def pdss2t_transformer_sd_32(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
# PDS
set_pds_deep_32(args)
args.pds_embed_dims = getattr(args, "pds_embed_dims", "256_256_256_256")
args.pds_attn_heads = getattr(args, "pds_attn_heads", "4_4_4_4")
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", "8_8_8_8")
base_architecture(args)
@register_model_architecture("pdss2t_transformer", "pdss2t_transformer_m")
def pdss2t_transformer_m(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.15)
# PDS
set_pds_base_16(args)
args.pds_embed_dims = getattr(args, "pds_embed_dims", "512_512_512_512")
args.pds_attn_heads = getattr(args, "pds_attn_heads", "8_8_8_8")
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", "4_4_4_4")
base_architecture(args)
@register_model_architecture("pdss2t_transformer", "pdss2t_transformer_m_8")
def pdss2t_transformer_m_8(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.15) args.dropout = getattr(args, "dropout", 0.15)
# PDS
set_pds_base_8(args)
args.pds_embed_dims = getattr(args, "pds_embed_dims", "512_512_512_512")
args.pds_attn_heads = getattr(args, "pds_attn_heads", "8_8_8_8")
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", "4_4_4_4")
base_architecture(args) base_architecture(args)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_mp") @register_model_architecture("pdss2t_transformer", "pdss2t_transformer_m_16")
def pys2t_transformer_mp(args): def pdss2t_transformer_m_16(args):
args.encoder_layers = getattr(args, "encoder_layers", 16) pdss2t_transformer_m(args)
pys2t_transformer_m(args)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_l") @register_model_architecture("pdss2t_transformer", "pdss2t_transformer_m_32")
def pys2t_transformer_l(args): def pdss2t_transformer_m_32(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.15)
# PDS
set_pds_base_32(args)
args.pds_embed_dims = getattr(args, "pds_embed_dims", "512_512_512_512")
args.pds_attn_heads = getattr(args, "pds_attn_heads", "8_8_8_8")
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", "4_4_4_4")
base_architecture(args)
@register_model_architecture("pdss2t_transformer", "pdss2t_transformer_l")
def pdss2t_transformer_l(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.2) args.dropout = getattr(args, "dropout", 0.2)
base_architecture(args)
# PDS
set_pds_base_16(args)
args.pds_embed_dims = getattr(args, "pds_embed_dims", "1024_1024_1024_1024")
args.pds_attn_heads = getattr(args, "pds_attn_heads", "16_16_16_16")
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", "4_4_4_4")
@register_model_architecture("pys2t_transformer", "pys2t_transformer_lp") base_architecture(args)
def pys2t_transformer_lp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
pys2t_transformer_l(args)
...@@ -314,7 +314,22 @@ class S2TConformerEncoder(S2TTransformerEncoder): ...@@ -314,7 +314,22 @@ class S2TConformerEncoder(S2TTransformerEncoder):
if self.history is not None: if self.history is not None:
self.history.clean() self.history.clean()
cos_sim_idx = -1
dis = self.dis
if self.gather_cos_sim:
x = src_tokens
x = x.transpose(0, 1)
self.add_to_dict(x, dis, cos_sim_idx)
x, input_lengths = self.subsample(src_tokens, src_lengths) x, input_lengths = self.subsample(src_tokens, src_lengths)
if type(x) == list:
inner_x = x
if self.gather_cos_sim:
for x in inner_x:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
x = inner_x[-1]
x = self.embed_scale * x x = self.embed_scale * x
encoder_padding_mask = lengths_to_padding_mask(input_lengths) encoder_padding_mask = lengths_to_padding_mask(input_lengths)
...@@ -328,6 +343,11 @@ class S2TConformerEncoder(S2TTransformerEncoder): ...@@ -328,6 +343,11 @@ class S2TConformerEncoder(S2TTransformerEncoder):
if self.history is not None: if self.history is not None:
self.history.add(x) self.history.add(x)
cos_sim_idx = (cos_sim_idx + 10) // 10 * 10 - 1
if self.gather_cos_sim:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
for layer in self.layers: for layer in self.layers:
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
...@@ -335,6 +355,10 @@ class S2TConformerEncoder(S2TTransformerEncoder): ...@@ -335,6 +355,10 @@ class S2TConformerEncoder(S2TTransformerEncoder):
if self.history is not None: if self.history is not None:
self.history.add(x) self.history.add(x)
if self.gather_cos_sim:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
...@@ -357,8 +381,8 @@ def base_architecture(args): ...@@ -357,8 +381,8 @@ def base_architecture(args):
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024) args.conv_channels = getattr(args, "conv_channels", 1024)
# Conformer # Conformer
args.macaron_style = getattr(args, "macaron_style", True) args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", True) args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31) args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
......
...@@ -17,9 +17,8 @@ from fairseq.models.speech_to_text import ( ...@@ -17,9 +17,8 @@ from fairseq.models.speech_to_text import (
S2TTransformerModel, S2TTransformerModel,
S2TTransformerEncoder, S2TTransformerEncoder,
S2TConformerEncoder, S2TConformerEncoder,
S2TConformerModel, PDSS2TTransformerModel,
PYS2TTransformerModel, PDSS2TTransformerEncoder,
PyS2TTransformerEncoder,
) )
from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler
from fairseq.modules import ( from fairseq.modules import (
...@@ -47,9 +46,9 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -47,9 +46,9 @@ class S2TSATEModel(S2TTransformerModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
PYS2TTransformerModel.add_args(parser) PDSS2TTransformerModel.add_args(parser)
# sate setting # SATE setting
parser.add_argument( parser.add_argument(
"--text-encoder-layers", "--text-encoder-layers",
default=6, default=6,
...@@ -57,12 +56,24 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -57,12 +56,24 @@ class S2TSATEModel(S2TTransformerModel):
help="layers of the text encoder", help="layers of the text encoder",
) )
parser.add_argument( parser.add_argument(
"--text-attention-type",
default="selfattn",
type=str,
help="attention type of the textual encoder",
)
parser.add_argument(
"--adapter", "--adapter",
default="league", default="league",
type=str, type=str,
help="adapter type", help="adapter type",
) )
parser.add_argument( parser.add_argument(
"--share-ctc-and-adapter",
default=False,
action="store_true",
help="share the projection weights of the ctc and adapter",
)
parser.add_argument(
"--temperature", "--temperature",
default=1.0, default=1.0,
type=float, type=float,
...@@ -119,6 +130,9 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -119,6 +130,9 @@ class S2TSATEModel(S2TTransformerModel):
component=encoder.text_encoder, checkpoint=args.load_pretrained_text_encoder_from, strict=False component=encoder.text_encoder, checkpoint=args.load_pretrained_text_encoder_from, strict=False
) )
if args.share_ctc_and_adapter and hasattr(encoder.adapter, "linear_adapter"):
encoder.acoustic_encoder.ctc_projection.weight = encoder.adapter.linear_adapter[0].weight
return encoder return encoder
...@@ -126,45 +140,45 @@ class Adapter(nn.Module): ...@@ -126,45 +140,45 @@ class Adapter(nn.Module):
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens):
super().__init__() super().__init__()
attention_dim = args.encoder_embed_dim embed_dim = args.encoder_embed_dim
adapter_type = getattr(args, "adapter", "league") self.adapter_type = args.adapter
self.adapter_type = adapter_type
if adapter_type in ["linear", "league", "gated_league", "gated_league2"]: if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
self.linear_adapter = nn.Sequential( self.linear_adapter = nn.Sequential(
nn.Linear(attention_dim, attention_dim), nn.Linear(embed_dim, embed_dim),
LayerNorm(args.encoder_embed_dim), LayerNorm(args.encoder_embed_dim),
# self.dropout_module,
nn.ReLU(), nn.ReLU(),
) )
elif adapter_type == "linear2": elif self.adapter_type == "linear2":
self.linear_adapter = nn.Sequential( self.linear_adapter = nn.Sequential(
nn.Linear(attention_dim, attention_dim), nn.Linear(embed_dim, embed_dim),
# self.dropout_module,
) )
elif adapter_type == "subsample": elif self.adapter_type == "subsample":
self.subsample_adaptor = Conv1dSubsampler( self.subsample_adaptor = Conv1dSubsampler(
attention_dim, embed_dim,
args.conv_channels, args.conv_channels,
attention_dim, embed_dim,
[int(k) for k in args.conv_kernel_sizes.split(",")], [int(k) for k in args.conv_kernel_sizes.split(",")],
) )
if adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]: if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]:
if embed_tokens is None: if embed_tokens is None:
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
self.embed_adapter = Embedding(num_embeddings, attention_dim, self.padding_idx) self.embed_adapter = Embedding(num_embeddings, embed_dim, self.padding_idx)
else: else:
self.embed_adapter = embed_tokens self.embed_adapter = embed_tokens
if adapter_type == "gated_league": if self.adapter_type == "gated_league":
self.gate_linear = nn.Linear(2 * attention_dim, attention_dim) self.gate_linear = nn.Linear(2 * embed_dim, embed_dim)
elif adapter_type == "gated_league2": elif self.adapter_type == "gated_league2":
self.gate_linear1 = nn.Linear(attention_dim, attention_dim) self.gate_linear1 = nn.Linear(embed_dim, embed_dim)
self.gate_linear2 = nn.Linear(attention_dim, attention_dim) self.gate_linear2 = nn.Linear(embed_dim, embed_dim)
self.out_layernorm = LayerNorm(embed_dim)
# self.out_layernorm = nn.Identity()
def forward(self, x, padding): def forward(self, x, padding):
representation, distribution = x representation, distribution = x
batch, seq_len, embed_dim = representation.size() batch, seq_len, embed_dim = representation.size()
...@@ -174,25 +188,30 @@ class Adapter(nn.Module): ...@@ -174,25 +188,30 @@ class Adapter(nn.Module):
if self.adapter_type == "linear": if self.adapter_type == "linear":
out = self.linear_adapter(representation) out = self.linear_adapter(representation)
out = self.out_layernorm(out)
elif self.adapter_type == "context": elif self.adapter_type == "context":
out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1) out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
out = self.out_layernorm(out)
elif self.adapter_type == "subsample": elif self.adapter_type == "subsample":
representation = representation.transpose(0, 1) representation = representation.transpose(0, 1)
out, input_lengths = self.subsample_adaptor(representation, lengths) out, input_lengths = self.subsample_adaptor(representation, lengths)
padding = lengths_to_padding_mask(input_lengths) padding = lengths_to_padding_mask(input_lengths)
out = self.out_layernorm(out)
elif self.adapter_type == "league": elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
out = linear_out + soft_out out = linear_out + soft_out
out = self.out_layernorm(out)
elif self.adapter_type == "gated_league": elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid() coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out out = coef * linear_out + (1 - coef) * soft_out
out = self.out_layernorm(out)
elif self.adapter_type == "none": elif self.adapter_type == "none":
out = representation out = representation
...@@ -204,15 +223,15 @@ class Adapter(nn.Module): ...@@ -204,15 +223,15 @@ class Adapter(nn.Module):
return out, padding return out, padding
class TextEncoder(FairseqEncoder): class TextEncoder(nn.Module):
def __init__(self, args, dictionary): def __init__(self, args, dictionary):
super().__init__(None) super().__init__()
self.embed_tokens = None self.embed_tokens = None
attention_dim = args.encoder_embed_dim embed_dim = args.encoder_embed_dim
self.embed_scale = math.sqrt(attention_dim) self.embed_scale = math.sqrt(embed_dim)
if args.no_scale_embedding: if args.no_scale_embedding:
self.embed_scale = 1.0 self.embed_scale = 1.0
self.padding_idx = dictionary.pad_index self.padding_idx = dictionary.pad_index
...@@ -233,7 +252,7 @@ class TextEncoder(FairseqEncoder): ...@@ -233,7 +252,7 @@ class TextEncoder(FairseqEncoder):
else: else:
self.layer_norm = None self.layer_norm = None
def forward(self, x, encoder_padding_mask=None, positions=None, history=None): def forward(self, x, encoder_padding_mask=None, history=None):
x = self.embed_scale * x x = self.embed_scale * x
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
...@@ -264,20 +283,21 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -264,20 +283,21 @@ class S2TSATEEncoder(FairseqEncoder):
super().__init__(None) super().__init__(None)
# acoustic encoder # acoustic encoder
acoustic_encoder_type = getattr(args, "acoustic_encoder", "transformer") acoustic_encoder_type = args.acoustic_encoder
if acoustic_encoder_type == "transformer": if acoustic_encoder_type == "transformer":
self.acoustic_encoder = S2TTransformerEncoder(args, task, embed_tokens) self.acoustic_encoder = S2TTransformerEncoder(args, task, embed_tokens)
elif acoustic_encoder_type == "conformer": elif acoustic_encoder_type == "pds":
self.acoustic_encoder = S2TConformerEncoder(args, task, embed_tokens) self.acoustic_encoder = PDSS2TTransformerEncoder(args, task, embed_tokens)
elif acoustic_encoder_type == "pyramid":
self.acoustic_encoder = PyS2TTransformerEncoder(args, task, embed_tokens)
else: else:
logging.error("Unsupported model arch {}!".format(acoustic_encoder_type)) logging.error("Unsupported model arch {}!".format(acoustic_encoder_type))
# adapter # adapter
self.temperature = getattr(args, "temperature", 1.0) self.temperature = args.temperature
self.adapter = Adapter(args, task.source_dictionary, embed_tokens) self.adapter = Adapter(args, task.source_dictionary, embed_tokens)
if args.share_ctc_and_adapter and hasattr(self.adapter, "linear_adapter"):
self.acoustic_encoder.ctc_projection.weight = self.adapter.linear_adapter[0].weight
# self.length_adapter = Conv1dSubsampler( # self.length_adapter = Conv1dSubsampler(
# args.encoder_embed_dim, # args.encoder_embed_dim,
# args.conv_channels, # args.conv_channels,
...@@ -286,9 +306,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -286,9 +306,7 @@ class S2TSATEEncoder(FairseqEncoder):
# ) # )
acoustic_encoder_attention_type = args.encoder_attention_type acoustic_encoder_attention_type = args.encoder_attention_type
# if acoustic_encoder_attention_type != "selfattn": args.encoder_attention_type = args.text_attention_type
# args.encoder_attention_type = "selfattn"
# logger.info("Force self attention for text encoder.")
# text encoder # text encoder
self.text_encoder = TextEncoder(args, task.source_dictionary) self.text_encoder = TextEncoder(args, task.source_dictionary)
...@@ -296,9 +314,10 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -296,9 +314,10 @@ class S2TSATEEncoder(FairseqEncoder):
args.encoder_attention_type = acoustic_encoder_attention_type args.encoder_attention_type = acoustic_encoder_attention_type
if getattr(args, "use_enc_dlcl", False): if getattr(args, "use_enc_dlcl", False):
normalize_before = args.encoder_normalize_before
layer_num = args.encoder_layers + args.text_encoder_layers + 1 layer_num = args.encoder_layers + args.text_encoder_layers + 1
self.history = LearnableDenseLayerHistory(normalize_before, layer_num, args.encoder_embed_dim, True) self.history = LearnableDenseLayerHistory(
args.encoder_normalize_before, layer_num, args.encoder_embed_dim, True
)
else: else:
self.history = None self.history = None
...@@ -311,9 +330,9 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -311,9 +330,9 @@ class S2TSATEEncoder(FairseqEncoder):
encoder_out = acoustic_encoder_out["encoder_out"][0] encoder_out = acoustic_encoder_out["encoder_out"][0]
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0] encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
if self.acoustic_encoder.use_ctc: if "ctc_logit" in acoustic_encoder_out and len(acoustic_encoder_out["ctc_logit"]) > 0:
ctc_logit = self.acoustic_encoder.compute_ctc_logit(encoder_out) ctc_logit = acoustic_encoder_out["ctc_logit"][0]
ctc_prob = self.acoustic_encoder.compute_ctc_prob(encoder_out, self.temperature) ctc_prob = self.acoustic_encoder.ctc.softmax(ctc_logit, self.temperature)
else: else:
ctc_logit = None ctc_logit = None
ctc_prob = None ctc_prob = None
...@@ -340,8 +359,8 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -340,8 +359,8 @@ class S2TSATEEncoder(FairseqEncoder):
x = self.text_encoder(x, encoder_padding_mask, self.history) x = self.text_encoder(x, encoder_padding_mask, self.history)
return { return {
"ctc_logit": [ctc_logit], # T x B x C
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C] "encoder_states": [], # List[T x B x C]
...@@ -349,9 +368,6 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -349,9 +368,6 @@ class S2TSATEEncoder(FairseqEncoder):
"src_lengths": [], "src_lengths": [],
} }
def compute_ctc_logit(self, encoder_out):
return encoder_out["ctc_logit"][0]
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
new_ctc_logit = ( new_ctc_logit = (
[] if len(encoder_out["ctc_logit"]) == 0 [] if len(encoder_out["ctc_logit"]) == 0
...@@ -401,9 +417,6 @@ def base_architecture(args): ...@@ -401,9 +417,6 @@ def base_architecture(args):
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn") args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.encoder_normalize_before = getattr(args, "acoustic_encoder", "transformer")
args.encoder_normalize_before = getattr(args, "adapter", "league")
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr( args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
...@@ -444,20 +457,33 @@ def base_architecture(args): ...@@ -444,20 +457,33 @@ def base_architecture(args):
args.use_cnn_module = getattr(args, "use_cnn_module", False) args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31) args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Pyramid # SATE
args.pyramid_stages = getattr(args, "pyramid_stages", None) args.acoustic_encoder = getattr(args, "acoustic_encoder", "transformer")
args.pyramid_layers = getattr(args, "pyramid_layers", None) args.adapter = getattr(args, "adapter", "league")
args.pyramid_sr_ratios = getattr(args, "pyramid_sr_ratios", None) args.temperature = getattr(args, "temperature", 1.0)
args.pyramid_attn_sample_ratios = getattr(args, "pyramid_attn_sample_ratios", None) args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.pyramid_embed_dims = getattr(args, "pyramid_embed_dims", None) args.text_attention_type = getattr(args, "text_attention_type", "selfattn")
args.pyramid_kernel_sizes = getattr(args, "pyramid_kernel_sizes", None)
args.pyramid_ffn_ratios = getattr(args, "pyramid_ffn_ratios", None) # PDS
args.pyramid_heads = getattr(args, "pyramid_heads", None) args.pds_stages = getattr(args, "pds_stages", None)
args.pyramid_position_embed = getattr(args, "pyramid_position_embed", None) args.pds_layers = getattr(args, "pds_layers", None)
args.pyramid_reduced_embed = getattr(args, "pyramid_reduced_embed", "conv") args.pds_ratios = getattr(args, "pds_ratios", None)
args.pyramid_embed_norm = getattr(args, "pyramid_embed_norm", False)
args.ctc_layer = getattr(args, "ctc_layer", -1) args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pyramid_dropout = getattr(args, "pyramid_dropout", args.dropout) args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", True)
args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
args.fusion = getattr(args, "fusion", False)
args.fusion_method = getattr(args, "fusion_method", "all_conv")
@register_model_architecture("s2t_sate", "s2t_sate_s") @register_model_architecture("s2t_sate", "s2t_sate_s")
def s2t_sate_s(args): def s2t_sate_s(args):
......
...@@ -13,6 +13,7 @@ from fairseq.models import ( ...@@ -13,6 +13,7 @@ from fairseq.models import (
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.models.speech_to_text import CTC
from fairseq.models.transformer import Embedding, TransformerDecoder from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.modules import ( from fairseq.modules import (
FairseqDropout, FairseqDropout,
...@@ -69,12 +70,17 @@ class Conv1dSubsampler(nn.Module): ...@@ -69,12 +70,17 @@ class Conv1dSubsampler(nn.Module):
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D) bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D)
x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T
inner_x = []
for conv in self.conv_layers: for conv in self.conv_layers:
x = conv(x) x = conv(x)
x = nn.functional.glu(x, dim=1) x = nn.functional.glu(x, dim=1)
inner_x.append(x)
_, _, out_seq_len = x.size() _, _, out_seq_len = x.size()
x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D) # x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D)
return x, self.get_out_seq_lens_tensor(src_lengths) out_inner_x = []
for x in inner_x:
out_inner_x.append(x.transpose(1, 2).transpose(0, 1).contiguous())
return out_inner_x, self.get_out_seq_lens_tensor(src_lengths)
@register_model("s2t_transformer") @register_model("s2t_transformer")
...@@ -281,7 +287,15 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -281,7 +287,15 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
default="learnable_dense", default="learnable_dense",
help='decoder layer history type' help='decoder layer history type'
) )
# CTC
parser.add_argument(
"--ctc-layer",
default=0,
type=int,
help="the position of the ctc loss",
)
# local modeling
parser.add_argument( parser.add_argument(
'--hard-mask-window', '--hard-mask-window',
type=float, type=float,
...@@ -349,6 +363,21 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -349,6 +363,21 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true", action="store_true",
help="Simultaneous speech translation or not", help="Simultaneous speech translation or not",
) )
# interleaved dropout
parser.add_argument('--interleave-dropout', type=int,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout',
action="store_true",
default=False,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout-epoch',
type=int,
default=None,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout-strategy',
type=str,
help='interleaved dropout probability')
pass pass
@classmethod @classmethod
...@@ -462,24 +491,13 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -462,24 +491,13 @@ class S2TTransformerEncoder(FairseqEncoder):
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx, pos_emb_type=self.attn_type args.max_source_positions, args.encoder_embed_dim, self.padding_idx
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] [ConformerEncoderLayer(args) for _ in range(args.encoder_layers)]
) )
# self.inter_subsample = []
# for i in range(args.encoder_layers // 4 - 1):
# self.inter_subsample.append(
# Conv1dSubsampler(
# args.encoder_embed_dim,
# args.encoder_ffn_embed_dim,
# args.encoder_embed_dim,
# [5],
# )
# )
if args.encoder_normalize_before: if args.encoder_normalize_before:
self.layer_norm = LayerNorm(args.encoder_embed_dim) self.layer_norm = LayerNorm(args.encoder_embed_dim)
else: else:
...@@ -491,34 +509,77 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -491,34 +509,77 @@ class S2TTransformerEncoder(FairseqEncoder):
self.history = None self.history = None
self.use_ctc = "sate" in args.arch or \ self.use_ctc = "sate" in args.arch or \
(("ctc" in getattr(args, "criterion", False)) and (("ctc" in getattr(args, "criterion", "")) and (getattr(args, "ctc_weight", 0) > 0))
(getattr(args, "ctc_weight", False) > 0))
if self.use_ctc: if self.use_ctc:
if task.source_dictionary == task.target_dictionary and getattr(args, "share_all_embeddings", False): self.ctc_layer = (args.encoder_layers + args.ctc_layer) % args.encoder_layers
self.ctc_projection = nn.Linear( self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
embed_tokens.weight.shape[1], if self.inter_ctc:
embed_tokens.weight.shape[0], logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
bias=False, self.ctc = CTC(args.encoder_embed_dim,
) dictionary_size=len(task.source_dictionary),
self.ctc_projection.weight = embed_tokens.weight dropout=args.dropout,
else: need_layernorm=True if self.inter_ctc else False)
self.ctc_projection = nn.Linear(args.encoder_embed_dim, len(task.source_dictionary), bias=False)
nn.init.normal_( if task.source_dictionary == task.target_dictionary:
self.ctc_projection.weight, mean=0, std=args.encoder_embed_dim ** -0.5 self.ctc.ctc_projection.weight = embed_tokens.weight
)
self.ctc_dropout_module = FairseqDropout( self.interleaved_dropout = getattr(args, "interleave_dropout", None)
p=args.dropout, module_name=self.__class__.__name__
) self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
self.softmax = nn.Softmax(dim=-1) # self.gather_cos_sim = True
self.dis = 2
self.cos_sim = dict()
@staticmethod
def pooling_ratio():
return 4
def add_to_dict(self, x, dis, idx):
sim = 0
seq_len = x.size(0)
cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
for i in range(dis, seq_len - dis):
a = x[i, :, :]
for j in range(-dis, dis + 1):
if j == 0:
continue
b = x[i + j, :, :]
sim_j = cos(a, b).mean()
sim += sim_j
sim = sim / 2 / dis / (seq_len - 2 * dis)
if idx not in self.cos_sim:
self.cos_sim[idx] = []
self.cos_sim[idx].append(float(sim))
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
ctc_input = None
if self.history is not None: if self.history is not None:
self.history.clean() self.history.clean()
# gather cosine similarity
cos_sim_idx = -1
dis = self.dis
if self.gather_cos_sim:
self.add_to_dict(src_tokens.transpose(0, 1), dis, cos_sim_idx)
# down-sampling
x, input_lengths = self.subsample(src_tokens, src_lengths) x, input_lengths = self.subsample(src_tokens, src_lengths)
if type(x) == list:
inner_x = x
# gather cosine similarity
if self.gather_cos_sim:
for x in inner_x:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
x = inner_x[-1]
# embedding scaling
x = self.embed_scale * x x = self.embed_scale * x
# padding and position embedding
encoder_padding_mask = lengths_to_padding_mask(input_lengths) encoder_padding_mask = lengths_to_padding_mask(input_lengths)
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn": if self.attn_type != "rel_selfattn":
...@@ -530,18 +591,32 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -530,18 +591,32 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.history is not None: if self.history is not None:
self.history.add(x) self.history.add(x)
# layer_index = 0 # gather cosine similarity
cos_sim_idx = (cos_sim_idx + 10) // 10 * 10 - 1
if self.gather_cos_sim:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
layer_index = 0
ctc_logit = None
for layer in self.layers: for layer in self.layers:
layer_index += 1
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
# encoder layer
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
# layer_index += 1 if layer_index != len(self.layers) \
# if layer_index % 4 == 0: and self.interleaved_dropout is not None \
# index = layer_index // 4 - 1 and layer_index % self.interleaved_dropout == 0:
# x = x.transpose(0, 1) x = self.dropout_module(x)
# x, input_lengths = self.inter_subsample[index](x, input_lengths)
# encoder_padding_mask = lengths_to_padding_mask(input_lengths) # gather cosine similarity
if self.gather_cos_sim:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
if self.history is not None: if self.history is not None:
self.history.add(x) self.history.add(x)
...@@ -552,8 +627,12 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -552,8 +627,12 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x)
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # B x T x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C] "encoder_states": [], # List[T x B x C]
...@@ -561,30 +640,17 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -561,30 +640,17 @@ class S2TTransformerEncoder(FairseqEncoder):
"src_lengths": [], "src_lengths": [],
} }
def compute_ctc_logit(self, encoder_out):
assert self.use_ctc, "CTC is not available!"
if isinstance(encoder_out, dict) and "encoder_out" in encoder_out:
encoder_state = encoder_out["encoder_out"][0]
else:
encoder_state = encoder_out
ctc_logit = self.ctc_projection(self.ctc_dropout_module(encoder_state))
return ctc_logit
def compute_ctc_prob(self, encoder_out, temperature=1.0):
assert self.use_ctc, "CTC is not available!"
ctc_logit = self.compute_ctc_logit(encoder_out) / temperature
return self.softmax(ctc_logit)
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = ( new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0 [] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
) )
new_ctc_logit = (
[] if len(encoder_out["ctc_logit"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["ctc_logit"]]
)
new_encoder_padding_mask = ( new_encoder_padding_mask = (
[] if len(encoder_out["encoder_padding_mask"]) == 0 [] if len(encoder_out["encoder_padding_mask"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]] else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
...@@ -602,6 +668,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -602,6 +668,7 @@ class S2TTransformerEncoder(FairseqEncoder):
return { return {
"encoder_out": new_encoder_out, # T x B x C "encoder_out": new_encoder_out, # T x B x C
"ctc_logit": new_ctc_logit, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T "encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C "encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C] "encoder_states": encoder_states, # List[T x B x C]
...@@ -661,11 +728,6 @@ def base_architecture(args): ...@@ -661,11 +728,6 @@ def base_architecture(args):
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024) args.conv_channels = getattr(args, "conv_channels", 1024)
# Conformer
args.macaron_style = getattr(args, "macaron_style", True)
args.use_cnn_module = getattr(args, "use_cnn_module", True)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Transformer # Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
...@@ -708,14 +770,30 @@ def base_architecture(args): ...@@ -708,14 +770,30 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
# Conformer
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Relative position encoding
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1) args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1) args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
# local modeling
args.hard_mask_window = getattr(args, 'hard_mask_window', 0) args.hard_mask_window = getattr(args, 'hard_mask_window', 0)
args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0) args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.init_mask_weight = getattr(args, 'init_mask_weight', 0) args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
# interleaved dropout
args.interleave_dropout = getattr(args, "interleave_dropout", None)
args.cl_dropout = getattr(args, "cl_dropout", False)
args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None)
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
@register_model_architecture("s2t_transformer", "s2t_transformer_s") @register_model_architecture("s2t_transformer", "s2t_transformer_s")
def s2t_transformer_s(args): def s2t_transformer_s(args):
......
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
# #
......
...@@ -249,6 +249,8 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -249,6 +249,8 @@ class TransformerModel(FairseqEncoderDecoderModel):
metavar="STR", metavar="STR",
help="freeze the module of the decoder", help="freeze the module of the decoder",
) )
parser.add_argument('--interleave-dropout', default=0, type=float, metavar='D',
help='interleaved dropout probability')
# fmt: on # fmt: on
@classmethod @classmethod
...@@ -819,6 +821,10 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -819,6 +821,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
) )
self.gather_attn_weight = getattr(args, "gather_attn_weight", False)
#self.gather_attn_weight = True
self.attn_weights = dict()
def build_decoder_layer(self, args, no_encoder_attn=False): def build_decoder_layer(self, args, no_encoder_attn=False):
layer = TransformerDecoderLayer(args, no_encoder_attn) layer = TransformerDecoderLayer(args, no_encoder_attn)
if getattr(args, "checkpoint_activations", False): if getattr(args, "checkpoint_activations", False):
...@@ -968,6 +974,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -968,6 +974,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
# decoder layers # decoder layers
avg_attn = None
attn: Optional[Tensor] = None attn: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x] inner_states: List[Optional[Tensor]] = [x]
for idx, layer in enumerate(self.layers): for idx, layer in enumerate(self.layers):
...@@ -993,8 +1000,8 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -993,8 +1000,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
incremental_state, incremental_state,
self_attn_mask=self_attn_mask, self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask, self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)), need_attn=bool((idx == alignment_layer) or self.gather_attn_weight),
need_head_weights=bool((idx == alignment_layer)), need_head_weights=bool((idx == alignment_layer) or self.gather_attn_weight),
pos_emb=positions pos_emb=positions
) )
inner_states.append(x) inner_states.append(x)
...@@ -1002,8 +1009,35 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -1002,8 +1009,35 @@ class TransformerDecoder(FairseqIncrementalDecoder):
attn = layer_attn.float().to(x) attn = layer_attn.float().to(x)
if self.history is not None: if self.history is not None:
self.history.add(x) self.history.add(x)
if self.gather_attn_weight:
if attn is not None: if avg_attn is None:
avg_attn = layer_attn
else:
avg_attn += layer_attn
if self.gather_attn_weight:
avg_attn = avg_attn / len(self.layers)
attn = avg_attn.mean(0).sum(-2)
attn = torch.reshape(attn, [attn.numel()])
attn = attn // 0.001
attn = attn.int().cpu()
if len(encoder_out["encoder_padding_mask"]) > 0:
mask = encoder_out["encoder_padding_mask"][0]
mask = torch.reshape(mask, [mask.numel()])
else:
mask = None
i = -1
for item in attn:
i += 1
if mask[i]:
continue
idx = int(item) * 0.001
if idx not in self.attn_weights:
self.attn_weights[idx] = 0
self.attn_weights[idx] += 1
elif attn is not None:
if alignment_heads is not None: if alignment_heads is not None:
attn = attn[:alignment_heads] attn = attn[:alignment_heads]
......
...@@ -43,7 +43,7 @@ from .unfold import unfold1d ...@@ -43,7 +43,7 @@ from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .vggblock import VGGBlock from .vggblock import VGGBlock
from .conformer_layer import ConformerEncoderLayer from .conformer_layer import ConformerEncoderLayer
from .pyramid_layer import PyramidTransformerEncoderLayer from .pds_layer import PDSTransformerEncoderLayer
__all__ = [ __all__ = [
"AdaptiveInput", "AdaptiveInput",
...@@ -78,7 +78,7 @@ __all__ = [ ...@@ -78,7 +78,7 @@ __all__ = [
"LocalMultiheadAttention", "LocalMultiheadAttention",
"MultiheadAttention", "MultiheadAttention",
"PositionalEmbedding", "PositionalEmbedding",
"PyramidTransformerEncoderLayer", "PDSTransformerEncoderLayer",
"ReducedMultiheadAttention", "ReducedMultiheadAttention",
"RelPositionMultiheadAttention", "RelPositionMultiheadAttention",
"RelativeMultiheadAttention", "RelativeMultiheadAttention",
......
...@@ -22,11 +22,11 @@ from fairseq.modules.quant_noise import quant_noise ...@@ -22,11 +22,11 @@ from fairseq.modules.quant_noise import quant_noise
from torch import Tensor from torch import Tensor
class PyramidTransformerEncoderLayer(nn.Module): class PDSTransformerEncoderLayer(nn.Module):
"""Encoder layer block. """Encoder layer block for progressive down-sampling method.
In the original paper each operation (multi-head attention or FFN) is In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the post-processed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with: preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the `dropout -> add residual`. We default to the approach in the paper, but the
...@@ -575,6 +575,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -575,6 +575,7 @@ class TransformerDecoderLayer(nn.Module):
x = self.residual_connection(x, residual) x = self.residual_connection(x, residual)
if not self.normalize_before: if not self.normalize_before:
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None: if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state) saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None assert saved_state is not None
......
...@@ -14,9 +14,8 @@ def PositionalEmbedding( ...@@ -14,9 +14,8 @@ def PositionalEmbedding(
embedding_dim: int, embedding_dim: int,
padding_idx: int, padding_idx: int,
learned: bool = False, learned: bool = False,
pos_emb_type: str = None,
): ):
if learned or pos_emb_type == "learned": if learned:
# if padding_idx is specified then offset the embedding ids by # if padding_idx is specified then offset the embedding ids by
# this index and adjust num_embeddings appropriately # this index and adjust num_embeddings appropriately
# TODO: The right place for this offset would be inside # TODO: The right place for this offset would be inside
......
...@@ -406,6 +406,22 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -406,6 +406,22 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for item in translation_list: for item in translation_list:
f.write("{}\n".format("\t".join(item))) f.write("{}\n".format("\t".join(item)))
if models[0].decoder.gather_attn_weight:
weights = models[0].decoder.attn_weights
sort_weights = sorted(weights.items(), key=lambda k: k[0])
num = sum([k[1] for k in sort_weights])
with open("weights", "w", encoding="utf-8") as fw:
for item in sort_weights:
fw.write("%f\t%f\n" % (item[0], item[1] / num))
if getattr(models[0].encoder, "gather_cos_sim", False):
cos_sim = models[0].encoder.cos_sim
with open("cos_sim", "w", encoding="utf-8") as fw:
for layer, sim in cos_sim.items():
sim = sum(sim) / len(sim) * 100
# if layer >= 10:
# layer -= 10
fw.write("%d\t%f\n" % (layer, sim))
return scorer return scorer
......
...@@ -153,6 +153,23 @@ def main(cfg: FairseqConfig) -> None: ...@@ -153,6 +153,23 @@ def main(cfg: FairseqConfig) -> None:
) )
break break
if getattr(cfg.model, "cl_dropout", False):
cl_dropout_epoch = getattr(cfg.model, "cl_dropout_epoch", None)
cl_dropout_strategy = getattr(cfg.model, "cl_dropout_strategy", "linear")
dropout = getattr(cfg.model, "dropout", False)
assert cl_dropout_epoch > 0
curr_epoch = epoch_itr.epoch
if curr_epoch <= cl_dropout_epoch:
if curr_epoch == cl_dropout_epoch:
curr_dropout = dropout
else:
curr_dropout = curr_epoch / cl_dropout_epoch * dropout
logger.info("Epoch {}: dropout ratio: {}.".format(curr_epoch, curr_dropout))
for name, module in trainer.model.named_modules():
from fairseq.modules.fairseq_dropout import FairseqDropout
if isinstance(module, FairseqDropout):
module.p = curr_dropout
# train for one epoch # train for one epoch
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
if should_stop: if should_stop:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论