Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
cabfc4ea
Commit
cabfc4ea
authored
Aug 22, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Daily revision
parent
0a70c5c5
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
49 行增加
和
18 行删除
+49
-18
egs/mustc/st/local/cal_wer.sh
+2
-2
fairseq/criterions/ctc.py
+29
-15
fairseq/models/speech_to_text/pdss2t_transformer.py
+15
-1
fairseq/models/speech_to_text/s2t_ctc.py
+3
-0
没有找到文件。
egs/mustc/st/local/cal_wer.sh
查看文件 @
cabfc4ea
...
@@ -13,4 +13,4 @@ ctc_infer_sort=${infer_dir}/${tag}_ctc_infer_sort
...
@@ -13,4 +13,4 @@ ctc_infer_sort=${infer_dir}/${tag}_ctc_infer_sort
cut
-f1
${
s2s_infer_file
}
>
${
idx
}
cut
-f1
${
s2s_infer_file
}
>
${
idx
}
paste
${
idx
}
${
org_ctc_infer_file
}
>
${
ctc_infer
}
paste
${
idx
}
${
org_ctc_infer_file
}
>
${
ctc_infer
}
sort
-n
-t
$'
\t
'
${
ctc_infer
}
| cut
-f2
>
${
ctc_infer_sort
}
sort
-n
-t
$'
\t
'
${
ctc_infer
}
| cut
-f2
>
${
ctc_infer_sort
}
python3 ./cal_wer_lcrm.py
${
ref
}
${
ctc_infer_sort
}
python3 ./cal_wer.py
${
ref
}
${
ctc_infer_sort
}
\ No newline at end of file
\ No newline at end of file
fairseq/criterions/ctc.py
查看文件 @
cabfc4ea
...
@@ -300,9 +300,9 @@ class CtcCriterion(FairseqCriterion):
...
@@ -300,9 +300,9 @@ class CtcCriterion(FairseqCriterion):
return
loss
,
lprobs
return
loss
,
lprobs
@staticmethod
@staticmethod
def
get_ctc_self_distill_loss
(
distill_num
,
teacher_logit
,
student_logits
,
non_padding_mask
):
def
get_ctc_self_distill_loss
(
distill_num
,
teacher_logit
,
student_logits
,
non_padding_mask
,
temperature
=
1.0
):
ctc_self_distill_loss
=
0
ctc_self_distill_
num
=
0
ctc_self_distill_
losses
=
[]
for
i
in
range
(
distill_num
):
for
i
in
range
(
distill_num
):
logit
=
student_logits
[
i
]
logit
=
student_logits
[
i
]
if
type
(
logit
)
==
list
:
if
type
(
logit
)
==
list
:
...
@@ -315,15 +315,15 @@ class CtcCriterion(FairseqCriterion):
...
@@ -315,15 +315,15 @@ class CtcCriterion(FairseqCriterion):
continue
continue
loss
=
F
.
kl_div
(
loss
=
F
.
kl_div
(
F
.
log_softmax
(
student_logit
,
dim
=-
1
,
dtype
=
torch
.
float32
),
F
.
log_softmax
(
student_logit
/
temperature
,
dim
=-
1
,
dtype
=
torch
.
float32
),
F
.
log_softmax
(
teacher_logit
,
dim
=-
1
,
dtype
=
torch
.
float32
),
# F.log_softmax(teacher_logit / temperature
, dim=-1, dtype=torch.float32),
# F.log_softmax(teacher_logit.detach()
, dim=-1, dtype=torch.float32),
F
.
log_softmax
(
teacher_logit
.
detach
()
/
temperature
,
dim
=-
1
,
dtype
=
torch
.
float32
),
log_target
=
True
,
log_target
=
True
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
ctc_self_distill_loss
+
=
loss
.
sum
(
-
1
)
.
transpose
(
0
,
1
)
.
masked_fill_
(
~
non_padding_mask
,
0.0
)
.
sum
()
loss
=
loss
.
sum
(
-
1
)
.
transpose
(
0
,
1
)
.
masked_fill_
(
~
non_padding_mask
,
0.0
)
.
sum
()
ctc_self_distill_
num
+=
1
ctc_self_distill_
losses
.
append
(
loss
)
return
ctc_self_distill_
num
,
ctc_self_distill_los
s
return
ctc_self_distill_
losse
s
def
get_target_text
(
self
,
sample
):
def
get_target_text
(
self
,
sample
):
if
self
.
aligned_target_ctc
and
"aligned_target"
in
sample
:
if
self
.
aligned_target_ctc
and
"aligned_target"
in
sample
:
...
@@ -507,10 +507,17 @@ class CtcCriterion(FairseqCriterion):
...
@@ -507,10 +507,17 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_num
=
interleaved_ctc_num
-
1
ctc_self_distill_num
=
interleaved_ctc_num
-
1
if
ctc_self_distill_num
!=
0
:
if
ctc_self_distill_num
!=
0
:
ctc_self_distill_num
,
source_ctc_self_distill_los
s
=
\
source_ctc_self_distill_losse
s
=
\
self
.
get_ctc_self_distill_loss
(
self
.
get_ctc_self_distill_loss
(
ctc_self_distill_num
,
teacher_logit
,
student_logits
,
non_padding
)
ctc_self_distill_num
,
source_ctc_self_distill_loss
/=
ctc_self_distill_num
teacher_logit
,
student_logits
,
non_padding
,
self
.
ctc_self_distill_temperature
)
ctc_self_distill_num
=
len
(
source_ctc_self_distill_losses
)
source_ctc_self_distill_loss
=
sum
(
source_ctc_self_distill_losses
)
/
ctc_self_distill_num
logging_output
[
"ctc_self_distill_loss"
]
=
utils
.
item
(
source_ctc_self_distill_loss
.
data
)
logging_output
[
"ctc_self_distill_loss"
]
=
utils
.
item
(
source_ctc_self_distill_loss
.
data
)
ctc_self_distill_loss
+=
source_ctc_self_distill_loss
*
self
.
ctc_self_distill_weight
ctc_self_distill_loss
+=
source_ctc_self_distill_loss
*
self
.
ctc_self_distill_weight
...
@@ -529,11 +536,18 @@ class CtcCriterion(FairseqCriterion):
...
@@ -529,11 +536,18 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_num
=
target_interleaved_ctc_num
-
1
ctc_self_distill_num
=
target_interleaved_ctc_num
-
1
if
ctc_self_distill_num
!=
0
:
if
ctc_self_distill_num
!=
0
:
ctc_self_distill_num
,
target_ctc_self_distill_los
s
=
\
target_ctc_self_distill_losse
s
=
\
self
.
get_ctc_self_distill_loss
(
self
.
get_ctc_self_distill_loss
(
ctc_self_distill_num
,
teacher_logit
,
student_logits
,
non_padding
)
ctc_self_distill_num
,
teacher_logit
,
student_logits
,
non_padding
,
self
.
ctc_self_distill_temperature
)
ctc_self_distill_num
=
len
(
target_ctc_self_distill_losses
)
target_ctc_self_distill_loss
/=
ctc_self_distill_num
target_ctc_self_distill_loss
=
sum
(
target_ctc_self_distill_losses
)
/
ctc_self_distill_num
logging_output
[
"target_ctc_self_distill_loss"
]
=
utils
.
item
(
target_ctc_self_distill_loss
.
data
)
logging_output
[
"target_ctc_self_distill_loss"
]
=
utils
.
item
(
target_ctc_self_distill_loss
.
data
)
ctc_self_distill_loss
+=
target_ctc_self_distill_loss
*
self
.
target_ctc_self_distill_weight
ctc_self_distill_loss
+=
target_ctc_self_distill_loss
*
self
.
target_ctc_self_distill_weight
...
...
fairseq/models/speech_to_text/pdss2t_transformer.py
查看文件 @
cabfc4ea
...
@@ -605,6 +605,20 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -605,6 +605,20 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
}
}
return
x
,
encoder_padding_mask
,
input_lengths
,
mixup
return
x
,
encoder_padding_mask
,
input_lengths
,
mixup
def
set_ctc_infer
(
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
):
if
hasattr
(
self
,
"ctc"
):
assert
src_dict
is
not
None
self
.
ctc
.
set_infer
(
ctc_infer
,
post_process
,
src_dict
,
path
=
path
+
".ctc"
if
path
is
not
None
else
None
)
def
ctc_valid
(
self
,
lprobs
,
targets
,
input_lengths
,
dictionary
,
lang
=
"source"
):
if
hasattr
(
self
,
"ctc"
):
return
self
.
ctc
.
valid
(
lprobs
,
targets
,
input_lengths
,
dictionary
)
logger
.
error
(
"No ctc module in textual encoder"
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
def
forward
(
self
,
src_tokens
,
src_lengths
):
batch
=
src_tokens
.
size
(
0
)
batch
=
src_tokens
.
size
(
0
)
...
@@ -748,7 +762,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -748,7 +762,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
x
=
self
.
layer_norm
(
x
)
x
=
self
.
layer_norm
(
x
)
if
self
.
use_ctc
and
ctc_logit
is
None
:
if
self
.
use_ctc
and
ctc_logit
is
None
:
ctc_logit
=
self
.
ctc
(
x
)
ctc_logit
=
self
.
ctc
(
x
,
encoder_padding_mask
,
is_top
=
True
)
return
{
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_out"
:
[
x
],
# T x B x C
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
cabfc4ea
...
@@ -314,6 +314,9 @@ def base_architecture(args):
...
@@ -314,6 +314,9 @@ def base_architecture(args):
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
encoder_no_scale_embedding
=
getattr
(
args
,
"encoder_no_scale_embedding"
,
False
)
args
.
encoder_embed_linear
=
getattr
(
args
,
"encoder_embed_linear"
,
False
)
args
.
encoder_embed_linear
=
getattr
(
args
,
"encoder_embed_linear"
,
False
)
args
.
encoder_embed_norm
=
getattr
(
args
,
"encoder_embed_norm"
,
False
)
args
.
encoder_embed_norm
=
getattr
(
args
,
"encoder_embed_norm"
,
False
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论