Commit 380d7794 by xuchen

I optimized the implementation of S2T.

It must be said that some problems still confuse me:
1. Whether to scale in the input layer (I try to replace it with layer specification);
2. The detailed setting of weight sharing between output projection matrix and embedding matrix in the adapter (I notice that inconsistent variance will lead to bad results);
3. The biggest confusion is that the variance increases with the calculation layer by layer (I am not sure if this phenomenon is reasonable, I will compare the behavior on the latest code).
Finally, the detailed implementation is so important to the final performance, even if it is a subtle difference.
parent 03076942
......@@ -4,9 +4,9 @@ interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
sae-adapter: league
sae-drop-prob: 0.2
sae-distribution-cutoff: 10
sae-adapter: inter_league
sae-drop-prob: 0.0
sae-distribution-cutoff: 0
share-ctc-and-sae: False
ctc-self-distill-weight: 0
inter_mixup: True
inter_mixup_layer: -1
inter_mixup_prob: 1.0
inter_mixup_ratio: 0.2
\ No newline at end of file
inter_mixup_ratio: 0.2
inter_mixup_beta: 0.2
arch: s2t_ctc
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
zero_infinity: True
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
#load-pretrained-encoder-from:
\ No newline at end of file
#ctc-layer:
ctc-weight: 0.2
interleaved-ctc-weight: 0.1
interleaved-ctc-layers: 6,9
#ctc-weight: 0.2
interleaved-ctc-weight: 0.3
interleaved-ctc-layers: 2,4
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
interleaved_ctc_upsampling_ratio: 2
sae-adapter: league
sae-adapter: inter_league
sae-drop-prob: 0.0
#sae-distribution-cutoff: 10
share-ctc-and-sae: False
#share-ctc-and-sae: True
ctc-self-distill-weight: 0
\ No newline at end of file
......@@ -8,6 +8,7 @@ max_tokens=8192
exp_tag=baseline
config_list=(base)
config_list=(base_ctc inter)
# exp full name
exp_name=
......
arch: s2t_transformer
arch: s2t_transformer_s
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
......@@ -34,28 +34,19 @@ decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
inter_mixup: True
inter_mixup_layer: -1
inter_mixup_ratio: 0.2
ctc-weight: 0.2
interleaved-ctc-weight: 0.1
interleaved-ctc-layers: 6,9
interleaved-temperature: 2
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-interleaved-ctc-weight: 0.1
#target-interleaved-ctc-layers: 2,4
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
sae-adapter: league
share-ctc-and-sae: False
sae-drop-prob: 0.2
interleaved-ctc-drop-prob: 0.2
sae-drop-prob: 0.0
sae-distribution-cutoff: 10
share-ctc-and-sae: False
ctc-self-distill-weight: 0
post-process: sentencepiece
\ No newline at end of file
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
......@@ -9,7 +9,6 @@ adam_betas: (0.9,0.98)
criterion: ctc
zero_infinity: True
post-process: sentencepiece
subsampling-type: conv1d
subsampling-layers: 2
......
#ctc-layer:
#ctc-weight: 0.2
interleaved-ctc-weight: 0.3
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-layers: 8
interleaved-ctc-temperature: 1
interleaved-ctc-drop-prob: 0
interleaved_ctc_upsampling_ratio: 2
sae-adapter: league
sae-adapter: inter_league
sae-drop-prob: 0.0
#sae-distribution-cutoff: 10
share-ctc-and-sae: False
#share-ctc-and-sae: True
ctc-self-distill-weight: 0
\ No newline at end of file
ctc-self-distill-weight: 0
......@@ -8,6 +8,7 @@ best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
post-process: sentencepiece
#fp16-scale-tolerance: 0.25
no-epoch-checkpoints: True
#keep-last-epochs: 10
keep-best-checkpoints: 10
......
ctc-weight: 0.2
interleaved-ctc-weight: 0.1
ctc-weight: 0.3
interleaved-ctc-weight: 0.2
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
#target-ctc-weight: 0.3
#target-ctc-layer: 6
target-interleaved-ctc-weight: 0.1
target-interleaved-ctc-layers: 2,4
#target-interleaved-ctc-weight: 0.1
#target-interleaved-ctc-layers: 2,4
sae-adapter: league
sae-adapter: inter_league
sae-drop-prob: 0.0
#sae-distribution-cutoff: 10
share-ctc-and-sae: False
share-target-ctc-and-sae: False
#sae-distribution-cutoff: 0
#share-ctc-and-sae: True
#share-target-ctc-and-sae: True
ctc-self-distill-weight: 0
\ No newline at end of file
......@@ -37,6 +37,11 @@ decoder-attention-heads: 4
acoustic-encoder: transformer
adapter: league
#adapter-embed-norm: True
#adapter-out-norm: True
#share-adapter-and-ctc: True
#share-adapter-and-embed: True
#load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
......
......@@ -21,6 +21,7 @@ stop_stage=0
######## hardware ########
# devices
#device=()
use_auto=0
gpu_num=8
update_freq=1
......@@ -215,6 +216,9 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: ST Network Training"
[[ ! -d ${data_dir} ]] && echo "The data dir ${data_dir} is not existing!" && exit 1;
if [[ ${use_auto} -eq 1 ]]; then
device=(-1)
fi
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then
device=""
......@@ -330,13 +334,20 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "${time} | ${device} | ${data_dir} | ${exp_name} | ${model_dir} " >> $log
tail -n 50 ${log} > tmp.log
mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device}
cmd="nohup ${cmd} >> ${model_dir}/train.log 2>&1 &"
if [[ $eval -eq 1 ]]; then
eval $cmd
sleep 2s
tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log
if [[ ${use_auto} -eq 1 ]]; then
cmd=$(echo ${cmd} | tr -d "\n")
auto_run -c "${cmd}" -n ${gpu_num}
else
export CUDA_VISIBLE_DEVICES=${device}
eval $cmd
fi
sleep 5s
if [[ -f ${model_dir}/train.log ]]; then
tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log
fi
fi
fi
wait
......@@ -359,6 +370,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
dec_model=${dec_model}
fi
if [[ ${use_auto} -eq 1 ]]; then
device=(-1)
fi
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then
device=""
......@@ -367,7 +381,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
device=$(get_devices $gpu_num 0)
fi
fi
export CUDA_VISIBLE_DEVICES=${device}
suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ ${n_average} -ne 1 ]]; then
......@@ -409,7 +422,13 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo -e "\033[34mRun command: \n${cmd} \033[0m"
if [[ $eval -eq 1 ]]; then
eval $cmd
if [[ ${use_auto} -eq 1 ]]; then
cmd=$(echo ${cmd} | tr -d "\n")
auto_run -c ${cmd} -n ${gpu_num}
else
export CUDA_VISIBLE_DEVICES=${device}
eval $cmd
fi
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
......
set -e
eval=1
lcrm=0
tokenizer=0
root_dir=~/st/Fairseq-S2T
data_dir=~/st/data/test
vocab_dir=~/st/data/mustc/st/en-de
asr_vocab_prefix=spm_unigram10000_st_share
src_lang=en
tgt_lang=de
subsets=(2019)
cp -r ${vocab_dir}/${asr_vocab_prefix}.* ${data_dir}/${src_lang}-${tgt_lang}
rm -rf ${data_dir}/${src_lang}-${tgt_lang}/fbank80.zip
splits=$(echo ${subsets[*]} | sed 's/ /,/g')
cmd="python ${root_dir}/examples/speech_to_text/prep_st_data.py
--data-root ${data_dir}
--output-root ${data_dir}
--splits ${splits}
--task asr
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--add-src
--share
--asr-prefix ${asr_vocab_prefix}
--cmvn-type utterance"
if [[ ${lcrm} -eq 1 ]]; then
cmd="$cmd
--lowercase-src
--rm-punc-src"
fi
if [[ ${tokenizer} -eq 1 ]]; then
cmd="$cmd
--tokenizer"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
arch: s2t_transformer_s
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 1e-3
#lr: 5e-4
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsampling-layers: 2
#subsampling-filter: 2048
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
load-pretrained-encoder-from: /home/xuchen/st/checkpoints/aishell/asr/0506_sp_char_base_ctc_sample1024/avg_10_checkpoint.pt
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/aishell/asr/0506_sp_char_base_ctc_sample1024/avg_10_checkpoint.pt
load-pretrained-encoder-from: /home/xuchen/st/checkpoints/librispeech/asr/base_baseline/avg_10_checkpoint.pt
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/librispeech/asr/base_baseline/avg_10_checkpoint.pt
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/librispeech/asr/base_conformer_baseline_batch50k_16/avg_10_checkpoint.pt
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/librispeech/asr/base_conformer_baseline_batch50k_16/avg_10_checkpoint.pt
train-subset: train
valid-subset: dev
max-epoch: 100
max-update: 100000
patience: 20
best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
no-epoch-checkpoints: True
#keep-last-epochs: 10
keep-best-checkpoints: 10
num-workers: 8
no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
skip-invalid-size-inputs-valid-test: True
arch: s2t_transformer_m
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.15
activation-fn: relu
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
arch: s2t_transformer_m
share-decoder-input-output-embed: True
share-ctc-and-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 1e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv2d
subsmapling-layers: 2
subsampling-filter: 512
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: relu
dropout: 0.15
activation-fn: relu
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
cnn-module-norm: layer_norm
load-pretrained-encoder-from: /home/xuchen/after.pt
load-pretrained-decoder-from: /home/xuchen/after.pt
#load-pretrained-decoder-from:
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
encoder-attention-type: rel_pos
encoder-activation-fn: swish
ctc-weight: 0.3
post-process: sentencepiece
use-enc-dlcl: True
use-dec-dlcl: True
ctc-weight: 0.2
intermedia-ctc-layers: 6,9
intermedia-adapter: league
intermedia-ctc-weight: 0.1
ctc-self-distill-weight: 0
post-process: sentencepiece
\ No newline at end of file
encoder-attention-type: local
hard-mask-window: 0
gauss-mask-sigma: 3
init-mask-weight: 0
\ No newline at end of file
arch: pdss2t_transformer_s_8
pds-fusion: True
ctc-layer: 12
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: pdss2t_transformer_s_16
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 2_2_6_2
pds-ratios: 2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: pdss2t_transformer_s_32
encoder-embed-dim: 256
pds-stages: 5
ctc-layer: 12
pds-layers: 2_2_3_3_2
pds-ratios: 2_2_2_2_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1_1
pds-kernel-sizes: 5_5_5_5_5
pds-ffn-ratios: 8_8_8_8_8
pds-attn-heads: 4_4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 4_2_2_4
pds-ratios: 2_2_1_1
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 1e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 4_2_2_4
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 1e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
arch: pdss2t_transformer_s_8
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 512
pds-stages: 4
#ctc-layer: 15
encoder-layers: 18
pds-layers: 6_3_3_6
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
decoder-layers: 6
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
arch: pdss2t_transformer_m_8
encoder-embed-dim: 512
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 512_512_512_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 4_4_4_4
pds-attn-heads: 8_8_8_8
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.15
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
arch: s2t_ctc
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 1e-3
adam_betas: (0.9,0.98)
criterion: ctc
zero_infinity: True
post-process: sentencepiece
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
load-pretrained-encoder-from: /home/xuchen/st/checkpoints/aishell/asr/0506_sp_char_base_ctc_sample1024/avg_10_checkpoint.pt
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/librispeech/asr/base_baseline/avg_10_checkpoint.pt
arch: s2t_ctc
encode-type: pds
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 8_4_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
#load-pretrained-encoder-from:
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#intermedia-temperature: 5
encoder-attention-type: rel_pos
#encoder-attention-type: reduced_rel_pos
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 512
pds-stages: 4
#ctc-layer: 15
encoder-layers: 10
pds-layers: 3_2_2_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.002
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
#load-pretrained-encoder-from:
encoder-attention-type: rel_pos
#encoder-attention-type: relative
#max-encoder-relative-length: 100
#! /bin/bash
gpu_num=1
data_dir=
test_subset=(dev test)
exp_name=
if [ "$#" -eq 1 ]; then
exp_name=$1
fi
cer=1
n_average=10
beam_size=5
len_penalty=1.0
max_tokens=10000
dec_model=checkpoint_best.pt
cmd="./run.sh
--stage 2
--stop_stage 2
--gpu_num ${gpu_num}
--exp_name ${exp_name}
--n_average ${n_average}
--cer ${cer}
--beam_size ${beam_size}
--len_penalty ${len_penalty}
--max_tokens ${max_tokens}
--dec_model ${dec_model}
"
if [[ -n ${data_dir} ]]; then
cmd="$cmd --data_dir ${data_dir}"
fi
if [[ ${#test_subset[@]} -ne 0 ]]; then
subsets=$(echo ${test_subset[*]} | sed 's/ /,/g')
cmd="$cmd --test_subset ${subsets}"
fi
echo $cmd
eval $cmd
gpu_num=4
cmd="sh train.sh"
while :
do
record=$(mktemp -t temp.record.XXXXXX)
gpustat > $record
all_devices=$(seq 0 "$(sed '1,2d' ${record} | wc -l)");
count=0
for dev in ${all_devices[@]}
do
line=$((dev + 2))
use=$(head -n $line ${record} | tail -1 | cut -d '|' -f3 | cut -d '/' -f1)
if [[ $use -lt 100 ]]; then
device[$count]=$dev
count=$((count + 1))
if [[ $count -eq $gpu_num ]]; then
break
fi
fi
done
if [[ ${#device[@]} -lt $gpu_num ]]; then
sleep 60s
else
echo "Run $cmd"
eval $cmd
sleep 10s
exit
fi
done
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.
get_devices(){
gpu_num=$1
use_cpu=$2
device=()
while :
do
record=$(mktemp -t temp.record.XXXXXX)
gpustat > $record
all_devices=$(seq 0 "$(sed '1,2d' ${record} | wc -l)");
count=0
for dev in ${all_devices[@]}
do
line=$((dev + 2))
use=$(head -n $line ${record} | tail -1 | cut -d '|' -f3 | cut -d '/' -f1)
if [[ $use -lt 100 ]]; then
device[$count]=$dev
count=$((count + 1))
if [[ $count -eq $gpu_num ]]; then
break
fi
fi
done
if [[ ${#device[@]} -lt $gpu_num ]]; then
if [[ $use_cpu -eq 1 ]]; then
device=(-1)
else
sleep 60s
fi
else
break
fi
done
echo ${device[*]} | sed 's/ /,/g'
return $?
}
#! /bin/bash
# Processing MuST-C Datasets
# Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com)
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
#set -u
set -o pipefail
export PYTHONIOENCODING=UTF-8
eval=1
time=$(date "+%m%d_%H%M")
stage=0
stop_stage=0
######## hardware ########
# devices
#device=()
gpu_num=8
update_freq=1
root_dir=~/st
code_dir=${root_dir}/Fairseq-S2T
pwd_dir=$PWD
# dataset
src_lang=ti
tgt_lang=de
lang=${src_lang}-${tgt_lang}
dataset=tibetan
task=speech_to_text
vocab_type=unigram
vocab_type=char
#vocab_type=word
vocab_size=1700
speed_perturb=0
lcrm=0
tokenizer=1
use_raw_audio=0
use_specific_dict=0
specific_prefix=st
specific_dir=${root_dir}/data/mustc/st
asr_vocab_prefix=spm_unigram10000_st_share
org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/asr_char
data_dir=${root_dir}/data/${dataset}/asr_word
#data_dir=${root_dir}/data/${dataset}/asr
train_split=train
valid_split=dev
test_split=test
test_subset=test
# exp
exp_prefix=$(date "+%m%d")
extra_tag=
extra_parameter=
exp_tag=baseline
exp_name=
# config
train_config=base
data_config=config.yaml
# training setting
fp16=1
max_tokens=40000
step_valid=0
# decoding setting
cer=0
dec_model=checkpoint_best.pt
n_average=10
beam_size=5
len_penalty=1.0
if [[ ${speed_perturb} -eq 1 ]]; then
data_dir=${data_dir}_sp
exp_prefix=${exp_prefix}_sp
fi
if [[ ${lcrm} -eq 1 ]]; then
data_dir=${data_dir}_lcrm
exp_prefix=${exp_prefix}_lcrm
fi
if [[ ${use_specific_dict} -eq 1 ]]; then
data_dir=${data_dir}_${specific_prefix}
exp_prefix=${exp_prefix}_${specific_prefix}
fi
if [[ ${tokenizer} -eq 1 ]]; then
data_dir=${data_dir}_tok
exp_prefix=${exp_prefix}_tok
fi
if [[ ${use_raw_audio} -eq 1 ]]; then
data_dir=${data_dir}_raw
exp_prefix=${exp_prefix}_raw
fi
. ./local/parse_options.sh || exit 1;
if [[ -z ${exp_name} ]]; then
config_string=${train_config//,/_}
exp_name=${exp_prefix}_${config_string}_${exp_tag}
if [[ -n ${extra_tag} ]]; then
exp_name=${exp_name}_${extra_tag}
fi
fi
model_dir=${root_dir}/checkpoints/${dataset}/asr/${exp_name}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "stage -1: Data Download"
# pass
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
### Task dependent. You have to make data the following preparation part by yourself.
### But you can utilize Kaldi recipes in most cases
echo "stage 0: ASR Data Preparation"
if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir}
fi
feature_zip=fbank80.zip
if [[ ${speed_perturb} -eq 1 ]]; then
feature_zip=fbank80_sp.zip
fi
if [[ ! -f ${data_dir}/${feature_zip} && -f ${data_dir}/../${feature_zip} ]]; then
ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi
cmd="python ${code_dir}/examples/speech_to_text/prep_audio_data.py
--data-root ${org_data_dir}
--output-root ${data_dir}
--task asr
--src-lang ${src_lang}
--splits ${valid_split},${test_split},${train_split}
--vocab-type ${vocab_type}
--vocab-size ${vocab_size}"
if [[ ${use_raw_audio} -eq 1 ]]; then
cmd="$cmd
--raw"
fi
if [[ ${use_specific_dict} -eq 1 ]]; then
cp -r ${specific_dir}/${asr_vocab_prefix}.* ${data_dir}
cmd="$cmd
--asr-prefix ${asr_vocab_prefix}"
fi
if [[ ${speed_perturb} -eq 1 ]]; then
cmd="$cmd
--speed-perturb"
fi
if [[ ${lcrm} -eq 1 ]]; then
cmd="$cmd
--lowercase-src
--rm-punc-src"
fi
if [[ ${tokenizer} -eq 1 ]]; then
cmd="$cmd
--tokenizer"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
if [[ -f ${data_dir}/${feature_zip} && ! -f ${data_dir}/../${feature_zip} ]]; then
mv ${data_dir}/${feature_zip} ${data_dir}/..
ln -s ${data_dir}/../${feature_zip} ${data_dir}
fi
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: ASR Network Training"
[[ ! -d ${data_dir} ]] && echo "The data dir ${data_dir} is not existing!" && exit 1;
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then
device=""
else
source ./local/utils.sh
device=$(get_devices $gpu_num 0)
fi
fi
echo -e "dev=${device} data=${data_dir} model=${model_dir}"
if [[ ! -d ${model_dir} ]]; then
mkdir -p ${model_dir}
else
echo "${model_dir} exists."
fi
cp ${BASH_SOURCE[0]} ${model_dir}
cp ${PWD}/train.sh ${model_dir}
extra_parameter="${extra_parameter}
--train-config ${pwd_dir}/conf/basis.yaml"
cp ${pwd_dir}/conf/basis.yaml ${model_dir}
config_list="${train_config//,/ }"
idx=1
for config in ${config_list[@]}
do
config_path=${pwd_dir}/conf/${config}.yaml
if [[ ! -f ${config_path} ]]; then
echo "No config file ${config_path}"
exit
fi
cp ${config_path} ${model_dir}
extra_parameter="${extra_parameter}
--train-config${idx} ${config_path}"
idx=$((idx + 1))
done
cmd="python3 -u ${code_dir}/fairseq_cli/train.py
${data_dir}
--config-yaml ${data_config}
--task ${task}
--max-tokens ${max_tokens}
--skip-invalid-size-inputs-valid-test
--update-freq ${update_freq}
--log-interval 100
--save-dir ${model_dir}
--tensorboard-logdir ${model_dir}"
if [[ -n ${extra_parameter} ]]; then
cmd="${cmd}
${extra_parameter}"
fi
if [[ ${gpu_num} -gt 0 ]]; then
cmd="${cmd}
--distributed-world-size $gpu_num
--ddp-backend no_c10d"
fi
if [[ $fp16 -eq 1 ]]; then
cmd="${cmd}
--fp16"
fi
if [[ $step_valid -eq 1 ]]; then
validate_interval=1
save_interval=1
keep_last_epochs=10
no_epoch_checkpoints=0
save_interval_updates=500
keep_interval_updates=10
else
validate_interval=1
keep_last_epochs=10
fi
if [[ -n $no_epoch_checkpoints && $no_epoch_checkpoints -eq 1 ]]; then
cmd="$cmd
--no-epoch-checkpoints"
fi
if [[ -n $validate_interval ]]; then
cmd="${cmd}
--validate-interval $validate_interval "
fi
if [[ -n $save_interval ]]; then
cmd="${cmd}
--save-interval $save_interval "
fi
if [[ -n $keep_last_epochs ]]; then
cmd="${cmd}
--keep-last-epochs $keep_last_epochs "
fi
if [[ -n $save_interval_updates ]]; then
cmd="${cmd}
--save-interval-updates $save_interval_updates"
if [[ -n $keep_interval_updates ]]; then
cmd="${cmd}
--keep-interval-updates $keep_interval_updates"
fi
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
# save info
log=./history.log
echo "${time} | ${device} | ${data_dir} | ${exp_name} | ${model_dir} " >> $log
tail -n 50 ${log} > tmp.log
mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device}
cmd="nohup ${cmd} >> ${model_dir}/train.log 2>&1 &"
if [[ $eval -eq 1 ]]; then
eval $cmd
sleep 2s
tail -n "$(wc -l ${model_dir}/train.log | awk '{print $1+1}')" -f ${model_dir}/train.log
fi
fi
wait
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: ASR Decoding"
if [[ ${n_average} -ne 1 ]]; then
# Average models
dec_model=avg_${n_average}_checkpoint.pt
if [[ ! -f ${model_dir}/${dec_model} ]]; then
cmd="python ${code_dir}/scripts/average_checkpoints.py
--inputs ${model_dir}
--num-best-checkpoints ${n_average}
--output ${model_dir}/${dec_model}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval $cmd
fi
else
dec_model=${dec_model}
fi
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then
device=""
else
source ./local/utils.sh
device=$(get_devices $gpu_num 0)
fi
fi
export CUDA_VISIBLE_DEVICES=${device}
result_file=${model_dir}/decode_result
[[ -f ${result_file} ]] && rm ${result_file}
test_subset=${test_subset//,/ }
for subset in ${test_subset[@]}; do
subset=${subset}
cmd="python ${code_dir}/fairseq_cli/generate.py
${data_dir}
--config-yaml ${data_config}
--gen-subset ${subset}
--task speech_to_text
--path ${model_dir}/${dec_model}
--results-path ${model_dir}
--max-tokens ${max_tokens}
--beam ${beam_size}
--lenpen ${len_penalty}
--scoring wer
"
if [[ ${cer} -eq 1 ]]; then
cmd="${cmd}
--wer-char-level"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
if [[ $eval -eq 1 ]]; then
eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
fi
done
cat ${result_file}
fi
#! /bin/bash
# training the model
gpu_num=1
update_freq=1
max_tokens=20000
#extra_tag=lr0.0005
#extra_tag=lr0.001
#extra_tag=char
extra_tag=word
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
exp_tag=batch5w
exp_tag=pretrain
#exp_tag=batch5w_pre_libri
config_list=(purectc)
config_list=(base)
#config_list=(base ctc)
#config_list=(base conformer)
config_list=(big_wenet conformer ctc)
#config_list=(pds_base_4 ctc)
#config_list=(pds_base_8 ctc)
# exp full name
exp_name=
train_config=$(echo ${config_list[*]} | sed 's/ /,/g')
cmd="./run.sh
--stage 1
--stop_stage 1
--gpu_num ${gpu_num}
--update_freq ${update_freq}
--train_config ${train_config}
--max_tokens ${max_tokens}
"
if [[ -n ${exp_name} ]]; then
cmd="$cmd --exp_name ${exp_name}"
fi
if [[ -n ${exp_tag} ]]; then
cmd="$cmd --exp_tag ${exp_tag}"
fi
if [[ -n ${extra_tag} ]]; then
cmd="$cmd --extra_tag ${extra_tag}"
fi
if [[ -n ${extra_parameter} ]]; then
cmd="$cmd --extra_parameter \"${extra_parameter}\""
fi
echo ${cmd}
eval ${cmd}
arch: transformer_ctc
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 16000
lr: 2e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 20
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
#ctc-weight: 0.2
interleaved-ctc-weight: 0.3
interleaved-ctc-layers: 10,15
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
interleaved_ctc_upsampling_ratio: 2
sae-adapter: inter_league
sae-drop-prob: 0.0
#sae-distribution-cutoff: 10
share-ctc-and-sae: True
share-ctc-and-embed: True
ctc-self-distill-weight: 0
......@@ -14,6 +14,8 @@ import logging
import torch
import torch.nn.functional as F
from torch.distributions import Categorical
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
......@@ -43,6 +45,10 @@ class CtcCriterionConfig(FairseqDataclass):
default=0.0,
metadata={"help": "weight of CTC entropy"},
)
ctc_entropy_cutoff: int = field(
default=0,
metadata={"help": "cutoff for CTC entropy computation"},
)
interleaved_ctc_weight: float = field(
default=0.0,
metadata={"help": "weight of interleaved CTC loss"},
......@@ -132,6 +138,7 @@ class CtcCriterion(FairseqCriterion):
self.target_interleaved_ctc_weight = cfg.target_interleaved_ctc_weight
self.ctc_self_distill_weight = cfg.ctc_self_distill_weight
self.ctc_entropy = cfg.ctc_entropy
self.ctc_entropy_cutoff = cfg.ctc_entropy_cutoff
self.all_ctc_weight = self.ctc_weight + self.interleaved_ctc_weight + \
self.target_ctc_weight + self.target_interleaved_ctc_weight + \
self.ctc_self_distill_weight + self.ctc_entropy
......@@ -218,13 +225,16 @@ class CtcCriterion(FairseqCriterion):
ctc_loss += self.get_loss(lprobs, flat, input_lengths, lengths) * coef
if self.ctc_entropy > 0:
from torch.distributions import Categorical
# ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100]
# ctc_logit = ctc_logit / ctc_logit.sum(dim=-1, keepdim=True)
# cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100]
# ctc_entropy = Categorical(logits=cut_ctc_logit).entropy().sum()
ctc_entropy = Categorical(logits=ctc_logit).entropy().sum()
if self.ctc_entropy_cutoff != 0:
# ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100]
# ctc_logit = ctc_logit / ctc_logit.sum(dim=-1, keepdim=True)
cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:self.ctc_entropy_cutoff]
cut_ctc_logit = cut_ctc_logit / cut_ctc_logit.sum(dim=-1, keepdim=True)
ctc_entropy = Categorical(logits=cut_ctc_logit).entropy().sum()
else:
ctc_entropy = Categorical(logits=ctc_logit).entropy().sum()
logging_output["ctc_entropy"] = utils.item(ctc_entropy.data)
logging_output["ctc_loss"] = utils.item(ctc_loss.data)
......@@ -328,14 +338,13 @@ class CtcCriterion(FairseqCriterion):
out = net_output["interleaved_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
non_padding_mask = ~out[1]
else:
inter_ctc_logit = out
if inter_ctc_logit.size() != ctc_logit.size():
continue
ctc_self_distill_num += 1
loss = F.kl_div(
F.log_softmax(inter_ctc_logit, dim=-1, dtype=torch.float32),
F.softmax(ctc_logit, dim=-1, dtype=torch.float32),
......@@ -344,9 +353,11 @@ class CtcCriterion(FairseqCriterion):
loss = loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0)
loss = loss.sum()
ctc_self_distill_loss += loss
ctc_self_distill_num += 1
ctc_self_distill_loss /= ctc_self_distill_num
logging_output["ctc_self_distill_loss"] = utils.item(ctc_self_distill_loss.data)
if ctc_self_distill_num != 0:
ctc_self_distill_loss /= ctc_self_distill_num
logging_output["ctc_self_distill_loss"] = utils.item(ctc_self_distill_loss.data)
loss = \
self.ctc_weight * ctc_loss + \
......
......@@ -52,7 +52,17 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
3) logging outputs to display while training
"""
src_tokens, src_lengths, prev_output_tokens = sample["net_input"].values()
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
from fairseq.models.speech_to_text import S2TDualModel
if isinstance(model, S2TDualModel):
assert "transcript" in sample
text_src_tokens = sample["transcript"]["tokens"]
text_src_lengths = sample["transcript"]["lengths"]
encoder_out = model.encoder(src_tokens, src_lengths,
text_src_tokens, text_src_lengths)
else:
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
use_mixup = False
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
......
......@@ -854,6 +854,10 @@ class GenerationConfig(FairseqDataclass):
default=False,
metadata={"help": "if set, dont use seed for initializing random generators"},
)
ctc_infer: bool = field(
default=False,
metadata={"help": "generate CTC decoding results during inference"}
)
@dataclass
......
......@@ -595,7 +595,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
pad1 = encoder_padding_mask[idx1]
pad2 = encoder_padding_mask[idx2]
encoder_padding_mask = pad1 + pad2
encoder_padding_mask = pad1 & pad2
input_lengths = (~encoder_padding_mask).sum(-1)
mixup = {
......
......@@ -26,6 +26,7 @@ from fairseq.models.speech_to_text import (
S2TTransformerEncoder,
PDSS2TTransformerModel,
PDSS2TTransformerEncoder,
S2TSATEModel,
S2TSATEEncoder,
)
from fairseq.modules.speech_to_text import Adapter, CTC
......@@ -58,73 +59,18 @@ class S2TDualModel(FairseqEncoderDecoderModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
PDSS2TTransformerModel.add_args(parser)
# SATE setting
parser.add_argument(
"--text-encoder-layers",
default=6,
type=int,
help="layers of the text encoder",
)
parser.add_argument(
"--text-attention-type",
default="selfattn",
type=str,
help="attention type of the textual encoder",
)
parser.add_argument(
"--adapter",
default="league",
type=str,
help="adapter type",
)
parser.add_argument(
"--ctc-compress-strategy",
default="avg",
type=str,
help="compress strategy, such as avg, weighted, and softmax",
)
parser.add_argument(
"--share-ctc-and-adapter",
default=False,
action="store_true",
help="share the projection weights of the ctc and adapter",
)
parser.add_argument(
"--temperature",
default=1.0,
type=float,
help="temperature of the CTC softmax",
)
parser.add_argument(
"--acoustic-encoder",
default="transformer",
type=str,
help="the architecture of the acoustic encoder",
)
parser.add_argument(
"--target-ctc-layers",
default=None,
type=str,
help="ctc layers for target sentence",
)
parser.add_argument(
"--load-pretrained-acoustic-encoder-from",
type=str,
metavar="STR",
help="model to take acoustic encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-text-encoder-from",
type=str,
metavar="STR",
help="model to take text encoder weights from (for initialization)",
)
S2TTransformerModel.add_args(parser)
PDSS2TTransformerModel.add_specific_args(parser)
S2TSATEModel.add_specific_args(parser)
S2TDualModel.add_specific_args(parser)
@staticmethod
def add_specific_args(parser):
# multi-encoder
parser.add_argument(
"--asr-encoder",
default="transformer",
choices=["transformer", "pds", "sate", "wav2vec"],
type=str,
help="the architecture of the ASR encoder",
)
......@@ -134,12 +80,11 @@ class S2TDualModel(FairseqEncoderDecoderModel):
type=str,
help="the architecture of the MT encoder",
)
# parser.add_argument(
# "--mt-encoder-dim",
# default="transformer",
# type=str,
# help="the dimension of the MT encoder",
# )
parser.add_argument(
"--mt-encoder-dim",
type=int,
help="the dimension of the MT encoder",
)
parser.add_argument(
"--mt-encoder-layers",
default=6,
......@@ -148,13 +93,13 @@ class S2TDualModel(FairseqEncoderDecoderModel):
)
parser.add_argument(
"--encoder-asr-ratio",
default=1,
default=0.5,
type=float,
help="the ratio of the asr representation",
)
parser.add_argument(
"--encoder-mt-ratio",
default=1,
default=0.5,
type=float,
help="the ratio of the mt representation",
)
......@@ -165,9 +110,9 @@ class S2TDualModel(FairseqEncoderDecoderModel):
)
parser.add_argument(
"--encoder-drop-net-prob",
default=0.5,
default=0.2,
type=float,
help="the probability of dropping",
help="probability of dropping one of the representations",
)
parser.add_argument(
"--encoder-drop-net-mix",
......@@ -268,9 +213,6 @@ class S2TDualModel(FairseqEncoderDecoderModel):
tgt_dict, args.decoder_embed_dim
)
setattr(args, "encoder_s1_ratio", args.encoder_asr_ratio)
setattr(args, "encoder_s2_ratio", args.encoder_mt_ratio)
encoder = cls.build_encoder(args, task, encoder_embed_tokens)
if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module)
......@@ -326,8 +268,11 @@ class S2TDualEncoder(FairseqEncoder):
else:
logger.error("Unsupported ASR architecture: %s." % asr_encoder_type)
attn_type = args.encoder_attention_type
setattr(args, "encoder_s1_ratio", args.encoder_asr_ratio)
setattr(args, "encoder_s2_ratio", args.encoder_mt_ratio)
setattr(args, "encoder_layers", args.mt_encoder_layers)
attn_type = args.encoder_attention_type
setattr(args, "encoder_attention_type", "selfattn")
self.mt_encoder = TransformerS2Encoder(args, task.source_dictionary, embed_tokens)
setattr(args, "encoder_attention_type", attn_type)
......
......@@ -5,6 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils
from fairseq.models.transformer import Embedding
from fairseq.models import (
FairseqEncoder,
register_model,
......@@ -42,8 +43,12 @@ class S2TSATEModel(S2TTransformerModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
PDSS2TTransformerModel.add_args(parser)
S2TTransformerModel.add_args(parser)
PDSS2TTransformerModel.add_specific_args(parser)
S2TSATEModel.add_specific_args(parser)
@staticmethod
def add_specific_args(parser):
# SATE setting
parser.add_argument(
"--text-encoder-layers",
......@@ -70,10 +75,16 @@ class S2TSATEModel(S2TTransformerModel):
help="compress strategy, such as avg, weighted, and softmax",
)
parser.add_argument(
"--share-ctc-and-adapter",
"--share-adapter-and-ctc",
default=False,
action="store_true",
help="share the projection weights of the adapter and ctc",
)
parser.add_argument(
"--share-adapter-and-embed",
default=False,
action="store_true",
help="share the projection weights of the ctc and adapter",
help="share the projection weights of the adapter and embed",
)
parser.add_argument(
"--adapter-temperature",
......@@ -82,6 +93,18 @@ class S2TSATEModel(S2TTransformerModel):
help="temperature of the CTC softmax in adapter",
)
parser.add_argument(
"--adapter-embed-norm",
default=False,
action="store_true",
help="use the layer norm for embed output",
)
parser.add_argument(
"--adapter-out-norm",
default=False,
action="store_true",
help="use the layer norm for final output",
)
parser.add_argument(
"--acoustic-encoder",
default="transformer",
type=str,
......@@ -107,15 +130,20 @@ class S2TSATEModel(S2TTransformerModel):
help="ctc layer for target sentence",
)
parser.add_argument(
"--share-target-ctc-and-embed",
action="store_true",
help="share the weight of target ctc and embed",
)
parser.add_argument(
"--target-interleaved-ctc-layers",
default=None,
type=str,
help="interleaved ctc layers for target sentence",
)
parser.add_argument(
"--share-target-ctc-and-sae",
"--share-target-sae-and-ctc",
action="store_true",
help="share the weight of target ctc and sae",
help="share the weight of target sae and ctc",
)
# freeze
parser.add_argument(
......@@ -133,7 +161,6 @@ class S2TSATEModel(S2TTransformerModel):
action="store_true",
help="freeze the parameters of the decoder",
)
pass
@classmethod
def build_encoder(cls, args, task=None, decoder_embed_tokens=None):
......@@ -145,7 +172,9 @@ class S2TSATEModel(S2TTransformerModel):
f"{args.load_pretrained_encoder_from}"
)
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
component=encoder,
checkpoint=args.load_pretrained_encoder_from,
strict=False
)
if getattr(args, "load_pretrained_acoustic_encoder_from", None):
......@@ -154,7 +183,9 @@ class S2TSATEModel(S2TTransformerModel):
f"{args.load_pretrained_acoustic_encoder_from}"
)
encoder.acoustic_encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder.acoustic_encoder, checkpoint=args.load_pretrained_acoustic_encoder_from, strict=False
component=encoder.acoustic_encoder,
checkpoint=args.load_pretrained_acoustic_encoder_from,
strict=False
)
if getattr(args, "load_pretrained_text_encoder_from", None):
......@@ -162,15 +193,46 @@ class S2TSATEModel(S2TTransformerModel):
f"loaded pretrained text encoder from: "
f"{args.load_pretrained_text_encoder_from}"
)
encoder.text_encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder.text_encoder, checkpoint=args.load_pretrained_text_encoder_from, strict=False
encoder.textual_encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder.textual_encoder,
checkpoint=args.load_pretrained_text_encoder_from,
strict=False
)
if args.share_ctc_and_adapter and hasattr(encoder.adapter, "embed_adapter"):
encoder.acoustic_encoder.ctc.ctc_projection.weight = encoder.adapter.embed_adapter.weight
if args.share_adapter_and_ctc and hasattr(encoder.adapter, "embed_adapter"):
encoder.adapter.embed_adapter.weight = encoder.acoustic_encoder.ctc.ctc_projection.weight
return encoder
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
decoder_embed_tokens = build_embedding(
task.target_dictionary, args.decoder_embed_dim
)
encoder = cls.build_encoder(args, task, decoder_embed_tokens)
if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module)
logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module))
decoder = cls.build_decoder(args, task, decoder_embed_tokens)
if getattr(args, "decoder_freeze_module", None):
utils.freeze_parameters(decoder, args.decoder_freeze_module)
logging.info("freeze the decoder module: {}".format(args.decoder_freeze_module))
if args.share_adapter_and_embed and hasattr(encoder.adapter, "embed_adapter"):
encoder.adapter.embed_adapter.weight = decoder_embed_tokens.weight
return cls(encoder, decoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens):
"""
The forward method inherited from the base class has a **kwargs
......@@ -184,7 +246,7 @@ class S2TSATEModel(S2TTransformerModel):
return decoder_out
class TextEncoder(FairseqEncoder):
class TextualEncoder(FairseqEncoder):
def __init__(self, args, dictionary, embed_tokens=None):
super().__init__(None)
......@@ -200,6 +262,10 @@ class TextEncoder(FairseqEncoder):
self.embed_scale = 1.0
self.padding_idx = dictionary.pad_index
self.embed_norm = getattr(args, "embed_norm", False)
if self.embed_norm:
self.embed_ln = LayerNorm(embed_dim)
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
......@@ -228,7 +294,8 @@ class TextEncoder(FairseqEncoder):
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
self.ctc.ctc_projection.weight = embed_tokens.weight
if embed_tokens is not None:
self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
......@@ -250,10 +317,13 @@ class TextEncoder(FairseqEncoder):
self.ctc = CTC(embed_dim,
dictionary_size=len(dictionary),
dropout=args.dropout)
if embed_tokens is not None:
if embed_tokens is not None and args.share_target_ctc_and_embed and \
self.ctc.ctc_projection.weight.size() == embed_tokens.weight.size():
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = {
"embed_norm": getattr(args, "sae_embed_norm", False),
"out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
......@@ -262,13 +332,15 @@ class TextEncoder(FairseqEncoder):
self.sae = Adapter(embed_dim, args.sae_adapter,
len(dictionary),
strategy=strategy)
if args.share_target_ctc_and_sae and hasattr(self.sae, "embed_adapter"):
if args.share_target_sae_and_ctc and hasattr(self.sae, "embed_adapter"):
self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
def forward(self, x, encoder_padding_mask=None, history=None):
if self.embed_norm:
x = self.embed_ln(x)
x = self.embed_scale * x
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
x = positions + x
......@@ -297,7 +369,7 @@ class TextEncoder(FairseqEncoder):
target_interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae([x, prob], encoder_padding_mask)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask)
if history is not None:
history.push(x)
......@@ -313,6 +385,9 @@ class TextEncoder(FairseqEncoder):
return x, target_ctc_logit, target_interleaved_ctc_logits
def reorder_encoder_out(self, encoder_out, new_order):
pass
class S2TSATEEncoder(FairseqEncoder):
"""Speech-to-text Conformer encoder that consists of input subsampler and
......@@ -333,24 +408,28 @@ class S2TSATEEncoder(FairseqEncoder):
# adapter
self.adapter_temperature = args.adapter_temperature
strategy = {
"embed_norm": getattr(args, "adapter_embed_norm", False),
"out_norm": getattr(args, "adapter_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
"distribution_cutoff": getattr(args, "adapter_distribution_cutoff", None),
"drop_prob": getattr(args, "adapter_drop_prob", 0),
}
self.adapter = Adapter(args.encoder_embed_dim,
args.adapter,
len(task.source_dictionary),
decoder_embed_tokens if task.source_dictionary == task.target_dictionary else None,
strategy=strategy)
if args.share_ctc_and_adapter and hasattr(self.adapter, "embed_adapter"):
self.acoustic_encoder.ctc.ctc_projection.weight = self.adapter.embed_adapter.weight
assert not (args.share_adapter_and_ctc and args.share_adapter_and_embed), "Can not be True at the same time"
if args.share_adapter_and_ctc and hasattr(self.adapter, "embed_adapter"):
self.adapter.embed_adapter.weight = self.acoustic_encoder.ctc.ctc_projection.weight
if args.share_adapter_and_embed and hasattr(self.adapter, "embed_adapter"):
self.adapter.embed_adapter.weight = decoder_embed_tokens.weight
acoustic_encoder_attention_type = args.encoder_attention_type
args.encoder_attention_type = args.text_attention_type
# textual encoder
self.text_encoder = TextEncoder(args, task.source_dictionary, decoder_embed_tokens)
self.textual_encoder = TextualEncoder(args, task.source_dictionary, decoder_embed_tokens)
args.encoder_attention_type = acoustic_encoder_attention_type
......@@ -363,7 +442,14 @@ class S2TSATEEncoder(FairseqEncoder):
else:
self.history = None
def forward(self, src_tokens, src_lengths):
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
if hasattr(self, "ctc"):
assert src_dict is not None
self.acoustic_encoder.ctc.set_infer(ctc_infer, post_process, src_dict)
if hasattr(self.textual_encoder, "ctc"):
self.textual_encoder.ctc.set_infer(ctc_infer, post_process, tgt_dict)
def forward(self, src_tokens, src_lengths=None, **kwargs):
if self.history is not None:
self.history.clean()
......@@ -403,11 +489,11 @@ class S2TSATEEncoder(FairseqEncoder):
if self.freeze_textual_encoder:
with torch.no_grad():
x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask,
self.history)
x, target_ctc_logit, target_interleaved_ctc_logits = self.textual_encoder(x, encoder_padding_mask,
self.history)
else:
x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask,
self.history)
x, target_ctc_logit, target_interleaved_ctc_logits = self.textual_encoder(x, encoder_padding_mask,
self.history)
return {
"encoder_out": [x], # T x B x C
......@@ -519,7 +605,9 @@ def base_architecture(args):
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.target_ctc_layer = getattr(args, "target_ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
args.share_target_ctc_and_embed = getattr(args, "share_target_ctc_and_embed", False)
# Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
......@@ -556,8 +644,8 @@ def base_architecture(args):
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.share_target_ctc_and_sae = getattr(args, "share_target_ctc_and_sae", False)
args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False)
args.share_target_sae_and_ctc = getattr(args, "share_target_sae_and_ctc", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
......@@ -599,7 +687,10 @@ def base_architecture(args):
args.adapter_temperature = getattr(args, "adapter_temperature", 1.0)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.text_attention_type = getattr(args, "text_attention_type", "selfattn")
args.share_ctc_and_adapter = getattr(args, "share_ctc_and_adapter", False)
args.share_adapter_and_ctc = getattr(args, "share_adapter_and_ctc", False)
args.share_adapter_and_embed = getattr(args, "share_adapter_and_embed", False)
args.adapter_embed_norm = getattr(args, "adapter_embed_norm", False)
args.adapter_out_norm = getattr(args, "adapter_out_norm", False)
@register_model_architecture("s2t_sate", "s2t_sate_s")
......
......@@ -5,7 +5,6 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
......@@ -397,6 +396,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true",
help="use linear transform after down-sampling",
)
parser.add_argument(
"--embed-norm",
action="store_true",
help="use layer norm after down-sampling",
)
# interleaved CTC layers
parser.add_argument(
......@@ -438,10 +442,22 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="cutoff of the distribution in sae",
)
parser.add_argument(
"--share-ctc-and-sae",
"--share-sae-and-ctc",
action="store_true",
help="share the weight of ctc and sae",
)
parser.add_argument(
"--sae-embed-norm",
default=False,
action="store_true",
help="use the layer norm for embed output",
)
parser.add_argument(
"--sae-out-norm",
default=False,
action="store_true",
help="use the layer norm for final output",
)
# Mixup
parser.add_argument(
......@@ -580,8 +596,11 @@ class S2TTransformerEncoder(FairseqEncoder):
self.subsample = subsampling(args)
self.embed_linear = getattr(args, "embed_linear", False)
self.embed_norm = getattr(args, "embed_norm", False)
if self.embed_linear:
self.linear = nn.Linear(dim, dim)
if self.embed_norm:
self.embed_ln = LayerNorm(dim)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
if self.attn_type == "rel_pos":
......@@ -645,13 +664,16 @@ class S2TTransformerEncoder(FairseqEncoder):
if not self.use_ctc:
self.ctc = CTC(dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout)
dropout=args.dropout,
)
if getattr(args, "share_ctc_and_embed", False) and \
task.source_dictionary == task.target_dictionary and \
embed_tokens is not None and dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = {
"embed_norm": getattr(args, "sae_embed_norm", False),
"out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
......@@ -661,7 +683,7 @@ class S2TTransformerEncoder(FairseqEncoder):
len(task.source_dictionary),
strategy=strategy,
)
if args.share_ctc_and_sae and hasattr(self.sae, "embed_adapter"):
if args.share_sae_and_ctc and hasattr(self.sae, "embed_adapter"):
self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
# mixup
......@@ -682,6 +704,17 @@ class S2TTransformerEncoder(FairseqEncoder):
self.dis = 2
self.cos_sim = dict()
# debug the variance
self.debug_var = False
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
if hasattr(self, "ctc"):
assert src_dict is not None
self.ctc.set_infer(ctc_infer, post_process, src_dict)
def set_debug_var(self, debug_var_flag):
self.debug_var = debug_var_flag
@staticmethod
def pooling_ratio():
return 4
......@@ -730,7 +763,7 @@ class S2TTransformerEncoder(FairseqEncoder):
pad1 = encoder_padding_mask[idx1]
pad2 = encoder_padding_mask[idx2]
encoder_padding_mask = pad1 + pad2
encoder_padding_mask = pad1 & pad2
input_lengths = (~encoder_padding_mask).sum(-1)
mixup = {
......@@ -741,7 +774,15 @@ class S2TTransformerEncoder(FairseqEncoder):
}
return x, encoder_padding_mask, input_lengths, mixup
def forward(self, src_tokens, src_lengths):
def show_debug(self, x, text=None):
if not self.debug_var:
return
if text:
logger.info("--- Variance of %s: %f." % (text, x.var()))
else:
logger.info("--- Variance: %f." % (x.var()))
def forward(self, src_tokens, src_lengths=None, **kwargs):
layer_idx = -1
mixup = None
......@@ -753,6 +794,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x = src_tokens.transpose(0, 1)
input_lengths = src_lengths
self.show_debug(x, "input x")
# gather cosine similarity
cos_sim_idx = -1
dis = self.dis
......@@ -766,11 +808,19 @@ class S2TTransformerEncoder(FairseqEncoder):
# down-sampling
x, input_lengths = self.subsample(x, input_lengths)
self.show_debug(x, "x after subsampling")
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if encoder_padding_mask is not None:
x = x * (1 - encoder_padding_mask.transpose(0, 1).unsqueeze(-1).type_as(x))
if self.embed_norm:
x = self.embed_ln(x)
self.show_debug(x, "x after embed norm")
# embedding scaling
x = self.embed_scale * x
self.show_debug(x, "x after scale")
# position embedding
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
......@@ -783,9 +833,12 @@ class S2TTransformerEncoder(FairseqEncoder):
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
x += positions
positions = None
self.show_debug(x, "x after position embedding")
if self.embed_linear:
x = self.linear(x)
self.show_debug(x, "x after embed linear")
x = self.dropout_module(x)
# add emb into history
......@@ -806,6 +859,7 @@ class S2TTransformerEncoder(FairseqEncoder):
if torch.rand(1) < self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
self.show_debug(x, "x before encoding")
for layer in self.layers:
if self.history is not None:
x = self.history.pop()
......@@ -813,13 +867,14 @@ class S2TTransformerEncoder(FairseqEncoder):
# encoder layer
x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
self.show_debug(x, "x after layer %d" % layer_idx)
if self.training and self.mixup and layer_idx == self.mixup_layer:
if torch.rand(1) < self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone())
ctc_logit = self.ctc(x.clone(), encoder_padding_mask)
# interleaved CTC
if layer_idx in self.interleaved_ctc_layers:
......@@ -829,7 +884,7 @@ class S2TTransformerEncoder(FairseqEncoder):
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
logit = self.ctc(norm_x, encoder_padding_mask)
interleaved_ctc_logits.append(logit)
......@@ -837,7 +892,8 @@ class S2TTransformerEncoder(FairseqEncoder):
max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae([x, prob], encoder_padding_mask)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask)
self.show_debug(x, "x after sae")
# gather cosine similarity
if self.gather_cos_sim:
......@@ -850,11 +906,14 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.history is not None:
x = self.history.pop()
self.show_debug(x, "x after encoding")
if self.layer_norm is not None:
x = self.layer_norm(x)
self.show_debug(x, "x after encoding layer norm")
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x)
ctc_logit = self.ctc(x, encoder_padding_mask)
self.show_debug(x, "x after ctc")
return {
"encoder_out": [x], # T x B x C
......@@ -1005,6 +1064,7 @@ def base_architecture(args):
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.embed_linear = getattr(args, "embed_linear", False)
args.embed_norm = getattr(args, "embed_norm", False)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
......@@ -1045,7 +1105,9 @@ def base_architecture(args):
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False)
args.sae_embed_norm = getattr(args, "sae_embed_norm", False)
args.sae_out_norm = getattr(args, "sae_out_norm", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
......
......@@ -1063,7 +1063,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1]
pad2 = self_attn_padding_mask[idx2]
self_attn_padding_mask = pad1 + pad2
self_attn_padding_mask = pad1 & pad2
# decoder layers
avg_attn = None
......
......@@ -606,6 +606,7 @@ class TransformerCTCEncoder(FairseqEncoder):
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout,
dictionary=dictionary,
need_layernorm=True if self.inter_ctc else False)
self.ctc.ctc_projection.weight = decoder_embed_tokens.weight
......@@ -627,6 +628,7 @@ class TransformerCTCEncoder(FairseqEncoder):
if not self.use_ctc:
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dictionary=dictionary,
dropout=args.dropout)
self.ctc.ctc_projection.weight = decoder_embed_tokens.weight
......@@ -715,6 +717,10 @@ class TransformerCTCEncoder(FairseqEncoder):
x = x.unsqueeze(1).expand(-1, ratio, -1, -1).reshape(-1, bsz, dim)
return x
def set_ctc_infer(self, ctc_infer, post_process):
if hasattr(self, "ctc"):
self.ctc.set_infer(ctc_infer, post_process)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
......@@ -764,12 +770,19 @@ class TransformerCTCEncoder(FairseqEncoder):
# B x T x C -> T x B x C
x = x.transpose(0, 1)
bsz = x.size(1)
encoder_states = []
if return_all_hiddens:
encoder_states.append(x)
org_encoder_padding_mask = encoder_padding_mask
ctc_padding_mask = encoder_padding_mask
if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
ctc_padding_mask = encoder_padding_mask.unsqueeze(-1). \
expand(-1, -1, self.interleaved_ctc_upsampling_ratio).reshape(bsz, -1)
# add emb into history
if self.history is not None:
self.history.push(x)
......@@ -782,6 +795,10 @@ class TransformerCTCEncoder(FairseqEncoder):
if self.history is not None:
x = self.history.pop()
if layer_idx + 1 in self.interleaved_ctc_layers:
x = self.upsampling(x)
encoder_padding_mask = ctc_padding_mask
x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None
)
......@@ -792,7 +809,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(self.upsampling(x.clone()))
ctc_logit = self.ctc(self.upsampling(x.clone()), ctc_padding_mask)
# Interleaved CTC
if layer_idx in self.interleaved_ctc_layers:
......@@ -802,17 +819,17 @@ class TransformerCTCEncoder(FairseqEncoder):
break
norm_x = self.layer_norm(x)
up_x = self.upsampling(norm_x)
up_logit = self.ctc(up_x)
logit = self.ctc(norm_x, ctc_padding_mask)
interleaved_ctc_logits.append(up_logit)
up_prob = utils.softmax(up_logit / self.interleaved_ctc_temperature, dim=-1)
interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
up_prob = up_prob.permute(1, 2, 0)
prob = self.pool(up_prob)
prob = prob.permute(2, 0, 1)
x, _ = self.sae([norm_x, prob])
x, _ = self.sae([x, prob])
x = x.permute(1, 2, 0)
x = self.pool(x)
x = x.permute(2, 0, 1)
encoder_padding_mask = org_encoder_padding_mask
if self.history is not None:
self.history.push(x)
......@@ -824,13 +841,7 @@ class TransformerCTCEncoder(FairseqEncoder):
x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(self.upsampling(x))
ctc_padding_mask = encoder_padding_mask
if ctc_logit is not None or len(interleaved_ctc_logits) != 0:
bsz = encoder_padding_mask.size(0)
ctc_padding_mask = encoder_padding_mask.unsqueeze(-1). \
expand(-1, -1, self.interleaved_ctc_upsampling_ratio).reshape(bsz, -1)
ctc_logit = self.ctc(self.upsampling(x), ctc_padding_mask)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
......@@ -1226,7 +1237,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1]
pad2 = self_attn_padding_mask[idx2]
self_attn_padding_mask = pad1 + pad2
self_attn_padding_mask = pad1 & pad2
# decoder layers
avg_attn = None
......
......@@ -194,7 +194,7 @@ class TransformerS2Encoder(TransformerEncoder):
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_out_s2": [x2], # T x B x C
"encoder_out_s2": [x2], # T x B x C
"encoder_padding_mask_s2": [x2_encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
......@@ -317,7 +317,7 @@ class TransformerS2Decoder(TransformerDecoder):
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1]
pad2 = self_attn_padding_mask[idx2]
self_attn_padding_mask = pad1 + pad2
self_attn_padding_mask = pad1 & pad2
# decoder layers
avg_attn = None
......
......@@ -69,15 +69,19 @@ class Adapter(nn.Module):
if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
self.cal_linear = True
self.linear_adapter = nn.Sequential(
nn.Linear(dim, dim),
LayerNorm(dim),
nn.Linear(dim, 2 * dim),
nn.ReLU(),
nn.Linear(2 * dim, dim),
LayerNorm(dim),
)
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
self.cal_context = True
self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization
nn.init.normal_(self.embed_adapter.weight, mean=0, std=dim ** -0.5)
self.embed_norm = strategy.get("embed_norm", False)
if self.embed_norm:
self.embed_ln = LayerNorm(dim)
if embed_tokens is not None:
self.embed_adapter.weight = embed_tokens.weight
......@@ -90,7 +94,7 @@ class Adapter(nn.Module):
# additional strategy
if self.adapter_type == "shrink":
assert strategy is not None
ctc_compress_strategy = getattr(strategy, "ctc_compress_strategy", "avg")
ctc_compress_strategy = strategy.get("ctc_compress_strategy", "avg")
self.ctc_compress = getattr(CTCCompressStrategy, ctc_compress_strategy)
logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy)
......@@ -103,6 +107,9 @@ class Adapter(nn.Module):
self.drop_prob = strategy.get("drop_prob", 0)
if self.drop_prob != 0:
logger.info("Adapter drop probability: %f" % self.drop_prob)
self.out_norm = strategy.get("out_norm", False)
if self.out_norm:
self.out_ln = LayerNorm(dim)
def forward(self, x, padding=None):
......@@ -119,13 +126,19 @@ class Adapter(nn.Module):
if self.cal_context:
if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), org_distribution.size(-1) - 1)
threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
# threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
# distribution = torch.where(
# org_distribution > threshold, org_distribution, torch.zeros_like(org_distribution)
# )
threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True)
distribution = torch.where(
org_distribution > threshold, org_distribution, torch.zeros_like(org_distribution)
threshold > 0.9, org_distribution, torch.zeros_like(org_distribution)
)
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
if self.embed_norm:
soft_out = self.embed_ln(soft_out)
if self.adapter_type == "linear":
out = linear_out
......@@ -134,7 +147,7 @@ class Adapter(nn.Module):
out = soft_out
elif self.adapter_type == "league":
if self.drop_prob > 0 and torch.rand(1).uniform_() < self.drop_prob:
if self.training and self.drop_prob > 0 and torch.rand(1).uniform_() < self.drop_prob:
if torch.rand(1).uniform_() < 0.5:
out = linear_out
else:
......@@ -178,4 +191,7 @@ class Adapter(nn.Module):
out = None
logging.error("Unsupported adapter type: {}.".format(self.adapter_type))
if self.out_norm:
out = self.out_ln(out)
return out, padding
import logging
import editdistance
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -8,13 +9,15 @@ from fairseq.modules import (
FairseqDropout,
LayerNorm,
)
from fairseq.data.data_utils import post_process
logger = logging.getLogger(__name__)
class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout, need_layernorm=False):
def __init__(self, embed_dim, dictionary_size, dropout,
need_layernorm=False, dictionary=None):
super(CTC, self).__init__()
self.embed_dim = embed_dim
......@@ -26,15 +29,31 @@ class CTC(nn.Module):
self.ctc_dropout_module = FairseqDropout(
p=dropout, module_name=self.__class__.__name__
)
self.need_layernorm = need_layernorm
if self.need_layernorm:
self.LayerNorm = LayerNorm(embed_dim)
def forward(self, x):
self.dictionary = dictionary
self.infer_decoding = False
self.post_process = "sentencepiece"
def set_infer(self, is_infer, text_post_process, dictionary):
self.infer_decoding = is_infer
self.post_process = text_post_process
self.dictionary = dictionary
def forward(self, x, padding=None):
if self.need_layernorm:
x = self.LayerNorm(x)
x = self.ctc_projection(self.ctc_dropout_module(x))
if not self.training and self.infer_decoding:
assert self.dictionary is not None
input_lengths = (~padding).sum(-1)
self.infer(x.transpose(0, 1).float().contiguous().cpu(), input_lengths)
return x
def softmax(self, x, temperature=1.0):
......@@ -45,3 +64,55 @@ class CTC(nn.Module):
def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1)
def infer(self, logits_or_probs, lengths):
for lp, inp_l in zip(
logits_or_probs,
lengths,
):
lp = lp[:inp_l].unsqueeze(0)
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.dictionary.bos()].tolist()
pred_units = self.dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
print(pred_words_raw)
def valid(self, logits_or_probs, target, lengths):
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
for lp, t, inp_l in zip(
logits_or_probs,
target,
lengths,
):
lp = lp[:inp_l].unsqueeze(0)
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
)
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
w_len += len(targ_words)
\ No newline at end of file
......@@ -242,7 +242,11 @@ class TransformerS2EncoderLayer(nn.Module):
x = self.dropout_module(x)
if x2 is not None:
x2, _ = self.s2_attn(x, x2, x2, x2_encoder_padding_mask)
x2, _ = self.s2_attn(
query=x,
key=x2,
value=x2,
key_padding_mask=x2_encoder_padding_mask)
x2 = self.dropout_module(x2)
ratio = self.get_ratio()
x = x * ratio[0] + x2 * ratio[1]
......@@ -557,6 +561,8 @@ class TransformerS2DecoderLayer(nn.Module):
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
# notice here
# self.s2_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
......
......@@ -137,7 +137,8 @@ class BeamSearch(Search):
scores_buf = top_prediction[0]
indices_buf = top_prediction[1]
# Project back into relative indices and beams
beams_buf = indices_buf // vocab_size
# beams_buf = indices_buf // vocab_size
beams_buf = torch.div(indices_buf, vocab_size, rounding_mode='trunc')
indices_buf = indices_buf.fmod(vocab_size)
# At this point, beams_buf and indices_buf are single-dim and contain relative indices
......
......@@ -14,6 +14,8 @@ from fairseq.models import FairseqIncrementalDecoder
from torch import Tensor
from fairseq.ngram_repeat_block import NGramRepeatBlock
from fairseq.models.speech_to_text import S2TDualModel
class SequenceGenerator(nn.Module):
def __init__(
......@@ -197,11 +199,9 @@ class SequenceGenerator(nn.Module):
)
net_input = sample["net_input"]
# if "transcript" in sample:
# text_src_tokens = sample["transcript"]["tokens"]
# text_src_lengths = sample["transcript"]["lengths"]
# net_input["text_src_tokens"] = text_src_tokens
# net_input["text_src_lengths"] = text_src_lengths
if "transcript" in sample:
net_input["text_src_tokens"] = sample["transcript"]["tokens"]
net_input["text_src_lengths"] = sample["transcript"]["lengths"]
if "src_tokens" in net_input:
src_tokens = net_input["src_tokens"]
......@@ -662,7 +662,8 @@ class SequenceGenerator(nn.Module):
idx = bbsz_idx[i]
score = eos_scores[i]
# sentence index in the current (possibly reduced) batch
unfin_idx = idx // beam_size
# unfin_idx = idx // beam_size
unfin_idx = torch.div(idx, beam_size, rounding_mode='trunc')
# sentence index in the original (unreduced) batch
sent = unfin_idx + cum_unfin[unfin_idx]
# Cannot create dict for key type '(int, int)' in torchscript.
......@@ -760,7 +761,18 @@ class EnsembleModel(nn.Module):
def forward_encoder(self, net_input: Dict[str, Tensor]):
if not self.has_encoder():
return None
return [model.encoder.forward_torchscript(net_input) for model in self.models]
encoder_outs = []
for model in self.models:
if not isinstance(model, S2TDualModel):
if "text_src_tokens" in net_input:
net_input.pop("text_src_tokens", None)
if "text_src_lengths" in net_input:
net_input.pop("text_src_lengths", None)
encoder_outs.append(model.encoder.forward_torchscript(net_input))
return encoder_outs
# return [model.encoder.forward_torchscript(net_input) for model in self.models]
@torch.jit.export
def forward_decoder(
......
......@@ -105,6 +105,10 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
for model in models:
if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"):
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, cfg.common_eval.post_process,
src_dict, tgt_dict)
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论