Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
8f084189
Commit
8f084189
authored
Mar 30, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix the bugs and optimize the code
parent
4f679c86
显示空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
156 行增加
和
54 行删除
+156
-54
egs/iwslt14/mt/conf/base.yaml
+2
-2
egs/iwslt14/mt/conf/inter.yaml
+15
-0
egs/iwslt2022/asr/binary.sh
+2
-1
egs/iwslt2022/asr/conf/basis.yaml
+2
-2
egs/iwslt2022/mt/conf/base.yaml
+1
-1
egs/iwslt2022/mt/conf/deep.yaml
+32
-0
egs/iwslt2022/st/conf/basis.yaml
+1
-0
egs/libri_trans/asr/conf/debug.yaml
+22
-17
egs/librispeech/asr/conf/EffecientConformerCTCSmall.yaml
+1
-0
egs/librispeech/asr/conf/basis.yaml
+1
-0
egs/librispeech/asr/conf/purectc_base_compare.yaml
+2
-3
egs/librispeech/asr/conf/purectc_pds_base_8_compare.yaml
+3
-3
egs/librispeech/asr/conf/purectc_pds_base_8_growth_compare.yaml
+1
-0
fairseq/models/speech_to_text/pdss2t_transformer.py
+10
-9
fairseq/models/speech_to_text/s2t_ctc.py
+2
-2
fairseq/models/speech_to_text/s2t_dual.py
+4
-4
fairseq/models/speech_to_text/s2t_sate.py
+46
-6
fairseq/models/speech_to_text/s2t_transformer.py
+5
-0
fairseq/models/wav2vec/wav2vec2_asr.py
+2
-2
fairseq/modules/speech_to_text/adapter.py
+2
-2
没有找到文件。
egs/iwslt14/mt/conf/base.yaml
查看文件 @
8f084189
arch
:
transformer
arch
:
transformer
_ctc
share-all-embeddings
:
True
optimizer
:
adam
clip-norm
:
10.0
...
...
@@ -8,7 +8,7 @@ warmup-updates: 8000
lr
:
1e-3
adam_betas
:
(0.9,0.997)
criterion
:
label_smoothed_cross_entropy
criterion
:
label_smoothed_cross_entropy
_with_ctc
label_smoothing
:
0.1
dropout
:
0.3
...
...
egs/iwslt14/mt/conf/inter.yaml
0 → 100644
查看文件 @
8f084189
#ctc-weight: 0.2
intermedia-ctc-weight
:
0.3
intermedia-ctc-layers
:
2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter
:
league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process
:
sentencepiece
\ No newline at end of file
egs/iwslt2022/asr/binary.sh
查看文件 @
8f084189
...
...
@@ -23,7 +23,8 @@ asr_vocab_prefix=spm_unigram10000_st_share
src_lang
=
en
tgt_lang
=
zh
subsets
=(
train_covost train_eu train_iwslt train_mustc_ende train_voxpopuil train_mustc_enzh dev tst-COMMON
)
subsets
=(
train_covost train_eu train_iwslt train_mustc_ende train_voxpopuil train_mustc_enzh dev tst-COMMON train_ted
)
#subsets=(train_ted)
mkdir
-p
$data_dir
splits
=
$(
echo
${
subsets
[*]
}
| sed
's/ /,/g'
)
...
...
egs/iwslt2022/asr/conf/basis.yaml
查看文件 @
8f084189
#train-subset: train_covost,train_eu,train_iwslt,train_mustc_ende,train_voxpopuil,train_mustc_enzh
train-subset
:
train_mustc_enzh
train-subset
:
train_covost,train_eu,train_iwslt,train_mustc_ende,train_voxpopuil,train_mustc_enzh,train_ted,train-clean-100,train-clean-360,train-other-500
#
train-subset: train_mustc_enzh
valid-subset
:
dev
max-epoch
:
100
...
...
egs/iwslt2022/mt/conf/base.yaml
查看文件 @
8f084189
arch
:
transformer
share-
decoder-input-output-embed
:
True
share-
all-embeddings
:
True
optimizer
:
adam
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
...
...
egs/iwslt2022/mt/conf/deep.yaml
0 → 100644
查看文件 @
8f084189
arch
:
transformer
share-decoder-input-output-embed
:
True
optimizer
:
adam
#clip-norm: 10.0
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-updates
:
1000
lr
:
2e-4
adam_betas
:
(0.9,0.997)
criterion
:
label_smoothed_cross_entropy
label_smoothing
:
0.1
dropout
:
0.1
attention-dropout
:
0.1
activation-dropout
:
0.1
activation-fn
:
relu
encoder-normalize-before
:
True
decoder-normalize-before
:
True
encoder-embed-dim
:
512
encoder-ffn-embed-dim
:
2048
encoder-layers
:
30
decoder-layers
:
6
encoder-attention-heads
:
8
decoder-embed-dim
:
512
decoder-ffn-embed-dim
:
2048
decoder-attention-heads
:
8
load-pretrained-encoder-from
:
/home/xuchen/st/checkpoints/wmt20/mt/0317_unified_lcrm_tok_deep_baseline_pretrain/avg_5_checkpoint.pt
load-pretrained-decoder-from
:
/home/xuchen/st/checkpoints/wmt20/mt/0317_unified_lcrm_tok_deep_baseline_pretrain/avg_5_checkpoint.pt
egs/iwslt2022/st/conf/basis.yaml
查看文件 @
8f084189
...
...
@@ -2,6 +2,7 @@
train-subset
:
train_mustc_enzh
valid-subset
:
dev
fp16-scale-tolerance
:
0.25
max-epoch
:
100
max-update
:
100000
patience
:
20
...
...
egs/libri_trans/asr/conf/debug.yaml
查看文件 @
8f084189
arch
:
pdss2t_transformer_s_8
pds-fusion
:
True
ctc-layer
:
12
arch
:
s2t_ctc
encoder-type
:
transformer
inter_mixup
:
True
inter_mixup_layer
:
0
inter_mixup_ratio
:
0.2
share-decoder-input-output-embed
:
True
optimizer
:
adam
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-updates
:
10000
lr
:
2e-3
lr
:
0.0015
adam_betas
:
(0.9,0.98)
criterion
:
label_smoothed_cross_entropy_with_ctc
label_smoothing
:
0.1
criterion
:
ctc
ctc-weight
:
1.0
subsampling-type
:
conv2d
subsampling-layers
:
2
subsampling-filter
:
176
subsampling-kernel
:
3
subsampling-stride
:
2
subsampling-norm
:
batch2d
subsampling-activation
:
swish
dropout
:
0.1
activation-fn
:
relu
encoder-
ffn-embed-dim
:
2048
encoder-
layers
:
12
decoder-layers
:
6
encoder-
embed-dim
:
176
encoder-
ffn-embed-dim
:
704
encoder-layers
:
1
6
encoder-attention-heads
:
4
decoder-embed-dim
:
256
decoder-ffn-embed-dim
:
2048
decoder-attention-heads
:
4
macaron-style
:
True
use-cnn-module
:
True
cnn-module-kernel
:
31
encoder-activation-fn
:
swish
encoder-attention-type
:
rel_pos
\ No newline at end of file
egs/librispeech/asr/conf/EffecientConformerCTCSmall.yaml
查看文件 @
8f084189
...
...
@@ -43,6 +43,7 @@ lr: 0.0015
adam_betas
:
(0.9,0.98)
criterion
:
ctc
ctc-weight
:
1.0
post-process
:
sentencepiece
dropout
:
0.1
...
...
egs/librispeech/asr/conf/basis.yaml
查看文件 @
8f084189
...
...
@@ -7,6 +7,7 @@ patience: 20
best-checkpoint-metric
:
loss
maximize-best-checkpoint-metric
:
False
post-process
:
sentencepiece
no-epoch-checkpoints
:
True
#keep-last-epochs: 10
keep-best-checkpoints
:
10
...
...
egs/librispeech/asr/conf/
ConformerCTCSmall
.yaml
→
egs/librispeech/asr/conf/
purectc_base_compare
.yaml
查看文件 @
8f084189
...
...
@@ -2,16 +2,15 @@ arch: s2t_ctc
encoder-type
:
transformer
optimizer
:
adam
#
clip-norm: 10.0
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-updates
:
10000
weight-decay
:
1e-6
lr
:
0.0015
adam_betas
:
(0.9,0.98)
criterion
:
ctc
post-process
:
sentencepiece
ctc-weight
:
1.0
subsampling-type
:
conv2d
subsampling-layers
:
2
...
...
egs/librispeech/asr/conf/purectc_pds_base_8_compare.yaml
查看文件 @
8f084189
...
...
@@ -12,7 +12,7 @@ encoder-type: pds
encoder-embed-dim
:
176
pds-stages
:
4
ctc-layer
:
16
#
ctc-layer: 16
pds-layers
:
4_4_4_4
pds-ratios
:
2_2_1_2
pds-fusion
:
True
...
...
@@ -38,11 +38,11 @@ post-process: sentencepiece
dropout
:
0.1
activation-fn
:
relu
encoder-layers
:
1
2
encoder-layers
:
1
6
macaron-style
:
True
use-cnn-module
:
True
cnn-module-kernel
:
31
cnn-module-kernel
:
15
encoder-activation-fn
:
swish
encoder-attention-type
:
rel_pos
...
...
egs/librispeech/asr/conf/purectc_pds_base_8_
compare2
.yaml
→
egs/librispeech/asr/conf/purectc_pds_base_8_
growth_compare
.yaml
查看文件 @
8f084189
...
...
@@ -34,6 +34,7 @@ lr: 0.0015
adam_betas
:
(0.9,0.98)
criterion
:
ctc
ctc-weight
:
1.0
post-process
:
sentencepiece
dropout
:
0.1
...
...
fairseq/models/speech_to_text/pdss2t_transformer.py
查看文件 @
8f084189
...
...
@@ -696,8 +696,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
fusion_stages_num
=
0
self
.
fusion_stages_num
=
fusion_stages_num
args
.
pds_ctc
=
getattr
(
args
,
"pds_ctc"
,
"0_0_0_0"
)
self
.
pds_ctc
=
[
int
(
n
)
for
n
in
args
.
pds_ctc
.
split
(
"_"
)]
args
.
pds_ctc
=
getattr
(
args
,
"pds_ctc"
,
None
)
self
.
pds_ctc
=
[
int
(
n
)
for
n
in
args
.
pds_ctc
.
split
(
"_"
)]
if
args
.
pds_ctc
is
not
None
else
None
inter_ctc_module
=
None
inter_adapter
=
None
...
...
@@ -708,11 +708,12 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
embed_dim
=
self
.
pds_embed_dims
[
i
]
kernel_size
=
self
.
pds_kernel_sizes
[
i
]
use_pos_embed
=
self
.
pds_position_embed
[
i
]
use_ctc
=
self
.
pds_ctc
[
i
]
use_ctc
=
self
.
pds_ctc
[
i
]
if
self
.
pds_ctc
is
not
None
else
False
ffn_ratio
=
self
.
pds_ffn_ratios
[
i
]
num_head
=
self
.
pds_attn_heads
[
i
]
attn_ds_ratio
=
self
.
pds_attn_ds_ratios
[
i
]
# if self.attn_type == "reduced" else -1
attn_ds_ratio
=
self
.
pds_attn_ds_ratios
[
i
]
\
if
self
.
pds_conv_strides
is
not
None
and
self
.
attn_type
==
"reduced"
else
1
conv_stride
=
self
.
pds_conv_strides
[
i
]
if
self
.
pds_conv_strides
is
not
None
else
1
attn_stride
=
self
.
pds_attn_strides
[
i
]
if
self
.
pds_attn_strides
is
not
None
else
1
if
conv_stride
!=
1
or
attn_stride
!=
1
:
...
...
@@ -1231,15 +1232,15 @@ def base_architecture(args):
args
.
pds_ds_method
=
getattr
(
args
,
"pds_ds_method"
,
"conv"
)
args
.
pds_embed_dims
=
getattr
(
args
,
"pds_embed_dims"
,
None
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
Tru
e
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
Fals
e
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
None
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
None
)
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
...
...
@@ -1248,7 +1249,7 @@ def base_architecture(args):
args
.
pds_fusion_method
=
getattr
(
args
,
"pds_fusion_method"
,
"all_conv"
)
# intermedia CTC
args
.
pds_ctc
=
getattr
(
args
,
"pds_ctc"
,
"0_0_0_0"
)
args
.
pds_ctc
=
getattr
(
args
,
"pds_ctc"
,
None
)
args
.
intermedia_adapter
=
getattr
(
args
,
"intermedia_adapter"
,
"none"
)
args
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
8f084189
...
...
@@ -558,6 +558,7 @@ class S2TCTCEncoder(FairseqEncoder):
def
__init__
(
self
,
args
,
task
=
None
):
super
()
.
__init__
(
None
)
setattr
(
args
,
"ctc_weight"
,
1.0
)
encoder_type
=
getattr
(
args
,
"encoder_type"
,
"transformer"
)
if
encoder_type
==
"transformer"
:
from
.s2t_transformer
import
S2TTransformerEncoder
...
...
@@ -575,8 +576,7 @@ class S2TCTCEncoder(FairseqEncoder):
return
self
.
encoder
(
src_tokens
,
src_lengths
,
**
kwargs
)
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
self
.
encoder
.
reorder_encoder_out
(
encoder_out
,
new_order
)
return
return
self
.
encoder
.
reorder_encoder_out
(
encoder_out
,
new_order
)
class
CTCDecoder
(
object
):
...
...
fairseq/models/speech_to_text/s2t_dual.py
查看文件 @
8f084189
...
...
@@ -494,15 +494,15 @@ def base_architecture(args):
args
.
pds_ds_method
=
getattr
(
args
,
"pds_ds_method"
,
"conv"
)
args
.
pds_embed_dims
=
getattr
(
args
,
"pds_embed_dims"
,
None
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
Tru
e
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
Fals
e
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
None
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
None
)
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
...
...
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
8f084189
...
...
@@ -112,6 +112,22 @@ class S2TSATEModel(S2TTransformerModel):
type
=
str
,
help
=
"intermedia ctc layers for target sentence"
,
)
# freeze
parser
.
add_argument
(
"--freeze-acoustic-encoder"
,
action
=
"store_true"
,
help
=
"freeze the parameters of the acoustic encoder"
,
)
parser
.
add_argument
(
"--freeze-textual-encoder"
,
action
=
"store_true"
,
help
=
"freeze the parameters of the acoustic encoder"
,
)
parser
.
add_argument
(
"--freeze-decoder"
,
action
=
"store_true"
,
help
=
"freeze the parameters of the decoder"
,
)
pass
@classmethod
...
...
@@ -150,6 +166,18 @@ class S2TSATEModel(S2TTransformerModel):
return
encoder
def
forward
(
self
,
src_tokens
,
src_lengths
,
prev_output_tokens
):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out
=
self
.
encoder
(
src_tokens
=
src_tokens
,
src_lengths
=
src_lengths
)
decoder_out
=
self
.
decoder
(
prev_output_tokens
=
prev_output_tokens
,
encoder_out
=
encoder_out
)
return
decoder_out
class
TextEncoder
(
FairseqEncoder
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
=
None
):
...
...
@@ -224,7 +252,8 @@ class TextEncoder(FairseqEncoder):
elif
args
.
intermedia_adapter
==
"league"
:
strategy
=
getattr
(
args
,
"intermedia_distribution_cutoff"
,
None
)
self
.
adapter
=
Adapter
(
embed_dim
,
args
.
intermedia_adapter
,
len
(
dictionary
),
embed_tokens
=
embed_tokens
,
len
(
dictionary
),
# embed_tokens=embed_tokens,
strategy
=
strategy
)
self
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
self
.
intermedia_temperature
=
getattr
(
args
,
"intermedia_temperature"
,
1
)
...
...
@@ -294,7 +323,7 @@ class S2TSATEEncoder(FairseqEncoder):
# adapter
self
.
temperature
=
args
.
temperature
# self.adapter = Adapter(args, task.source_dictionary, embed_tokens)
strategy
=
None
if
args
.
adapter
==
"shrink"
:
strategy
=
getattr
(
args
,
"ctc_compress_strategy"
,
"avg"
)
...
...
@@ -318,6 +347,9 @@ class S2TSATEEncoder(FairseqEncoder):
args
.
encoder_attention_type
=
acoustic_encoder_attention_type
self
.
freeze_acoustic_encoder
=
getattr
(
args
,
"freeze_acoustic_encoder"
,
False
)
self
.
freeze_textual_encoder
=
getattr
(
args
,
"freeze_textual_encoder"
,
False
)
if
getattr
(
args
,
"use_enc_dlcl"
,
False
):
layer_num
=
args
.
encoder_layers
+
args
.
text_encoder_layers
+
2
self
.
history
=
DynamicLinearCombination
(
args
,
is_encoder
=
True
,
layer_num
=
layer_num
)
...
...
@@ -328,6 +360,10 @@ class S2TSATEEncoder(FairseqEncoder):
if
self
.
history
is
not
None
:
self
.
history
.
clean
()
if
self
.
freeze_acoustic_encoder
:
with
torch
.
no_grad
():
acoustic_encoder_out
=
self
.
acoustic_encoder
(
src_tokens
,
src_lengths
)
else
:
acoustic_encoder_out
=
self
.
acoustic_encoder
(
src_tokens
,
src_lengths
)
encoder_out
=
acoustic_encoder_out
[
"encoder_out"
][
0
]
...
...
@@ -354,6 +390,10 @@ class S2TSATEEncoder(FairseqEncoder):
self
.
history
.
push
(
x
)
if
self
.
freeze_textual_encoder
:
with
torch
.
no_grad
():
x
,
target_ctc_logit
,
target_intermedia_ctc_logits
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
self
.
history
)
else
:
x
,
target_ctc_logit
,
target_intermedia_ctc_logits
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
self
.
history
)
return
{
...
...
@@ -482,15 +522,15 @@ def base_architecture(args):
args
.
pds_ds_method
=
getattr
(
args
,
"pds_ds_method"
,
"conv"
)
args
.
pds_embed_dims
=
getattr
(
args
,
"pds_embed_dims"
,
None
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
Tru
e
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
Fals
e
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
None
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
None
)
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
8f084189
...
...
@@ -598,6 +598,7 @@ class S2TTransformerEncoder(FairseqEncoder):
strategy
=
getattr
(
args
,
"intermedia_distribution_cutoff"
,
None
)
self
.
adapter
=
Adapter
(
dim
,
args
.
intermedia_adapter
,
len
(
task
.
source_dictionary
),
strategy
=
strategy
)
# embed_tokens=embed_tokens if embed_tokens is not None else self.ctc.ctc_projection)
self
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
self
.
intermedia_temperature
=
getattr
(
args
,
"intermedia_temperature"
,
1
)
...
...
@@ -700,6 +701,10 @@ class S2TTransformerEncoder(FairseqEncoder):
logit
=
self
.
ctc
(
norm_x
)
intermedia_ctc_logits
.
append
(
logit
)
logit
=
logit
.
clamp
(
min
=-
1e8
if
logit
.
dtype
==
torch
.
float32
else
-
1e4
,
max
=
1e8
if
logit
.
dtype
==
torch
.
float32
else
1e4
)
prob
=
utils
.
softmax
(
logit
/
self
.
intermedia_temperature
,
dim
=-
1
)
x
,
encoder_padding_mask
=
self
.
adapter
([
x
,
prob
],
encoder_padding_mask
)
...
...
fairseq/models/wav2vec/wav2vec2_asr.py
查看文件 @
8f084189
...
...
@@ -161,8 +161,8 @@ class Wav2VecCtc(BaseFairseqModel):
padding
=
net_output
[
"padding_mask"
]
if
padding
is
not
None
and
padding
.
any
():
padding
=
padding
.
T
logits
[
padding
][
...
,
0
]
=
0
logits
[
padding
][
...
,
1
:]
=
float
(
'-inf'
)
logits
[
padding
][
...
,
0
]
=
0
logits
[
padding
][
...
,
1
:]
=
float
(
'-inf'
)
return
logits
...
...
fairseq/modules/speech_to_text/adapter.py
查看文件 @
8f084189
...
...
@@ -130,7 +130,7 @@ class Adapter(nn.Module):
out
=
coef
*
linear_out
+
(
1
-
coef
)
*
soft_out
elif
self
.
adapter_type
==
"inter_league"
:
soft_out
=
torch
.
mm
(
distribution
,
self
.
embed_adapter
.
weight
.
t
()
)
.
view
(
seq_len
,
bsz
,
-
1
)
soft_out
=
torch
.
mm
(
distribution
,
self
.
embed_adapter
.
weight
)
.
view
(
seq_len
,
bsz
,
-
1
)
out
=
representation
+
soft_out
elif
self
.
adapter_type
==
"none"
:
...
...
@@ -153,7 +153,7 @@ class Adapter(nn.Module):
# x is T x B x C -> B x C x T; weights_matrix is B x T x T'
representation
=
representation
.
permute
(
1
,
2
,
0
)
compressed_output
=
representation
.
float
()
.
bmm
(
weights_matrix
)
.
type_as
(
representation
)
# B x C x T'
compressed_output
=
representation
.
bmm
(
weights_matrix
)
.
type_as
(
representation
)
# B x C x T'
out
=
compressed_output
.
permute
(
2
,
0
,
1
)
out_lengths
=
lengths
.
new
(
new_lengths
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论