Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
9fe8cd1e
Commit
9fe8cd1e
authored
Jul 19, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix some bugs
parent
a201a883
显示空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
111 行增加
和
15 行删除
+111
-15
fairseq/criterions/ctc.py
+12
-4
fairseq/dataclass/configs.py
+10
-1
fairseq/models/speech_to_text/__init__.py
+1
-1
fairseq/models/speech_to_text/s2t_ctc.py
+82
-6
fairseq/models/speech_to_text/s2t_sate.py
+0
-2
fairseq/modules/speech_to_text/ctc.py
+3
-0
fairseq/scoring/wer.py
+2
-0
fairseq_cli/generate.py
+1
-1
没有找到文件。
fairseq/criterions/ctc.py
查看文件 @
9fe8cd1e
...
...
@@ -44,6 +44,10 @@ class CtcCriterionConfig(FairseqDataclass):
"See fairseq.data.data_utils.post_process() for full list of options"
},
)
ctc_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight of CTC loss"
},
)
ctc_entropy
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight of CTC entropy"
},
...
...
@@ -312,7 +316,8 @@ class CtcCriterion(FairseqCriterion):
loss
=
F
.
kl_div
(
F
.
log_softmax
(
student_logit
,
dim
=-
1
,
dtype
=
torch
.
float32
),
F
.
log_softmax
(
teacher_logit
.
detach
(),
dim
=-
1
,
dtype
=
torch
.
float32
),
F
.
log_softmax
(
teacher_logit
,
dim
=-
1
,
dtype
=
torch
.
float32
),
# F.log_softmax(teacher_logit.detach(), dim=-1, dtype=torch.float32),
log_target
=
True
,
reduction
=
"none"
,
)
...
...
@@ -491,7 +496,8 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_num
=
0
non_padding
=
non_padding_mask
if
self
.
ctc_weight
>
0
and
self
.
ctc_self_distill_weight
>
0
and
interleaved_ctc_num
>
0
:
# if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0:
if
self
.
ctc_self_distill_weight
>
0
and
interleaved_ctc_num
>
0
:
teacher_logit
=
ctc_logit
student_logits
=
net_output
[
"interleaved_ctc_logits"
]
ctc_self_distill_num
=
interleaved_ctc_num
...
...
@@ -550,15 +556,17 @@ class CtcCriterion(FairseqCriterion):
if
self
.
target_ctc_weight
!=
0
:
logger
.
warning
(
"Target CTC loss
%
f!"
%
target_ctc_loss
)
# CER is not completely accurate and is for reference only.
if
not
model
.
training
:
if
hasattr
(
model
.
encoder
,
"ctc_valid"
):
encoder
=
model
.
encoder
.
encoder
if
hasattr
(
model
.
encoder
,
"encoder"
)
else
model
.
encoder
if
hasattr
(
encoder
,
"ctc_valid"
):
if
lprobs
is
not
None
:
lprobs_t
=
lprobs
.
transpose
(
0
,
1
)
.
float
()
.
contiguous
()
.
cpu
()
if
mixup
:
idx
=
mixup_idx1
if
mixup_coef
>
0.5
else
mixup_idx2
tokens
=
tokens
[
idx
]
c_err
,
c_len
,
w_errs
,
w_len
,
wv_errs
=
model
.
encoder
.
ctc_valid
(
c_err
,
c_len
,
w_errs
,
w_len
,
wv_errs
=
encoder
.
ctc_valid
(
lprobs_t
,
tokens
,
input_lengths
,
self
.
task
.
source_dictionary
,
lang
=
"source"
)
logging_output
[
"wv_errors"
]
=
wv_errs
...
...
fairseq/dataclass/configs.py
查看文件 @
9fe8cd1e
...
...
@@ -854,9 +854,18 @@ class GenerationConfig(FairseqDataclass):
default
=
False
,
metadata
=
{
"help"
:
"if set, dont use seed for initializing random generators"
},
)
# CTC inference
ctc_infer
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"generate CTC decoding results during inference"
}
metadata
=
{
"help"
:
"generate CTC results during inference"
}
)
ctc_self_ensemble
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"ensemble the top representation and intermediate representations for decoding"
}
)
ctc_inter_logit
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"use the specific logit (from top to bottom, 0 is the top layer) for inference"
}
)
...
...
fairseq/models/speech_to_text/__init__.py
查看文件 @
9fe8cd1e
...
...
@@ -5,8 +5,8 @@
from
.berard
import
*
# noqa
from
.convtransformer
import
*
# noqa
from
.s2t_ctc
import
*
from
.s2t_transformer
import
*
# noqa
from
.pdss2t_transformer
import
*
# noqa
from
.s2t_sate
import
*
# noqa
from
.s2t_dual
import
*
# noqa
from
.s2t_ctc
import
*
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
9fe8cd1e
...
...
@@ -11,8 +11,17 @@ from fairseq.models import (
register_model_architecture
,
)
from
.s2t_transformer
import
S2TTransformerModel
,
S2TTransformerEncoder
from
.pdss2t_transformer
import
PDSS2TTransformerModel
,
PDSS2TTransformerEncoder
# from .s2t_sate import S2TSATEModel, S2TSATEEncoder
# from .s2t_transformer import S2TTransformerModel, S2TTransformerEncoder
# from .pdss2t_transformer import PDSS2TTransformerModel, PDSS2TTransformerEncoder
from
fairseq.models.speech_to_text
import
(
S2TTransformerModel
,
S2TTransformerEncoder
,
PDSS2TTransformerModel
,
PDSS2TTransformerEncoder
,
S2TSATEModel
,
S2TSATEEncoder
,
)
from
torch
import
Tensor
...
...
@@ -30,6 +39,7 @@ class S2TCTCModel(FairseqEncoderModel):
"""Add model-specific arguments to the parser."""
S2TTransformerModel
.
add_args
(
parser
)
PDSS2TTransformerModel
.
add_specific_args
(
parser
)
S2TSATEModel
.
add_specific_args
(
parser
)
# encoder
parser
.
add_argument
(
...
...
@@ -49,7 +59,7 @@ class S2TCTCModel(FairseqEncoderModel):
f
"{args.load_pretrained_encoder_from}"
)
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
component
=
encoder
.
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
return
encoder
...
...
@@ -108,11 +118,16 @@ class S2TCTCEncoder(FairseqEncoder):
self
.
encoder
=
S2TTransformerEncoder
(
args
,
task
)
elif
encoder_type
==
"pds"
:
self
.
encoder
=
PDSS2TTransformerEncoder
(
args
,
task
)
elif
encoder_type
==
"sate"
:
self
.
encoder
=
S2TSATEEncoder
(
args
,
task
)
else
:
logger
.
error
(
"Unsupported architecture:
%
s."
%
encoder_type
)
return
def
set_ctc_infer
(
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
):
self
.
encoder
.
set_ctc_infer
(
ctc_infer
,
post_process
,
src_dict
=
src_dict
,
tgt_dict
=
tgt_dict
,
path
=
path
)
def
forward
(
self
,
src_tokens
,
src_lengths
,
**
kwargs
):
return
self
.
encoder
(
src_tokens
,
src_lengths
,
**
kwargs
)
...
...
@@ -132,6 +147,16 @@ class CTCDecoder(object):
self
.
unk
=
dictionary
.
unk
()
self
.
eos
=
dictionary
.
eos
()
self
.
ctc_self_ensemble
=
getattr
(
args
,
"ctc_self_ensemble"
,
False
)
self
.
ctc_inter_logit
=
getattr
(
args
,
"ctc_inter_logit"
,
0
)
assert
not
(
self
.
ctc_self_ensemble
is
True
and
self
.
ctc_inter_logit
is
True
),
\
"Self ensemble and inference by intermediate logit can not be True at the same time."
if
self
.
ctc_self_ensemble
:
logger
.
info
(
"Using self ensemble for CTC inference"
)
if
self
.
ctc_inter_logit
!=
0
:
logger
.
info
(
"Using intermediate logit
%
d for CTC inference"
%
self
.
ctc_inter_logit
)
self
.
vocab_size
=
len
(
dictionary
)
self
.
beam_size
=
args
.
beam
# the max beam size is the dictionary size - 1, since we never select pad
...
...
@@ -150,7 +175,11 @@ class CTCDecoder(object):
if
self
.
lm_model
is
not
None
:
self
.
lm_model
.
eval
()
self
.
infer
=
"greedy"
if
self
.
beam_size
>
1
:
try
:
from
ctcdecode
import
CTCBeamDecoder
self
.
infer
=
"beam"
self
.
ctc_decoder
=
CTCBeamDecoder
(
dictionary
.
symbols
,
model_path
=
self
.
lm_model
,
...
...
@@ -163,6 +192,8 @@ class CTCDecoder(object):
blank_id
=
self
.
blank
,
log_probs_input
=
False
)
except
ImportError
:
logger
.
warning
(
"Cannot import the CTCBeamDecoder library. We use the greedy search for CTC decoding."
)
def
generate
(
self
,
models
,
sample
:
Dict
[
str
,
Dict
[
str
,
Tensor
]],
**
kwargs
):
...
...
@@ -173,24 +204,34 @@ class CTCDecoder(object):
src_tokens
=
net_input
[
"src_tokens"
]
src_lengths
=
net_input
[
"src_lengths"
]
bsz
,
src_len
=
src_tokens
.
size
()[:
2
]
beam_size
=
self
.
beam_size
encoder_outs
=
self
.
model
(
src_tokens
=
src_tokens
,
src_lengths
=
src_lengths
)
ctc_logit
=
encoder_outs
[
"ctc_logit"
][
0
]
.
transpose
(
0
,
1
)
inter_logits
=
encoder_outs
.
get
(
"interleaved_ctc_logits"
,
[])
inter_logits_num
=
len
(
inter_logits
)
if
self
.
ctc_inter_logit
!=
0
:
if
inter_logits_num
!=
0
:
assert
self
.
ctc_inter_logit
<=
inter_logits_num
ctc_logit
=
inter_logits
[
-
self
.
ctc_inter_logit
]
.
transpose
(
0
,
1
)
logit_length
=
(
~
encoder_outs
[
"encoder_padding_mask"
][
0
])
.
long
()
.
sum
(
-
1
)
finalized
=
[]
if
self
.
infer
==
"beam"
:
beam_results
,
beam_scores
,
time_steps
,
out_lens
=
self
.
ctc_decoder
.
decode
(
utils
.
softmax
(
ctc_logit
,
-
1
),
logit_length
)
finalized
=
[]
for
idx
in
range
(
bsz
):
hypos
=
[]
for
beam_idx
in
range
(
beam_size
):
#for beam_idx in range(beam_size):
for
beam_idx
in
range
(
1
):
hypo
=
dict
()
length
=
out_lens
[
idx
][
beam_idx
]
scores
=
beam_scores
[
idx
,
beam_idx
]
hypo
[
"tokens"
]
=
beam_results
[
idx
,
beam_idx
,
:
length
]
hypo
[
"score"
]
=
scores
hypo
[
"attention"
]
=
None
...
...
@@ -198,6 +239,41 @@ class CTCDecoder(object):
hypo
[
"positional_scores"
]
=
torch
.
Tensor
([
scores
/
length
]
*
length
)
hypos
.
append
(
hypo
)
finalized
.
append
(
hypos
)
# elif self.infer == "greedy":
else
:
ctc_probs
=
utils
.
log_softmax
(
ctc_logit
,
-
1
)
if
self
.
ctc_self_ensemble
:
if
inter_logits_num
!=
0
:
for
i
in
range
(
inter_logits_num
):
inter_logits_prob
=
utils
.
log_softmax
(
inter_logits
[
i
]
.
transpose
(
0
,
1
),
-
1
)
ctc_probs
+=
inter_logits_prob
topk_prob
,
topk_index
=
ctc_probs
.
topk
(
1
,
dim
=
2
)
topk_prob
=
topk_prob
.
squeeze
(
-
1
)
topk_index
=
topk_index
.
squeeze
(
-
1
)
real_indexs
=
topk_index
.
masked_fill
(
encoder_outs
[
"encoder_padding_mask"
][
0
],
self
.
blank
)
.
cpu
()
real_probs
=
topk_prob
.
masked_fill
(
topk_index
==
self
.
blank
,
self
.
blank
)
scores
=
-
real_probs
.
sum
(
-
1
,
keepdim
=
True
)
.
cpu
()
for
idx
in
range
(
bsz
):
hypos
=
[]
hypo
=
dict
()
hyp
=
real_indexs
[
idx
]
.
unique_consecutive
()
hyp
=
hyp
[
hyp
!=
self
.
blank
]
length
=
len
(
hyp
)
hypo
[
"tokens"
]
=
hyp
hypo
[
"score"
]
=
scores
[
idx
]
hypo
[
"attention"
]
=
None
hypo
[
"alignment"
]
=
None
hypo
[
"positional_scores"
]
=
torch
.
Tensor
([
hypo
[
"score"
]
/
length
]
*
length
)
hypos
.
append
(
hypo
)
finalized
.
append
(
hypos
)
return
finalized
...
...
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
9fe8cd1e
import
logging
import
math
import
os
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
checkpoint_utils
,
utils
from
fairseq.models.transformer
import
Embedding
from
fairseq.models
import
(
...
...
fairseq/modules/speech_to_text/ctc.py
查看文件 @
9fe8cd1e
...
...
@@ -39,6 +39,9 @@ class CTC(nn.Module):
self
.
post_process
=
"sentencepiece"
self
.
blank_idx
=
0
self
.
path
=
None
self
.
save_stream
=
None
def
set_infer
(
self
,
is_infer
,
text_post_process
,
dictionary
,
path
):
self
.
infer_decoding
=
is_infer
self
.
post_process
=
text_post_process
...
...
fairseq/scoring/wer.py
查看文件 @
9fe8cd1e
...
...
@@ -46,6 +46,8 @@ class WerScorer(BaseScorer):
self
.
ref_length
=
0
def
add_string
(
self
,
ref
,
pred
):
ref
=
ref
.
replace
(
"<<unk>>"
,
"@"
)
pred
=
pred
.
replace
(
"<<unk>>"
,
"@"
)
ref_items
=
self
.
tokenizer
.
tokenize
(
ref
)
.
split
()
pred_items
=
self
.
tokenizer
.
tokenize
(
pred
)
.
split
()
self
.
distance
+=
self
.
ed
.
eval
(
ref_items
,
pred_items
)
...
...
fairseq_cli/generate.py
查看文件 @
9fe8cd1e
...
...
@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for
model
in
models
:
if
hasattr
(
model
,
"encoder"
)
and
hasattr
(
model
.
encoder
,
"set_ctc_infer"
):
model
.
encoder
.
set_ctc_infer
(
cfg
.
generation
.
ctc_infer
,
"sentencepiece"
,
src_dict
,
tgt_dict
,
translation_path
)
# os.path.dirname(translation_path))
src_dict
,
tgt_dict
,
translation_path
)
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
task
.
load_dataset
(
cfg
.
dataset
.
gen_subset
,
task_cfg
=
saved_cfg
.
task
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论