Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
2de89089
Commit
2de89089
authored
2 years ago
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix the bugs of sae for MT
parent
380d7794
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
47 行增加
和
35 行删除
+47
-35
egs/mustc/mt/conf/debug.yaml
+1
-1
egs/mustc/mt/decode.sh
+2
-2
egs/mustc/st/conf/base.yaml
+0
-1
egs/mustc/st/conf/inter.yaml
+2
-0
fairseq/criterions/ctc.py
+10
-4
fairseq/models/speech_to_text/s2t_sate.py
+2
-2
fairseq/models/transformer_ctc.py
+29
-24
fairseq/modules/speech_to_text/ctc.py
+1
-1
没有找到文件。
egs/mustc/mt/conf/debug.yaml
查看文件 @
2de89089
...
...
@@ -41,7 +41,7 @@ interleaved-ctc-weight: 0.3
interleaved-ctc-layers
:
6,9
interleaved-ctc-temperature
:
1.0
interleaved-ctc-drop-prob
:
0
interleaved_ctc_upsampling_ratio
:
2
interleaved_ctc_upsampling_ratio
:
3
sae-adapter
:
league
sae-drop-prob
:
0.0
...
...
This diff is collapsed.
Click to expand it.
egs/mustc/mt/decode.sh
查看文件 @
2de89089
...
...
@@ -3,7 +3,7 @@
gpu_num
=
1
data_dir
=
test_subset
=(
test
)
test_subset
=(
valid
test
)
exp_name
=
if
[
"$#"
-eq
1
]
;
then
...
...
@@ -14,7 +14,7 @@ sacrebleu=1
n_average
=
10
beam_size
=
5
len_penalty
=
1.0
max_tokens
=
8
0000
max_tokens
=
2
0000
dec_model
=
checkpoint_best.pt
cmd
=
"./run.sh
...
...
This diff is collapsed.
Click to expand it.
egs/mustc/st/conf/base.yaml
查看文件 @
2de89089
arch
:
s2t_transformer_s
share-decoder-input-output-embed
:
True
share-ctc-and-embed
:
True
optimizer
:
adam
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
...
...
This diff is collapsed.
Click to expand it.
egs/mustc/st/conf/inter.yaml
查看文件 @
2de89089
ctc-weight
:
0.3
share-ctc-and-embed
:
True
interleaved-ctc-weight
:
0.2
interleaved-ctc-layers
:
6,9
interleaved-ctc-temperature
:
1.0
...
...
This diff is collapsed.
Click to expand it.
fairseq/criterions/ctc.py
查看文件 @
2de89089
...
...
@@ -11,6 +11,7 @@ from omegaconf import II
from
typing
import
Optional
import
numpy
as
np
import
logging
import
editdistance
import
torch
import
torch.nn.functional
as
F
...
...
@@ -65,6 +66,10 @@ class CtcCriterionConfig(FairseqDataclass):
default
=
0.0
,
metadata
=
{
"help"
:
"weight of the self distillation CTC loss"
},
)
ctc_self_distill_prob
:
float
=
field
(
default
=
0.1
,
metadata
=
{
"help"
:
"probability to use distillation loss"
},
)
wer_kenlm_model
:
Optional
[
str
]
=
field
(
default
=
None
,
...
...
@@ -137,6 +142,7 @@ class CtcCriterion(FairseqCriterion):
self
.
target_ctc_weight
=
cfg
.
target_ctc_weight
self
.
target_interleaved_ctc_weight
=
cfg
.
target_interleaved_ctc_weight
self
.
ctc_self_distill_weight
=
cfg
.
ctc_self_distill_weight
self
.
ctc_self_distill_prob
=
cfg
.
ctc_self_distill_prob
self
.
ctc_entropy
=
cfg
.
ctc_entropy
self
.
ctc_entropy_cutoff
=
cfg
.
ctc_entropy_cutoff
self
.
all_ctc_weight
=
self
.
ctc_weight
+
self
.
interleaved_ctc_weight
+
\
...
...
@@ -333,7 +339,8 @@ class CtcCriterion(FairseqCriterion):
# calculate the self distillation CTC loss
ctc_self_distill_loss
=
0
ctc_self_distill_num
=
0
if
self
.
ctc_weight
>
0
and
self
.
ctc_self_distill_weight
>
0
and
interleaved_ctc_num
>
0
:
if
self
.
ctc_weight
>
0
and
self
.
ctc_self_distill_weight
>
0
and
interleaved_ctc_num
>
0
and
\
torch
.
rand
()
<
self
.
ctc_self_distill_prob
:
for
i
in
range
(
interleaved_ctc_num
):
out
=
net_output
[
"interleaved_ctc_logits"
][
i
]
if
type
(
out
)
==
list
:
...
...
@@ -347,7 +354,8 @@ class CtcCriterion(FairseqCriterion):
loss
=
F
.
kl_div
(
F
.
log_softmax
(
inter_ctc_logit
,
dim
=-
1
,
dtype
=
torch
.
float32
),
F
.
softmax
(
ctc_logit
,
dim
=-
1
,
dtype
=
torch
.
float32
),
F
.
log_softmax
(
ctc_logit
,
dim
=-
1
,
dtype
=
torch
.
float32
)
.
detach
(),
log_target
=
True
,
reduction
=
"none"
,
)
loss
=
loss
.
sum
(
-
1
)
.
transpose
(
0
,
1
)
.
masked_fill_
(
~
non_padding_mask
,
0.0
)
...
...
@@ -379,8 +387,6 @@ class CtcCriterion(FairseqCriterion):
logger
.
warning
(
"Target CTC loss
%
f!"
%
target_ctc_loss
)
if
not
model
.
training
and
self
.
ctc_weight
+
self
.
interleaved_ctc_weight
>
0
:
import
editdistance
with
torch
.
no_grad
():
lprobs_t
=
lprobs
.
transpose
(
0
,
1
)
.
float
()
.
contiguous
()
.
cpu
()
target
=
tokens
...
...
This diff is collapsed.
Click to expand it.
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
2de89089
...
...
@@ -399,9 +399,9 @@ class S2TSATEEncoder(FairseqEncoder):
# acoustic encoder
acoustic_encoder_type
=
args
.
acoustic_encoder
if
acoustic_encoder_type
==
"transformer"
:
self
.
acoustic_encoder
=
S2TTransformerEncoder
(
args
,
task
)
self
.
acoustic_encoder
=
S2TTransformerEncoder
(
args
,
task
,
decoder_embed_tokens
)
elif
acoustic_encoder_type
==
"pds"
:
self
.
acoustic_encoder
=
PDSS2TTransformerEncoder
(
args
,
task
)
self
.
acoustic_encoder
=
PDSS2TTransformerEncoder
(
args
,
task
,
decoder_embed_tokens
)
else
:
logging
.
error
(
"Unsupported model arch {}!"
.
format
(
acoustic_encoder_type
))
...
...
This diff is collapsed.
Click to expand it.
fairseq/models/transformer_ctc.py
查看文件 @
2de89089
...
...
@@ -708,18 +708,29 @@ class TransformerCTCEncoder(FairseqEncoder):
return_all_hiddens
,
token_embeddings
)
def
upsampling
(
self
,
x
):
def
upsampling
(
self
,
x
,
padding
):
ratio
=
self
.
interleaved_ctc_upsampling_ratio
if
ratio
<=
1
:
return
x
seq_len
,
bsz
,
dim
=
x
.
size
()
x
=
x
.
unsqueeze
(
1
)
.
expand
(
-
1
,
ratio
,
-
1
,
-
1
)
.
reshape
(
-
1
,
bsz
,
dim
)
return
x
def
set_ctc_infer
(
self
,
ctc_infer
,
post_process
):
bsz
,
seq_len
,
dim
=
x
.
size
()
up_x
=
x
.
unsqueeze
(
2
)
.
expand
(
-
1
,
-
1
,
ratio
,
-
1
)
.
reshape
(
bsz
,
-
1
,
dim
)
up_padding
=
padding
.
unsqueeze
(
-
1
)
.
expand
(
-
1
,
-
1
,
ratio
)
.
reshape
(
bsz
,
-
1
)
output_length
=
int
(
seq_len
*
ratio
*
2
/
3
)
select_matrix
=
torch
.
rand
(
bsz
,
ratio
*
seq_len
)
.
to
(
up_x
.
device
)
select_matrix
[:,
1
::
ratio
]
=
1
threshold
=
select_matrix
.
sort
(
dim
=-
1
,
descending
=
True
)[
0
][:,
output_length
:
output_length
+
1
]
select_matrix
=
(
select_matrix
>
threshold
)
assert
all
(
select_matrix
.
sum
(
dim
=-
1
)
.
eq
(
output_length
))
out_x
=
up_x
[
select_matrix
,
:]
.
reshape
(
bsz
,
-
1
,
dim
)
.
contiguous
()
out_padding
=
up_padding
[
select_matrix
]
.
reshape
(
bsz
,
-
1
)
.
contiguous
()
return
out_x
,
out_padding
def
set_ctc_infer
(
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
):
if
hasattr
(
self
,
"ctc"
):
self
.
ctc
.
set_infer
(
ctc_infer
,
post_process
)
assert
tgt_dict
is
not
None
self
.
ctc
.
set_infer
(
ctc_infer
,
post_process
,
tgt_dict
)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
...
...
@@ -768,21 +779,19 @@ class TransformerCTCEncoder(FairseqEncoder):
if
encoder_padding_mask
is
not
None
:
x
=
x
*
(
1
-
encoder_padding_mask
.
unsqueeze
(
-
1
)
.
type_as
(
x
))
ctc_padding_mask
=
encoder_padding_mask
if
self
.
use_ctc
or
len
(
self
.
interleaved_ctc_layers
)
!=
0
:
x
,
encoder_padding_mask
=
self
.
upsampling
(
x
,
encoder_padding_mask
)
ctc_padding_mask
=
encoder_padding_mask
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
bsz
=
x
.
size
(
1
)
encoder_states
=
[]
if
return_all_hiddens
:
encoder_states
.
append
(
x
)
org_encoder_padding_mask
=
encoder_padding_mask
ctc_padding_mask
=
encoder_padding_mask
if
self
.
use_ctc
or
len
(
self
.
interleaved_ctc_layers
)
!=
0
:
ctc_padding_mask
=
encoder_padding_mask
.
unsqueeze
(
-
1
)
.
\
expand
(
-
1
,
-
1
,
self
.
interleaved_ctc_upsampling_ratio
)
.
reshape
(
bsz
,
-
1
)
# add emb into history
if
self
.
history
is
not
None
:
self
.
history
.
push
(
x
)
...
...
@@ -795,10 +804,6 @@ class TransformerCTCEncoder(FairseqEncoder):
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
if
layer_idx
+
1
in
self
.
interleaved_ctc_layers
:
x
=
self
.
upsampling
(
x
)
encoder_padding_mask
=
ctc_padding_mask
x
=
layer
(
x
,
encoder_padding_mask
=
encoder_padding_mask
if
has_pads
else
None
)
...
...
@@ -809,7 +814,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC
if
self
.
use_ctc
and
self
.
inter_ctc
and
self
.
ctc_layer
==
layer_idx
:
ctc_logit
=
self
.
ctc
(
self
.
upsampling
(
x
.
clone
()
),
ctc_padding_mask
)
ctc_logit
=
self
.
ctc
(
x
.
clone
(
),
ctc_padding_mask
)
# Interleaved CTC
if
layer_idx
in
self
.
interleaved_ctc_layers
:
...
...
@@ -826,10 +831,10 @@ class TransformerCTCEncoder(FairseqEncoder):
x
,
_
=
self
.
sae
([
norm_x
,
prob
])
x
=
x
.
permute
(
1
,
2
,
0
)
x
=
self
.
pool
(
x
)
x
=
x
.
permute
(
2
,
0
,
1
)
encoder_padding_mask
=
org_encoder_padding_mask
#
x = x.permute(1, 2, 0)
#
x = self.pool(x)
#
x = x.permute(2, 0, 1)
#
encoder_padding_mask = org_encoder_padding_mask
if
self
.
history
is
not
None
:
self
.
history
.
push
(
x
)
...
...
@@ -841,7 +846,7 @@ class TransformerCTCEncoder(FairseqEncoder):
x
=
self
.
layer_norm
(
x
)
if
self
.
use_ctc
and
ctc_logit
is
None
:
ctc_logit
=
self
.
ctc
(
self
.
upsampling
(
x
)
,
ctc_padding_mask
)
ctc_logit
=
self
.
ctc
(
x
,
ctc_padding_mask
)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
...
...
This diff is collapsed.
Click to expand it.
fairseq/modules/speech_to_text/ctc.py
查看文件 @
2de89089
...
...
@@ -78,7 +78,7 @@ class CTC(nn.Module):
pred_units
=
self
.
dictionary
.
string
(
pred_units_arr
)
pred_words_raw
=
post_process
(
pred_units
,
self
.
post_process
)
.
split
()
print
(
pred_words_raw
)
logger
.
info
(
"
\n
CTC prediction:
%
s"
%
" "
.
join
(
pred_words_raw
)
)
def
valid
(
self
,
logits_or_probs
,
target
,
lengths
):
...
...
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论