Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
d946bc3b
Commit
d946bc3b
authored
May 27, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
I valid the results of embedding norm and no scale embedding for speech-to-text encoder.
Yeah, it is better.
parent
2de89089
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
55 行增加
和
30 行删除
+55
-30
egs/mustc/st/conf/base.yaml
+3
-0
egs/mustc/st/conf/sate.yaml
+3
-0
fairseq/criterions/ctc.py
+1
-1
fairseq/models/speech_to_text/s2t_sate.py
+7
-5
fairseq/models/speech_to_text/s2t_transformer.py
+17
-11
fairseq/models/transformer_ctc.py
+23
-13
fairseq_cli/train.py
+1
-0
没有找到文件。
egs/mustc/st/conf/base.yaml
查看文件 @
d946bc3b
...
...
@@ -19,6 +19,9 @@ subsampling-stride: 2
subsampling-norm
:
none
subsampling-activation
:
glu
encoder-embed-norm
:
True
encoder-no-scale-embedding
:
True
dropout
:
0.1
activation-fn
:
relu
encoder-embed-dim
:
256
...
...
egs/mustc/st/conf/sate.yaml
查看文件 @
d946bc3b
...
...
@@ -22,6 +22,9 @@ subsampling-stride: 2
subsampling-norm
:
none
subsampling-activation
:
glu
encoder-embed-norm
:
True
encoder-no-scale-embedding
:
True
dropout
:
0.1
activation-fn
:
relu
encoder-embed-dim
:
256
...
...
fairseq/criterions/ctc.py
查看文件 @
d946bc3b
...
...
@@ -340,7 +340,7 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_loss
=
0
ctc_self_distill_num
=
0
if
self
.
ctc_weight
>
0
and
self
.
ctc_self_distill_weight
>
0
and
interleaved_ctc_num
>
0
and
\
torch
.
rand
()
<
self
.
ctc_self_distill_prob
:
torch
.
rand
(
1
)
.
uniform_
(
)
<
self
.
ctc_self_distill_prob
:
for
i
in
range
(
interleaved_ctc_num
):
out
=
net_output
[
"interleaved_ctc_logits"
][
i
]
if
type
(
out
)
==
list
:
...
...
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
d946bc3b
...
...
@@ -258,12 +258,12 @@ class TextualEncoder(FairseqEncoder):
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
args
.
no_scale_embedding
:
if
args
.
encoder_
no_scale_embedding
:
self
.
embed_scale
=
1.0
self
.
padding_idx
=
dictionary
.
pad_index
self
.
e
mbed_norm
=
getattr
(
args
,
"
embed_norm"
,
False
)
if
self
.
embed_norm
:
self
.
e
ncoder_embed_norm
=
getattr
(
args
,
"encoder_
embed_norm"
,
False
)
if
self
.
e
ncoder_e
mbed_norm
:
self
.
embed_ln
=
LayerNorm
(
embed_dim
)
self
.
dropout_module
=
FairseqDropout
(
...
...
@@ -339,7 +339,7 @@ class TextualEncoder(FairseqEncoder):
def
forward
(
self
,
x
,
encoder_padding_mask
=
None
,
history
=
None
):
if
self
.
embed_norm
:
if
self
.
e
ncoder_e
mbed_norm
:
x
=
self
.
embed_ln
(
x
)
x
=
self
.
embed_scale
*
x
positions
=
self
.
embed_positions
(
encoder_padding_mask
)
.
transpose
(
0
,
1
)
...
...
@@ -599,9 +599,11 @@ def base_architecture(args):
)
args
.
decoder_input_dim
=
getattr
(
args
,
"decoder_input_dim"
,
args
.
decoder_embed_dim
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
encoder_no_scale_embedding
=
getattr
(
args
,
"encoder_no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
embed_linear
=
getattr
(
args
,
"embed_linear"
,
False
)
args
.
encoder_embed_linear
=
getattr
(
args
,
"encoder_embed_linear"
,
False
)
args
.
encoder_embed_norm
=
getattr
(
args
,
"encoder_embed_norm"
,
False
)
# CTC
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
d946bc3b
...
...
@@ -236,6 +236,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action
=
"store_true"
,
help
=
"if True, dont scale embeddings"
,
)
parser
.
add_argument
(
"--encoder-no-scale-embedding"
,
action
=
"store_true"
,
help
=
"if True, dont scale embeddings in encoder"
,
)
parser
.
add_argument
(
'--adaptive-softmax-cutoff'
,
metavar
=
'EXPR'
,
help
=
'comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'
),
...
...
@@ -392,12 +397,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help
=
"Kernel size of convolution module."
,
)
parser
.
add_argument
(
"--embed-linear"
,
"--e
ncoder-e
mbed-linear"
,
action
=
"store_true"
,
help
=
"use linear transform after down-sampling"
,
)
parser
.
add_argument
(
"--embed-norm"
,
"--e
ncoder-e
mbed-norm"
,
action
=
"store_true"
,
help
=
"use layer norm after down-sampling"
,
)
...
...
@@ -590,16 +595,16 @@ class S2TTransformerEncoder(FairseqEncoder):
p
=
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
)
self
.
embed_scale
=
math
.
sqrt
(
dim
)
if
args
.
no_scale_embedding
:
if
args
.
encoder_
no_scale_embedding
:
self
.
embed_scale
=
1.0
self
.
padding_idx
=
1
self
.
subsample
=
subsampling
(
args
)
self
.
e
mbed_linear
=
getattr
(
args
,
"
embed_linear"
,
False
)
self
.
e
mbed_norm
=
getattr
(
args
,
"
embed_norm"
,
False
)
if
self
.
embed_linear
:
self
.
e
ncoder_embed_linear
=
getattr
(
args
,
"encoder_
embed_linear"
,
False
)
self
.
e
ncoder_embed_norm
=
getattr
(
args
,
"encoder_
embed_norm"
,
False
)
if
self
.
e
ncoder_e
mbed_linear
:
self
.
linear
=
nn
.
Linear
(
dim
,
dim
)
if
self
.
embed_norm
:
if
self
.
e
ncoder_e
mbed_norm
:
self
.
embed_ln
=
LayerNorm
(
dim
)
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
...
...
@@ -814,7 +819,7 @@ class S2TTransformerEncoder(FairseqEncoder):
if
encoder_padding_mask
is
not
None
:
x
=
x
*
(
1
-
encoder_padding_mask
.
transpose
(
0
,
1
)
.
unsqueeze
(
-
1
)
.
type_as
(
x
))
if
self
.
embed_norm
:
if
self
.
e
ncoder_e
mbed_norm
:
x
=
self
.
embed_ln
(
x
)
self
.
show_debug
(
x
,
"x after embed norm"
)
...
...
@@ -835,7 +840,7 @@ class S2TTransformerEncoder(FairseqEncoder):
positions
=
None
self
.
show_debug
(
x
,
"x after position embedding"
)
if
self
.
embed_linear
:
if
self
.
e
ncoder_e
mbed_linear
:
x
=
self
.
linear
(
x
)
self
.
show_debug
(
x
,
"x after embed linear"
)
...
...
@@ -1061,10 +1066,11 @@ def base_architecture(args):
)
args
.
decoder_input_dim
=
getattr
(
args
,
"decoder_input_dim"
,
args
.
decoder_embed_dim
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
encoder_no_scale_embedding
=
getattr
(
args
,
"encoder_no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
e
mbed_linear
=
getattr
(
args
,
"
embed_linear"
,
False
)
args
.
e
mbed_norm
=
getattr
(
args
,
"
embed_norm"
,
False
)
args
.
e
ncoder_embed_linear
=
getattr
(
args
,
"encoder_
embed_linear"
,
False
)
args
.
e
ncoder_embed_norm
=
getattr
(
args
,
"encoder_
embed_norm"
,
False
)
# CTC
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
...
...
fairseq/models/transformer_ctc.py
查看文件 @
d946bc3b
...
...
@@ -713,18 +713,28 @@ class TransformerCTCEncoder(FairseqEncoder):
if
ratio
<=
1
:
return
x
if
len
(
x
.
size
())
==
3
:
bsz
,
seq_len
,
dim
=
x
.
size
()
up_x
=
x
.
unsqueeze
(
2
)
.
expand
(
-
1
,
-
1
,
ratio
,
-
1
)
.
reshape
(
bsz
,
-
1
,
dim
)
else
:
bsz
,
seq_len
=
x
.
size
()
up_x
=
x
.
unsqueeze
(
2
)
.
expand
(
-
1
,
-
1
,
ratio
)
.
reshape
(
bsz
,
-
1
)
up_padding
=
padding
.
unsqueeze
(
-
1
)
.
expand
(
-
1
,
-
1
,
ratio
)
.
reshape
(
bsz
,
-
1
)
output_length
=
int
(
seq_len
*
ratio
*
2
/
3
)
select_matrix
=
torch
.
rand
(
bsz
,
ratio
*
seq_len
)
.
to
(
up_x
.
device
)
select_matrix
[:,
1
::
ratio
]
=
1
threshold
=
select_matrix
.
sort
(
dim
=-
1
,
descending
=
True
)[
0
][:,
output_length
:
output_length
+
1
]
select_matrix
=
(
select_matrix
>
threshold
)
assert
all
(
select_matrix
.
sum
(
dim
=-
1
)
.
eq
(
output_length
))
out_x
=
up_x
[
select_matrix
,
:]
.
reshape
(
bsz
,
-
1
,
dim
)
.
contiguous
()
out_padding
=
up_padding
[
select_matrix
]
.
reshape
(
bsz
,
-
1
)
.
contiguous
()
# output_length = int(seq_len * ratio * 2/3)
# select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
# select_matrix[:, 1::ratio] = 1
# mask = select_matrix.sort(dim=-1, descending=True)[1][:, :output_length]
# mask = mask.sort(dim=-1)[0]
#
# if len(x.size()) == 3:
# out_x = torch.gather(up_x, dim=1, index=mask.unsqueeze(-1).expand(-1, -1, dim)).contiguous()
# else:
# out_x = torch.gather(up_x, dim=1, index=mask).contiguous()
# out_padding = torch.gather(up_padding, dim=1, index=mask).contiguous()
out_x
=
up_x
out_padding
=
up_padding
return
out_x
,
out_padding
def
set_ctc_infer
(
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
):
...
...
@@ -773,17 +783,17 @@ class TransformerCTCEncoder(FairseqEncoder):
if
self
.
history
is
not
None
:
self
.
history
.
clean
()
ctc_padding_mask
=
encoder_padding_mask
if
self
.
use_ctc
or
len
(
self
.
interleaved_ctc_layers
)
!=
0
:
src_tokens
,
encoder_padding_mask
=
self
.
upsampling
(
src_tokens
,
encoder_padding_mask
)
ctc_padding_mask
=
encoder_padding_mask
x
,
encoder_embedding
=
self
.
forward_embedding
(
src_tokens
,
token_embeddings
)
# account for padding while computing the representation
if
encoder_padding_mask
is
not
None
:
x
=
x
*
(
1
-
encoder_padding_mask
.
unsqueeze
(
-
1
)
.
type_as
(
x
))
ctc_padding_mask
=
encoder_padding_mask
if
self
.
use_ctc
or
len
(
self
.
interleaved_ctc_layers
)
!=
0
:
x
,
encoder_padding_mask
=
self
.
upsampling
(
x
,
encoder_padding_mask
)
ctc_padding_mask
=
encoder_padding_mask
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
...
...
fairseq_cli/train.py
查看文件 @
d946bc3b
...
...
@@ -67,6 +67,7 @@ def main(cfg: FairseqConfig) -> None:
# Print args
logger
.
info
(
cfg
)
if
distributed_utils
.
is_master
(
cfg
.
distributed_training
):
with
open
(
os
.
path
.
join
(
cfg
.
checkpoint
.
save_dir
,
"config.yaml"
),
'w'
)
as
f
:
f
.
write
(
"
%
s"
%
OmegaConf
.
to_yaml
(
cfg
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论