Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
8f084189
Commit
8f084189
authored
Mar 30, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix the bugs and optimize the code
parent
4f679c86
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
158 行增加
和
56 行删除
+158
-56
egs/iwslt14/mt/conf/base.yaml
+2
-2
egs/iwslt14/mt/conf/inter.yaml
+15
-0
egs/iwslt2022/asr/binary.sh
+2
-1
egs/iwslt2022/asr/conf/basis.yaml
+2
-2
egs/iwslt2022/mt/conf/base.yaml
+1
-1
egs/iwslt2022/mt/conf/deep.yaml
+32
-0
egs/iwslt2022/st/conf/basis.yaml
+1
-0
egs/libri_trans/asr/conf/debug.yaml
+22
-17
egs/librispeech/asr/conf/EffecientConformerCTCSmall.yaml
+1
-0
egs/librispeech/asr/conf/basis.yaml
+1
-0
egs/librispeech/asr/conf/purectc_base_compare.yaml
+2
-3
egs/librispeech/asr/conf/purectc_pds_base_8_compare.yaml
+3
-3
egs/librispeech/asr/conf/purectc_pds_base_8_growth_compare.yaml
+1
-0
fairseq/models/speech_to_text/pdss2t_transformer.py
+10
-9
fairseq/models/speech_to_text/s2t_ctc.py
+2
-2
fairseq/models/speech_to_text/s2t_dual.py
+4
-4
fairseq/models/speech_to_text/s2t_sate.py
+48
-8
fairseq/models/speech_to_text/s2t_transformer.py
+5
-0
fairseq/models/wav2vec/wav2vec2_asr.py
+2
-2
fairseq/modules/speech_to_text/adapter.py
+2
-2
没有找到文件。
egs/iwslt14/mt/conf/base.yaml
查看文件 @
8f084189
arch
:
transformer
arch
:
transformer
_ctc
share-all-embeddings
:
True
share-all-embeddings
:
True
optimizer
:
adam
optimizer
:
adam
clip-norm
:
10.0
clip-norm
:
10.0
...
@@ -8,7 +8,7 @@ warmup-updates: 8000
...
@@ -8,7 +8,7 @@ warmup-updates: 8000
lr
:
1e-3
lr
:
1e-3
adam_betas
:
(0.9,0.997)
adam_betas
:
(0.9,0.997)
criterion
:
label_smoothed_cross_entropy
criterion
:
label_smoothed_cross_entropy
_with_ctc
label_smoothing
:
0.1
label_smoothing
:
0.1
dropout
:
0.3
dropout
:
0.3
...
...
egs/iwslt14/mt/conf/inter.yaml
0 → 100644
查看文件 @
8f084189
#ctc-weight: 0.2
intermedia-ctc-weight
:
0.3
intermedia-ctc-layers
:
2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter
:
league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process
:
sentencepiece
\ No newline at end of file
egs/iwslt2022/asr/binary.sh
查看文件 @
8f084189
...
@@ -23,7 +23,8 @@ asr_vocab_prefix=spm_unigram10000_st_share
...
@@ -23,7 +23,8 @@ asr_vocab_prefix=spm_unigram10000_st_share
src_lang
=
en
src_lang
=
en
tgt_lang
=
zh
tgt_lang
=
zh
subsets
=(
train_covost train_eu train_iwslt train_mustc_ende train_voxpopuil train_mustc_enzh dev tst-COMMON
)
subsets
=(
train_covost train_eu train_iwslt train_mustc_ende train_voxpopuil train_mustc_enzh dev tst-COMMON train_ted
)
#subsets=(train_ted)
mkdir
-p
$data_dir
mkdir
-p
$data_dir
splits
=
$(
echo
${
subsets
[*]
}
| sed
's/ /,/g'
)
splits
=
$(
echo
${
subsets
[*]
}
| sed
's/ /,/g'
)
...
...
egs/iwslt2022/asr/conf/basis.yaml
查看文件 @
8f084189
#train-subset: train_covost,train_eu,train_iwslt,train_mustc_ende,train_voxpopuil,train_mustc_enzh
train-subset
:
train_covost,train_eu,train_iwslt,train_mustc_ende,train_voxpopuil,train_mustc_enzh,train_ted,train-clean-100,train-clean-360,train-other-500
train-subset
:
train_mustc_enzh
#
train-subset: train_mustc_enzh
valid-subset
:
dev
valid-subset
:
dev
max-epoch
:
100
max-epoch
:
100
...
...
egs/iwslt2022/mt/conf/base.yaml
查看文件 @
8f084189
arch
:
transformer
arch
:
transformer
share-
decoder-input-output-embed
:
True
share-
all-embeddings
:
True
optimizer
:
adam
optimizer
:
adam
clip-norm
:
10.0
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
lr-scheduler
:
inverse_sqrt
...
...
egs/iwslt2022/mt/conf/deep.yaml
0 → 100644
查看文件 @
8f084189
arch
:
transformer
share-decoder-input-output-embed
:
True
optimizer
:
adam
#clip-norm: 10.0
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-updates
:
1000
lr
:
2e-4
adam_betas
:
(0.9,0.997)
criterion
:
label_smoothed_cross_entropy
label_smoothing
:
0.1
dropout
:
0.1
attention-dropout
:
0.1
activation-dropout
:
0.1
activation-fn
:
relu
encoder-normalize-before
:
True
decoder-normalize-before
:
True
encoder-embed-dim
:
512
encoder-ffn-embed-dim
:
2048
encoder-layers
:
30
decoder-layers
:
6
encoder-attention-heads
:
8
decoder-embed-dim
:
512
decoder-ffn-embed-dim
:
2048
decoder-attention-heads
:
8
load-pretrained-encoder-from
:
/home/xuchen/st/checkpoints/wmt20/mt/0317_unified_lcrm_tok_deep_baseline_pretrain/avg_5_checkpoint.pt
load-pretrained-decoder-from
:
/home/xuchen/st/checkpoints/wmt20/mt/0317_unified_lcrm_tok_deep_baseline_pretrain/avg_5_checkpoint.pt
egs/iwslt2022/st/conf/basis.yaml
查看文件 @
8f084189
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
train-subset
:
train_mustc_enzh
train-subset
:
train_mustc_enzh
valid-subset
:
dev
valid-subset
:
dev
fp16-scale-tolerance
:
0.25
max-epoch
:
100
max-epoch
:
100
max-update
:
100000
max-update
:
100000
patience
:
20
patience
:
20
...
...
egs/libri_trans/asr/conf/debug.yaml
查看文件 @
8f084189
arch
:
pdss2t_transformer_s_8
arch
:
s2t_ctc
pds-fusion
:
True
encoder-type
:
transformer
ctc-layer
:
12
inter_mixup
:
True
inter_mixup_layer
:
0
inter_mixup_ratio
:
0.2
share-decoder-input-output-embed
:
True
optimizer
:
adam
optimizer
:
adam
clip-norm
:
10.0
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-init-lr
:
1e-7
warmup-updates
:
10000
warmup-updates
:
10000
lr
:
2e-3
lr
:
0.0015
adam_betas
:
(0.9,0.98)
adam_betas
:
(0.9,0.98)
criterion
:
label_smoothed_cross_entropy_with_ctc
criterion
:
ctc
label_smoothing
:
0.1
ctc-weight
:
1.0
subsampling-type
:
conv2d
subsampling-layers
:
2
subsampling-filter
:
176
subsampling-kernel
:
3
subsampling-stride
:
2
subsampling-norm
:
batch2d
subsampling-activation
:
swish
dropout
:
0.1
dropout
:
0.1
activation-fn
:
relu
activation-fn
:
relu
encoder-
ffn-embed-dim
:
2048
encoder-
embed-dim
:
176
encoder-
layers
:
12
encoder-
ffn-embed-dim
:
704
decoder-layers
:
6
encoder-layers
:
1
6
encoder-attention-heads
:
4
encoder-attention-heads
:
4
decoder-embed-dim
:
256
macaron-style
:
True
decoder-ffn-embed-dim
:
2048
use-cnn-module
:
True
decoder-attention-heads
:
4
cnn-module-kernel
:
31
encoder-activation-fn
:
swish
encoder-attention-type
:
rel_pos
\ No newline at end of file
egs/librispeech/asr/conf/EffecientConformerCTCSmall.yaml
查看文件 @
8f084189
...
@@ -43,6 +43,7 @@ lr: 0.0015
...
@@ -43,6 +43,7 @@ lr: 0.0015
adam_betas
:
(0.9,0.98)
adam_betas
:
(0.9,0.98)
criterion
:
ctc
criterion
:
ctc
ctc-weight
:
1.0
post-process
:
sentencepiece
post-process
:
sentencepiece
dropout
:
0.1
dropout
:
0.1
...
...
egs/librispeech/asr/conf/basis.yaml
查看文件 @
8f084189
...
@@ -7,6 +7,7 @@ patience: 20
...
@@ -7,6 +7,7 @@ patience: 20
best-checkpoint-metric
:
loss
best-checkpoint-metric
:
loss
maximize-best-checkpoint-metric
:
False
maximize-best-checkpoint-metric
:
False
post-process
:
sentencepiece
no-epoch-checkpoints
:
True
no-epoch-checkpoints
:
True
#keep-last-epochs: 10
#keep-last-epochs: 10
keep-best-checkpoints
:
10
keep-best-checkpoints
:
10
...
...
egs/librispeech/asr/conf/
ConformerCTCSmall
.yaml
→
egs/librispeech/asr/conf/
purectc_base_compare
.yaml
查看文件 @
8f084189
...
@@ -2,16 +2,15 @@ arch: s2t_ctc
...
@@ -2,16 +2,15 @@ arch: s2t_ctc
encoder-type
:
transformer
encoder-type
:
transformer
optimizer
:
adam
optimizer
:
adam
#
clip-norm: 10.0
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-init-lr
:
1e-7
warmup-updates
:
10000
warmup-updates
:
10000
weight-decay
:
1e-6
lr
:
0.0015
lr
:
0.0015
adam_betas
:
(0.9,0.98)
adam_betas
:
(0.9,0.98)
criterion
:
ctc
criterion
:
ctc
post-process
:
sentencepiece
ctc-weight
:
1.0
subsampling-type
:
conv2d
subsampling-type
:
conv2d
subsampling-layers
:
2
subsampling-layers
:
2
...
...
egs/librispeech/asr/conf/purectc_pds_base_8_compare.yaml
查看文件 @
8f084189
...
@@ -12,7 +12,7 @@ encoder-type: pds
...
@@ -12,7 +12,7 @@ encoder-type: pds
encoder-embed-dim
:
176
encoder-embed-dim
:
176
pds-stages
:
4
pds-stages
:
4
ctc-layer
:
16
#
ctc-layer: 16
pds-layers
:
4_4_4_4
pds-layers
:
4_4_4_4
pds-ratios
:
2_2_1_2
pds-ratios
:
2_2_1_2
pds-fusion
:
True
pds-fusion
:
True
...
@@ -38,11 +38,11 @@ post-process: sentencepiece
...
@@ -38,11 +38,11 @@ post-process: sentencepiece
dropout
:
0.1
dropout
:
0.1
activation-fn
:
relu
activation-fn
:
relu
encoder-layers
:
1
2
encoder-layers
:
1
6
macaron-style
:
True
macaron-style
:
True
use-cnn-module
:
True
use-cnn-module
:
True
cnn-module-kernel
:
31
cnn-module-kernel
:
15
encoder-activation-fn
:
swish
encoder-activation-fn
:
swish
encoder-attention-type
:
rel_pos
encoder-attention-type
:
rel_pos
...
...
egs/librispeech/asr/conf/purectc_pds_base_8_
compare2
.yaml
→
egs/librispeech/asr/conf/purectc_pds_base_8_
growth_compare
.yaml
查看文件 @
8f084189
...
@@ -34,6 +34,7 @@ lr: 0.0015
...
@@ -34,6 +34,7 @@ lr: 0.0015
adam_betas
:
(0.9,0.98)
adam_betas
:
(0.9,0.98)
criterion
:
ctc
criterion
:
ctc
ctc-weight
:
1.0
post-process
:
sentencepiece
post-process
:
sentencepiece
dropout
:
0.1
dropout
:
0.1
...
...
fairseq/models/speech_to_text/pdss2t_transformer.py
查看文件 @
8f084189
...
@@ -696,8 +696,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -696,8 +696,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
fusion_stages_num
=
0
fusion_stages_num
=
0
self
.
fusion_stages_num
=
fusion_stages_num
self
.
fusion_stages_num
=
fusion_stages_num
args
.
pds_ctc
=
getattr
(
args
,
"pds_ctc"
,
"0_0_0_0"
)
args
.
pds_ctc
=
getattr
(
args
,
"pds_ctc"
,
None
)
self
.
pds_ctc
=
[
int
(
n
)
for
n
in
args
.
pds_ctc
.
split
(
"_"
)]
self
.
pds_ctc
=
[
int
(
n
)
for
n
in
args
.
pds_ctc
.
split
(
"_"
)]
if
args
.
pds_ctc
is
not
None
else
None
inter_ctc_module
=
None
inter_ctc_module
=
None
inter_adapter
=
None
inter_adapter
=
None
...
@@ -708,11 +708,12 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -708,11 +708,12 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
embed_dim
=
self
.
pds_embed_dims
[
i
]
embed_dim
=
self
.
pds_embed_dims
[
i
]
kernel_size
=
self
.
pds_kernel_sizes
[
i
]
kernel_size
=
self
.
pds_kernel_sizes
[
i
]
use_pos_embed
=
self
.
pds_position_embed
[
i
]
use_pos_embed
=
self
.
pds_position_embed
[
i
]
use_ctc
=
self
.
pds_ctc
[
i
]
use_ctc
=
self
.
pds_ctc
[
i
]
if
self
.
pds_ctc
is
not
None
else
False
ffn_ratio
=
self
.
pds_ffn_ratios
[
i
]
ffn_ratio
=
self
.
pds_ffn_ratios
[
i
]
num_head
=
self
.
pds_attn_heads
[
i
]
num_head
=
self
.
pds_attn_heads
[
i
]
attn_ds_ratio
=
self
.
pds_attn_ds_ratios
[
i
]
# if self.attn_type == "reduced" else -1
attn_ds_ratio
=
self
.
pds_attn_ds_ratios
[
i
]
\
if
self
.
pds_conv_strides
is
not
None
and
self
.
attn_type
==
"reduced"
else
1
conv_stride
=
self
.
pds_conv_strides
[
i
]
if
self
.
pds_conv_strides
is
not
None
else
1
conv_stride
=
self
.
pds_conv_strides
[
i
]
if
self
.
pds_conv_strides
is
not
None
else
1
attn_stride
=
self
.
pds_attn_strides
[
i
]
if
self
.
pds_attn_strides
is
not
None
else
1
attn_stride
=
self
.
pds_attn_strides
[
i
]
if
self
.
pds_attn_strides
is
not
None
else
1
if
conv_stride
!=
1
or
attn_stride
!=
1
:
if
conv_stride
!=
1
or
attn_stride
!=
1
:
...
@@ -1231,15 +1232,15 @@ def base_architecture(args):
...
@@ -1231,15 +1232,15 @@ def base_architecture(args):
args
.
pds_ds_method
=
getattr
(
args
,
"pds_ds_method"
,
"conv"
)
args
.
pds_ds_method
=
getattr
(
args
,
"pds_ds_method"
,
"conv"
)
args
.
pds_embed_dims
=
getattr
(
args
,
"pds_embed_dims"
,
None
)
args
.
pds_embed_dims
=
getattr
(
args
,
"pds_embed_dims"
,
None
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
Tru
e
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
Fals
e
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
"1_1_1_1"
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
None
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
None
)
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
...
@@ -1248,7 +1249,7 @@ def base_architecture(args):
...
@@ -1248,7 +1249,7 @@ def base_architecture(args):
args
.
pds_fusion_method
=
getattr
(
args
,
"pds_fusion_method"
,
"all_conv"
)
args
.
pds_fusion_method
=
getattr
(
args
,
"pds_fusion_method"
,
"all_conv"
)
# intermedia CTC
# intermedia CTC
args
.
pds_ctc
=
getattr
(
args
,
"pds_ctc"
,
"0_0_0_0"
)
args
.
pds_ctc
=
getattr
(
args
,
"pds_ctc"
,
None
)
args
.
intermedia_adapter
=
getattr
(
args
,
"intermedia_adapter"
,
"none"
)
args
.
intermedia_adapter
=
getattr
(
args
,
"intermedia_adapter"
,
"none"
)
args
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
args
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
8f084189
...
@@ -558,6 +558,7 @@ class S2TCTCEncoder(FairseqEncoder):
...
@@ -558,6 +558,7 @@ class S2TCTCEncoder(FairseqEncoder):
def
__init__
(
self
,
args
,
task
=
None
):
def
__init__
(
self
,
args
,
task
=
None
):
super
()
.
__init__
(
None
)
super
()
.
__init__
(
None
)
setattr
(
args
,
"ctc_weight"
,
1.0
)
encoder_type
=
getattr
(
args
,
"encoder_type"
,
"transformer"
)
encoder_type
=
getattr
(
args
,
"encoder_type"
,
"transformer"
)
if
encoder_type
==
"transformer"
:
if
encoder_type
==
"transformer"
:
from
.s2t_transformer
import
S2TTransformerEncoder
from
.s2t_transformer
import
S2TTransformerEncoder
...
@@ -575,8 +576,7 @@ class S2TCTCEncoder(FairseqEncoder):
...
@@ -575,8 +576,7 @@ class S2TCTCEncoder(FairseqEncoder):
return
self
.
encoder
(
src_tokens
,
src_lengths
,
**
kwargs
)
return
self
.
encoder
(
src_tokens
,
src_lengths
,
**
kwargs
)
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
self
.
encoder
.
reorder_encoder_out
(
encoder_out
,
new_order
)
return
self
.
encoder
.
reorder_encoder_out
(
encoder_out
,
new_order
)
return
class
CTCDecoder
(
object
):
class
CTCDecoder
(
object
):
...
...
fairseq/models/speech_to_text/s2t_dual.py
查看文件 @
8f084189
...
@@ -494,15 +494,15 @@ def base_architecture(args):
...
@@ -494,15 +494,15 @@ def base_architecture(args):
args
.
pds_ds_method
=
getattr
(
args
,
"pds_ds_method"
,
"conv"
)
args
.
pds_ds_method
=
getattr
(
args
,
"pds_ds_method"
,
"conv"
)
args
.
pds_embed_dims
=
getattr
(
args
,
"pds_embed_dims"
,
None
)
args
.
pds_embed_dims
=
getattr
(
args
,
"pds_embed_dims"
,
None
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
Tru
e
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
Fals
e
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
"1_1_1_1"
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
None
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
None
)
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
...
...
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
8f084189
...
@@ -112,6 +112,22 @@ class S2TSATEModel(S2TTransformerModel):
...
@@ -112,6 +112,22 @@ class S2TSATEModel(S2TTransformerModel):
type
=
str
,
type
=
str
,
help
=
"intermedia ctc layers for target sentence"
,
help
=
"intermedia ctc layers for target sentence"
,
)
)
# freeze
parser
.
add_argument
(
"--freeze-acoustic-encoder"
,
action
=
"store_true"
,
help
=
"freeze the parameters of the acoustic encoder"
,
)
parser
.
add_argument
(
"--freeze-textual-encoder"
,
action
=
"store_true"
,
help
=
"freeze the parameters of the acoustic encoder"
,
)
parser
.
add_argument
(
"--freeze-decoder"
,
action
=
"store_true"
,
help
=
"freeze the parameters of the decoder"
,
)
pass
pass
@classmethod
@classmethod
...
@@ -150,6 +166,18 @@ class S2TSATEModel(S2TTransformerModel):
...
@@ -150,6 +166,18 @@ class S2TSATEModel(S2TTransformerModel):
return
encoder
return
encoder
def
forward
(
self
,
src_tokens
,
src_lengths
,
prev_output_tokens
):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out
=
self
.
encoder
(
src_tokens
=
src_tokens
,
src_lengths
=
src_lengths
)
decoder_out
=
self
.
decoder
(
prev_output_tokens
=
prev_output_tokens
,
encoder_out
=
encoder_out
)
return
decoder_out
class
TextEncoder
(
FairseqEncoder
):
class
TextEncoder
(
FairseqEncoder
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
=
None
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
=
None
):
...
@@ -224,7 +252,8 @@ class TextEncoder(FairseqEncoder):
...
@@ -224,7 +252,8 @@ class TextEncoder(FairseqEncoder):
elif
args
.
intermedia_adapter
==
"league"
:
elif
args
.
intermedia_adapter
==
"league"
:
strategy
=
getattr
(
args
,
"intermedia_distribution_cutoff"
,
None
)
strategy
=
getattr
(
args
,
"intermedia_distribution_cutoff"
,
None
)
self
.
adapter
=
Adapter
(
embed_dim
,
args
.
intermedia_adapter
,
self
.
adapter
=
Adapter
(
embed_dim
,
args
.
intermedia_adapter
,
len
(
dictionary
),
embed_tokens
=
embed_tokens
,
len
(
dictionary
),
# embed_tokens=embed_tokens,
strategy
=
strategy
)
strategy
=
strategy
)
self
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
self
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
self
.
intermedia_temperature
=
getattr
(
args
,
"intermedia_temperature"
,
1
)
self
.
intermedia_temperature
=
getattr
(
args
,
"intermedia_temperature"
,
1
)
...
@@ -294,7 +323,7 @@ class S2TSATEEncoder(FairseqEncoder):
...
@@ -294,7 +323,7 @@ class S2TSATEEncoder(FairseqEncoder):
# adapter
# adapter
self
.
temperature
=
args
.
temperature
self
.
temperature
=
args
.
temperature
# self.adapter = Adapter(args, task.source_dictionary, embed_tokens)
strategy
=
None
strategy
=
None
if
args
.
adapter
==
"shrink"
:
if
args
.
adapter
==
"shrink"
:
strategy
=
getattr
(
args
,
"ctc_compress_strategy"
,
"avg"
)
strategy
=
getattr
(
args
,
"ctc_compress_strategy"
,
"avg"
)
...
@@ -318,6 +347,9 @@ class S2TSATEEncoder(FairseqEncoder):
...
@@ -318,6 +347,9 @@ class S2TSATEEncoder(FairseqEncoder):
args
.
encoder_attention_type
=
acoustic_encoder_attention_type
args
.
encoder_attention_type
=
acoustic_encoder_attention_type
self
.
freeze_acoustic_encoder
=
getattr
(
args
,
"freeze_acoustic_encoder"
,
False
)
self
.
freeze_textual_encoder
=
getattr
(
args
,
"freeze_textual_encoder"
,
False
)
if
getattr
(
args
,
"use_enc_dlcl"
,
False
):
if
getattr
(
args
,
"use_enc_dlcl"
,
False
):
layer_num
=
args
.
encoder_layers
+
args
.
text_encoder_layers
+
2
layer_num
=
args
.
encoder_layers
+
args
.
text_encoder_layers
+
2
self
.
history
=
DynamicLinearCombination
(
args
,
is_encoder
=
True
,
layer_num
=
layer_num
)
self
.
history
=
DynamicLinearCombination
(
args
,
is_encoder
=
True
,
layer_num
=
layer_num
)
...
@@ -328,7 +360,11 @@ class S2TSATEEncoder(FairseqEncoder):
...
@@ -328,7 +360,11 @@ class S2TSATEEncoder(FairseqEncoder):
if
self
.
history
is
not
None
:
if
self
.
history
is
not
None
:
self
.
history
.
clean
()
self
.
history
.
clean
()
acoustic_encoder_out
=
self
.
acoustic_encoder
(
src_tokens
,
src_lengths
)
if
self
.
freeze_acoustic_encoder
:
with
torch
.
no_grad
():
acoustic_encoder_out
=
self
.
acoustic_encoder
(
src_tokens
,
src_lengths
)
else
:
acoustic_encoder_out
=
self
.
acoustic_encoder
(
src_tokens
,
src_lengths
)
encoder_out
=
acoustic_encoder_out
[
"encoder_out"
][
0
]
encoder_out
=
acoustic_encoder_out
[
"encoder_out"
][
0
]
encoder_padding_mask
=
acoustic_encoder_out
[
"encoder_padding_mask"
][
0
]
encoder_padding_mask
=
acoustic_encoder_out
[
"encoder_padding_mask"
][
0
]
...
@@ -354,7 +390,11 @@ class S2TSATEEncoder(FairseqEncoder):
...
@@ -354,7 +390,11 @@ class S2TSATEEncoder(FairseqEncoder):
self
.
history
.
push
(
x
)
self
.
history
.
push
(
x
)
x
,
target_ctc_logit
,
target_intermedia_ctc_logits
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
self
.
history
)
if
self
.
freeze_textual_encoder
:
with
torch
.
no_grad
():
x
,
target_ctc_logit
,
target_intermedia_ctc_logits
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
self
.
history
)
else
:
x
,
target_ctc_logit
,
target_intermedia_ctc_logits
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
self
.
history
)
return
{
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_out"
:
[
x
],
# T x B x C
...
@@ -482,15 +522,15 @@ def base_architecture(args):
...
@@ -482,15 +522,15 @@ def base_architecture(args):
args
.
pds_ds_method
=
getattr
(
args
,
"pds_ds_method"
,
"conv"
)
args
.
pds_ds_method
=
getattr
(
args
,
"pds_ds_method"
,
"conv"
)
args
.
pds_embed_dims
=
getattr
(
args
,
"pds_embed_dims"
,
None
)
args
.
pds_embed_dims
=
getattr
(
args
,
"pds_embed_dims"
,
None
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
Tru
e
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
Fals
e
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
"1_1_1_1"
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
None
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
None
)
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
8f084189
...
@@ -598,6 +598,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -598,6 +598,7 @@ class S2TTransformerEncoder(FairseqEncoder):
strategy
=
getattr
(
args
,
"intermedia_distribution_cutoff"
,
None
)
strategy
=
getattr
(
args
,
"intermedia_distribution_cutoff"
,
None
)
self
.
adapter
=
Adapter
(
dim
,
args
.
intermedia_adapter
,
self
.
adapter
=
Adapter
(
dim
,
args
.
intermedia_adapter
,
len
(
task
.
source_dictionary
),
strategy
=
strategy
)
len
(
task
.
source_dictionary
),
strategy
=
strategy
)
# embed_tokens=embed_tokens if embed_tokens is not None else self.ctc.ctc_projection)
self
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
self
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
self
.
intermedia_temperature
=
getattr
(
args
,
"intermedia_temperature"
,
1
)
self
.
intermedia_temperature
=
getattr
(
args
,
"intermedia_temperature"
,
1
)
...
@@ -700,6 +701,10 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -700,6 +701,10 @@ class S2TTransformerEncoder(FairseqEncoder):
logit
=
self
.
ctc
(
norm_x
)
logit
=
self
.
ctc
(
norm_x
)
intermedia_ctc_logits
.
append
(
logit
)
intermedia_ctc_logits
.
append
(
logit
)
logit
=
logit
.
clamp
(
min
=-
1e8
if
logit
.
dtype
==
torch
.
float32
else
-
1e4
,
max
=
1e8
if
logit
.
dtype
==
torch
.
float32
else
1e4
)
prob
=
utils
.
softmax
(
logit
/
self
.
intermedia_temperature
,
dim
=-
1
)
prob
=
utils
.
softmax
(
logit
/
self
.
intermedia_temperature
,
dim
=-
1
)
x
,
encoder_padding_mask
=
self
.
adapter
([
x
,
prob
],
encoder_padding_mask
)
x
,
encoder_padding_mask
=
self
.
adapter
([
x
,
prob
],
encoder_padding_mask
)
...
...
fairseq/models/wav2vec/wav2vec2_asr.py
查看文件 @
8f084189
...
@@ -161,8 +161,8 @@ class Wav2VecCtc(BaseFairseqModel):
...
@@ -161,8 +161,8 @@ class Wav2VecCtc(BaseFairseqModel):
padding
=
net_output
[
"padding_mask"
]
padding
=
net_output
[
"padding_mask"
]
if
padding
is
not
None
and
padding
.
any
():
if
padding
is
not
None
and
padding
.
any
():
padding
=
padding
.
T
padding
=
padding
.
T
logits
[
padding
][
...
,
0
]
=
0
logits
[
padding
][
...
,
0
]
=
0
logits
[
padding
][
...
,
1
:]
=
float
(
'-inf'
)
logits
[
padding
][
...
,
1
:]
=
float
(
'-inf'
)
return
logits
return
logits
...
...
fairseq/modules/speech_to_text/adapter.py
查看文件 @
8f084189
...
@@ -130,7 +130,7 @@ class Adapter(nn.Module):
...
@@ -130,7 +130,7 @@ class Adapter(nn.Module):
out
=
coef
*
linear_out
+
(
1
-
coef
)
*
soft_out
out
=
coef
*
linear_out
+
(
1
-
coef
)
*
soft_out
elif
self
.
adapter_type
==
"inter_league"
:
elif
self
.
adapter_type
==
"inter_league"
:
soft_out
=
torch
.
mm
(
distribution
,
self
.
embed_adapter
.
weight
.
t
()
)
.
view
(
seq_len
,
bsz
,
-
1
)
soft_out
=
torch
.
mm
(
distribution
,
self
.
embed_adapter
.
weight
)
.
view
(
seq_len
,
bsz
,
-
1
)
out
=
representation
+
soft_out
out
=
representation
+
soft_out
elif
self
.
adapter_type
==
"none"
:
elif
self
.
adapter_type
==
"none"
:
...
@@ -153,7 +153,7 @@ class Adapter(nn.Module):
...
@@ -153,7 +153,7 @@ class Adapter(nn.Module):
# x is T x B x C -> B x C x T; weights_matrix is B x T x T'
# x is T x B x C -> B x C x T; weights_matrix is B x T x T'
representation
=
representation
.
permute
(
1
,
2
,
0
)
representation
=
representation
.
permute
(
1
,
2
,
0
)
compressed_output
=
representation
.
float
()
.
bmm
(
weights_matrix
)
.
type_as
(
representation
)
# B x C x T'
compressed_output
=
representation
.
bmm
(
weights_matrix
)
.
type_as
(
representation
)
# B x C x T'
out
=
compressed_output
.
permute
(
2
,
0
,
1
)
out
=
compressed_output
.
permute
(
2
,
0
,
1
)
out_lengths
=
lengths
.
new
(
new_lengths
)
out_lengths
=
lengths
.
new
(
new_lengths
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论