Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
S
S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
S2T
Commits
cb2f2bcb
Commit
cb2f2bcb
authored
Nov 28, 2023
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
2023.11
parent
51395037
显示空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
433 行增加
和
58 行删除
+433
-58
egs/librispeech/asr/run.sh
+7
-5
examples/speech_to_text/prep_audio_data.py
+104
-4
fairseq/criterions/ctc.py
+125
-31
fairseq/criterions/label_smoothed_cross_entropy.py
+4
-4
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
+3
-3
fairseq/data/audio/speech_to_text_dataset.py
+0
-0
fairseq/dataclass/configs.py
+14
-0
fairseq/models/speech_to_text/pdss2t_transformer.py
+2
-1
fairseq/models/speech_to_text/s2t_ctc.py
+5
-4
fairseq/models/speech_to_text/s2t_transformer.py
+114
-2
fairseq/modules/speech_to_text/ctc.py
+13
-0
fairseq/tasks/fairseq_task.py
+6
-0
fairseq/trainer.py
+11
-2
fairseq/utils.py
+16
-0
fairseq_cli/generate.py
+9
-2
没有找到文件。
egs/librispeech/asr/run.sh
查看文件 @
cb2f2bcb
...
...
@@ -83,10 +83,12 @@ epoch_ensemble=0
best_ensemble
=
1
infer_debug
=
0
infer_score
=
0
infer_tag
=
infer_parameter
=
infer_tag
=
ee6
infer_parameter
s
=
"--early-exit-count 6"
#infer_parameter
s
="--early-exit-layer 12"
#infer_parameter
s
="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
infer_parameter
=
"--early-exit-count 6"
#infer_parameter="--early-exit-layer 12"
#infer_parameter="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
data_config
=
config.yaml
...
...
@@ -416,9 +418,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
cmd
=
"
${
cmd
}
--score-reference"
fi
if
[[
-n
${
infer_parameter
s
}
]]
;
then
if
[[
-n
${
infer_parameter
}
]]
;
then
cmd
=
"
${
cmd
}
${
infer_parameter
s
}
"
${
infer_parameter
}
"
fi
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
...
...
examples/speech_to_text/prep_audio_data.py
查看文件 @
cb2f2bcb
...
...
@@ -37,7 +37,7 @@ from tqdm import tqdm
logger
=
logging
.
getLogger
(
__name__
)
MANIFEST_COLUMNS
=
[
"id"
,
"audio"
,
"n_frames"
,
"tgt_text"
]
MANIFEST_COLUMNS
=
[
"id"
,
"audio"
,
"n_frames"
,
"tgt_text"
,
"tgt_lang"
]
class
AudioDataset
(
Dataset
):
...
...
@@ -398,6 +398,7 @@ def process(args):
if
args
.
add_src
and
src_utt
is
not
None
:
manifest
[
"src_text"
]
.
append
(
src_utt
)
manifest
[
"tgt_text"
]
.
append
(
tgt_utt
)
manifest
[
"tgt_lang"
]
.
append
(
tgt_lang
)
if
is_train_split
:
if
args
.
task
==
"st"
and
args
.
add_src
and
args
.
share
:
...
...
@@ -454,8 +455,8 @@ def process(args):
# if task == "st" and args.add_src and args.share:
if
args
.
add_src
and
args
.
share
:
for
e
in
reader
:
if
"src_text"
in
dict
(
e
):
src_utt
=
dict
(
e
)[
"src_text"
]
tgt_utt
=
dict
(
e
)[
"tgt_text"
]
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
...
...
@@ -463,6 +464,8 @@ def process(args):
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
train_text
.
append
(
src_utt
)
tgt_utt
=
dict
(
e
)[
"tgt_text"
]
train_text
.
append
(
tgt_utt
)
else
:
tgt_text
=
[(
dict
(
e
))[
"tgt_text"
]
for
e
in
reader
]
...
...
@@ -471,11 +474,16 @@ def process(args):
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
for
t
in
train_text
:
f
.
write
(
t
+
"
\n
"
)
special_symbols
=
None
if
args
.
add_syms
:
special_symbols
=
[
f
'<lang:{lang}>'
for
lang
in
args
.
tgt_langs
.
split
(
","
)]
gen_vocab
(
Path
(
f
.
name
),
output_root
/
spm_filename_prefix
,
args
.
vocab_type
,
args
.
vocab_size
,
special_symbols
=
special_symbols
)
# Generate config YAML
...
...
@@ -491,9 +499,94 @@ def process(args):
cmvn_type
=
args
.
cmvn_type
,
gcmvn_path
=
(
output_root
/
"gcmvn.npz"
if
args
.
cmvn_type
==
"global"
else
None
),
asr_spm_filename
=
asr_spm_filename
,
share_src_and_tgt
=
True
if
task
==
"asr"
else
False
,
share_src_and_tgt
=
True
if
task
==
"asr"
and
not
args
.
add_src
else
False
,
prepend_tgt_lang_tag
=
(
args
.
add_syms
),
)
def
process_joint
(
args
):
cur_root
=
Path
(
args
.
data_root
)
.
absolute
()
task
=
args
.
task
languages
=
args
.
languages
.
split
(
","
)
assert
all
((
cur_root
/
f
"{lang}"
)
.
is_dir
()
for
lang
in
languages
),
\
"do not have downloaded data available for all languages"
if
args
.
output_root
is
None
:
output_root
=
cur_root
else
:
output_root
=
Path
(
args
.
output_root
)
.
absolute
()
# Generate vocab
v_size_str
=
""
if
args
.
vocab_type
==
"char"
else
str
(
args
.
vocab_size
)
spm_filename_prefix
=
f
"spm_{args.vocab_type}{v_size_str}_{args.task}"
asr_spm_filename
=
None
if
args
.
add_src
:
if
args
.
share
:
if
args
.
st_spm_prefix
is
not
None
:
spm_filename_prefix
=
args
.
st_spm_prefix
else
:
spm_filename_prefix
=
f
"spm_{args.vocab_type}{v_size_str}_{task}_share"
asr_spm_filename
=
spm_filename_prefix
+
".model"
else
:
if
args
.
st_spm_prefix
is
not
None
:
spm_filename_prefix
=
args
.
st_spm_prefix
assert
args
.
asr_prefix
is
not
None
asr_spm_filename
=
args
.
asr_prefix
+
".model"
elif
task
==
"asr"
:
if
args
.
asr_prefix
is
not
None
:
spm_filename_prefix
=
args
.
asr_prefix
punctuation_str
=
string
.
punctuation
punctuation_str
=
punctuation_str
.
replace
(
"'"
,
""
)
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
for
lang
in
languages
:
tsv_path
=
cur_root
/
f
"{lang}"
/
f
"{args.task}"
/
f
"train.tsv"
df
=
load_df_from_tsv
(
tsv_path
)
for
t
in
df
[
"tgt_text"
]:
f
.
write
(
t
+
"
\n
"
)
if
args
.
add_src
:
for
src_utt
in
df
[
"src_text"
]:
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
for
w
in
punctuation_str
:
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
f
.
write
(
src_utt
+
"
\n
"
)
special_symbols
=
None
if
args
.
task
==
'st'
:
special_symbols
=
[
f
'<lang:{lang.split("-")[1]}>'
for
lang
in
languages
]
gen_vocab
(
Path
(
f
.
name
),
output_root
/
spm_filename_prefix
,
args
.
vocab_type
,
args
.
vocab_size
,
special_symbols
=
special_symbols
)
# Generate config YAML
yaml_filename
=
f
"config.yaml"
if
task
==
"st"
and
args
.
add_src
and
args
.
share
:
yaml_filename
=
f
"config_share.yaml"
gen_config_yaml
(
output_root
,
spm_filename_prefix
+
".model"
,
yaml_filename
=
yaml_filename
,
specaugment_policy
=
"ld2"
,
asr_spm_filename
=
asr_spm_filename
,
share_src_and_tgt
=
True
if
task
==
"asr"
else
False
,
prepend_tgt_lang_tag
=
(
args
.
task
==
"st"
),
)
# Make symbolic links to manifests
for
lang
in
languages
:
for
split
in
args
.
splits
.
split
(
","
):
src_path
=
cur_root
/
f
"{lang}"
/
f
"{task}"
/
f
"{split}.tsv"
desc_path
=
output_root
/
f
"{split}_{lang}.tsv"
if
not
desc_path
.
is_symlink
():
shutil
.
copy
(
src_path
,
desc_path
)
def
main
():
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -501,8 +594,12 @@ def main():
parser
.
add_argument
(
"--data-root"
,
"-d"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--output-root"
,
"-o"
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"st"
,
choices
=
[
"asr"
,
"st"
])
parser
.
add_argument
(
"--src-lang"
,
type
=
str
,
required
=
True
,
help
=
"source language"
)
parser
.
add_argument
(
"--joint"
,
action
=
"store_true"
,
help
=
""
)
parser
.
add_argument
(
"--add-syms"
,
action
=
"store_true"
,
help
=
""
)
parser
.
add_argument
(
"--src-lang"
,
type
=
str
,
help
=
"source language"
)
parser
.
add_argument
(
"--tgt-lang"
,
type
=
str
,
help
=
"target language"
)
parser
.
add_argument
(
"--tgt-langs"
,
type
=
str
,
help
=
"target languages for multilingual training"
)
parser
.
add_argument
(
"--languages"
,
type
=
str
,
help
=
"languages for multilingual training"
)
parser
.
add_argument
(
"--splits"
,
type
=
str
,
default
=
"train,dev,test"
,
help
=
"dataset splits"
)
...
...
@@ -569,6 +666,9 @@ def main():
args
=
parser
.
parse_args
()
if
args
.
joint
:
process_joint
(
args
)
else
:
process
(
args
)
...
...
fairseq/criterions/ctc.py
查看文件 @
cb2f2bcb
...
...
@@ -125,6 +125,18 @@ class CtcCriterionConfig(FairseqDataclass):
default
=
0
,
metadata
=
{
"help"
:
"consistent regularization for inter CTC loss in mixup"
},
)
xctc_mixup_consistent_weight
:
float
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"consistent regularization for XCTC loss in mixup"
},
)
inter_xctc_mixup_consistent_weight
:
float
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"consistent regularization for Inter XCTC loss in mixup"
},
)
ctc_mixup_consistent_hard_target
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"use hard distribution during mixup consistent learning"
},
)
wer_kenlm_model
:
Optional
[
str
]
=
field
(
default
=
None
,
...
...
@@ -156,7 +168,11 @@ class CtcCriterionConfig(FairseqDataclass):
@register_criterion
(
"ctc"
,
dataclass
=
CtcCriterionConfig
)
class
CtcCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
cfg
:
CtcCriterionConfig
,
task
:
FairseqTask
,
ctc_weight
=
1.0
,
save_dir
=
None
self
,
cfg
:
CtcCriterionConfig
,
task
:
FairseqTask
,
ctc_weight
=
1.0
,
save_dir
=
None
,
mixup_no_hard_loss
=
False
,
):
super
()
.
__init__
(
task
)
...
...
@@ -224,6 +240,10 @@ class CtcCriterion(FairseqCriterion):
self
.
ctc_mixup_consistent_weight
=
cfg
.
ctc_mixup_consistent_weight
self
.
inter_ctc_mixup_consistent_weight
=
cfg
.
inter_ctc_mixup_consistent_weight
self
.
xctc_mixup_consistent_weight
=
cfg
.
xctc_mixup_consistent_weight
self
.
inter_xctc_mixup_consistent_weight
=
cfg
.
inter_xctc_mixup_consistent_weight
self
.
mixup_no_hard_loss
=
mixup_no_hard_loss
self
.
ctc_mixup_consistent_hard_target
=
cfg
.
ctc_mixup_consistent_hard_target
self
.
all_ctc_weight
=
(
self
.
ctc_weight
...
...
@@ -441,6 +461,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths
,
loss_coef
,
force_emit
=
None
,
loss_mask_flag
=
None
):
lprobs
=
model
.
get_normalized_probs
(
[
ctc_logit
],
log_probs
=
True
...
...
@@ -470,6 +491,9 @@ class CtcCriterion(FairseqCriterion):
input_lengths
,
item_target_lengths
,
)
if
loss_mask_flag
is
not
None
:
item_loss
=
item_loss
*
loss_mask_flag
loss
+=
(
item_loss
*
item_coef
)
.
sum
()
return
loss
,
lprobs
...
...
@@ -518,6 +542,7 @@ class CtcCriterion(FairseqCriterion):
target_tokens
!=
self
.
eos_idx
)
mixup_flag
=
None
if
"mixup"
in
net_output
and
net_output
[
"mixup"
]
is
not
None
:
mixup_coef
=
net_output
[
"mixup"
][
"coef"
]
mixup_idx1
=
net_output
[
"mixup"
][
"index1"
]
...
...
@@ -532,12 +557,14 @@ class CtcCriterion(FairseqCriterion):
target_tokens
=
[
target_tokens1
,
target_tokens2
]
target_lengths
=
[
target_lengths1
,
target_lengths2
]
loss_coef
=
[
mixup_coef
,
1
-
mixup_coef
]
if
self
.
mixup_no_hard_loss
:
mixup_flag
=
~
net_output
[
"mixup"
][
"mixup_flag"
]
else
:
target_tokens
=
[
target_tokens
.
masked_select
(
target_pad_mask
)]
target_lengths
=
[
target_pad_mask
.
sum
(
-
1
)]
loss_coef
=
[
1
]
return
target_tokens
,
target_lengths
,
loss_coef
return
target_tokens
,
target_lengths
,
loss_coef
,
mixup_flag
def
compute_ctc_loss
(
self
,
model
,
sample
,
net_output
,
logging_output
):
if
"transcript"
in
sample
:
...
...
@@ -557,7 +584,7 @@ class CtcCriterion(FairseqCriterion):
nfeatures
=
input_lengths
.
sum
()
.
item
()
logging_output
[
"nfeatures"
]
=
nfeatures
transcripts
,
transcript_lengths
,
loss_coef
=
self
.
get_targets_for_ctc_loss
(
tokens
,
net_output
)
transcripts
,
transcript_lengths
,
loss_coef
,
mixup_flag
=
self
.
get_targets_for_ctc_loss
(
tokens
,
net_output
)
all_ctc_logits
=
dict
()
self
.
ctc_names
=
[]
...
...
@@ -570,17 +597,17 @@ class CtcCriterion(FairseqCriterion):
if
"inter_ctc_logits"
in
net_output
:
inter_ctc_num
=
len
(
net_output
[
"inter_ctc_logits"
])
# calculate the
i
nter CTC loss
# calculate the
I
nter CTC loss
if
self
.
inter_ctc_weight
>
0
and
inter_ctc_num
>
0
:
logits
=
net_output
[
"inter_ctc_logits"
]
for
i
in
range
(
inter_ctc_num
):
inter_transcripts
,
inter_transcript_lengths
,
inter_loss_coef
=
transcripts
,
transcript_lengths
,
loss_coef
inter_transcripts
,
inter_transcript_lengths
,
inter_loss_coef
,
inter_mixup_flag
=
transcripts
,
transcript_lengths
,
loss_coef
,
mixup_flag
if
self
.
inter_ctc_mlo
is
not
None
:
order
=
self
.
inter_ctc_mlo
[
i
]
tokens_key
=
"transcript
%
s"
%
order
if
sample
.
get
(
tokens_key
,
None
):
inter_tokens
=
sample
[
tokens_key
][
"tokens"
]
inter_transcripts
,
inter_transcript_lengths
,
inter_loss_coef
=
self
.
get_targets_for_ctc_loss
(
inter_tokens
,
net_output
)
inter_transcripts
,
inter_transcript_lengths
,
inter_loss_coef
,
inter_mixup_flag
=
self
.
get_targets_for_ctc_loss
(
inter_tokens
,
net_output
)
logit
=
logits
[
i
]
force_emit
=
None
...
...
@@ -625,6 +652,7 @@ class CtcCriterion(FairseqCriterion):
inter_transcript_lengths
,
inter_loss_coef
,
force_emit
,
inter_mixup_flag
)
inter_ctc_loss
+=
inter_loss
lprobs
=
inter_lprobs
...
...
@@ -641,7 +669,6 @@ class CtcCriterion(FairseqCriterion):
):
use_ctc
=
True
logit
=
net_output
[
"ctc_logit"
][
0
]
# all_ctc_logits["ctc_logit"] = [ctc_logit, input_lengths]
force_emit
=
None
if
type
(
logit
)
==
list
:
...
...
@@ -664,6 +691,7 @@ class CtcCriterion(FairseqCriterion):
transcript_lengths
,
loss_coef
,
force_emit
,
mixup_flag
)
if
self
.
ctc_entropy_weight
>
0
:
...
...
@@ -687,7 +715,7 @@ class CtcCriterion(FairseqCriterion):
if
self
.
use_axctc
:
aligned_target_tokens
=
self
.
get_aligned_target_text
(
sample
)
target_tokens
,
target_lengths
,
loss_coef
=
self
.
get_targets_for_ctc_loss
(
target_tokens
,
target_lengths
,
loss_coef
,
target_mixup_flag
=
self
.
get_targets_for_ctc_loss
(
aligned_target_tokens
,
net_output
)
...
...
@@ -711,7 +739,6 @@ class CtcCriterion(FairseqCriterion):
inter_axctc_logit
=
logit
inter_input_lengths
=
input_lengths
# all_ctc_logits["inter_axctc_logit%d" % i] = [inter_axctc_logit, inter_input_lengths]
inter_loss
,
target_inter_lprobs
=
self
.
get_ctc_loss
(
model
,
inter_axctc_logit
,
...
...
@@ -720,6 +747,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths
,
loss_coef
,
force_emit
,
target_mixup_flag
)
inter_axctc_loss
+=
inter_loss
target_lprobs
=
target_inter_lprobs
...
...
@@ -730,7 +758,6 @@ class CtcCriterion(FairseqCriterion):
if
self
.
axctc_weight
>
0
:
assert
"axctc_logit"
in
net_output
logit
=
net_output
[
"axctc_logit"
][
0
]
# all_ctc_logits["axctc_logit"] = [axctc_logit, input_lengths]
force_emit
=
None
if
type
(
logit
)
==
list
:
...
...
@@ -753,6 +780,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths
,
loss_coef
,
force_emit
,
target_mixup_flag
)
logging_output
[
"axctc_loss"
]
=
utils
.
item
(
axctc_loss
.
data
)
...
...
@@ -762,7 +790,7 @@ class CtcCriterion(FairseqCriterion):
if
self
.
use_xctc
:
ctc_target_tokens
=
self
.
get_ctc_target_text
(
sample
)
target_tokens
,
target_lengths
,
loss_coef
=
self
.
get_targets_for_ctc_loss
(
target_tokens
,
target_lengths
,
loss_coef
,
target_mixup_flag
=
self
.
get_targets_for_ctc_loss
(
ctc_target_tokens
,
net_output
)
...
...
@@ -787,7 +815,6 @@ class CtcCriterion(FairseqCriterion):
inter_xctc_logit
=
logit
inter_input_lengths
=
input_lengths
# all_ctc_logits["inter_xctc_logit%d" % i] = [inter_xctc_logit, inter_input_lengths]
inter_loss
,
target_inter_lprobs
=
self
.
get_ctc_loss
(
model
,
inter_xctc_logit
,
...
...
@@ -796,6 +823,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths
,
loss_coef
,
force_emit
,
target_mixup_flag
)
inter_xctc_loss
+=
inter_loss
target_lprobs
=
target_inter_lprobs
...
...
@@ -819,7 +847,6 @@ class CtcCriterion(FairseqCriterion):
force_emit
=
logit
[
2
]
else
:
xctc_logit
=
logit
# all_ctc_logits["xctc_logit"] = [xctc_logit, input_lengths]
xctc_loss
,
target_lprobs
=
self
.
get_ctc_loss
(
model
,
...
...
@@ -829,6 +856,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths
,
loss_coef
,
force_emit
,
target_mixup_flag
)
logging_output
[
"xctc_loss"
]
=
utils
.
item
(
xctc_loss
.
data
)
...
...
@@ -928,21 +956,8 @@ class CtcCriterion(FairseqCriterion):
xctc_self_distill_loss
*
self
.
xctc_self_distill_weight
)
ctc_mixup_consistent_loss
=
0
inter_ctc_mixup_consistent_loss
=
0
if
use_ctc
and
mixup
is
True
:
mixup_coef
=
net_output
[
"mixup"
][
"coef"
]
mixup_idx1
=
net_output
[
"mixup"
][
"index1"
]
mixup_idx2
=
net_output
[
"mixup"
][
"index2"
]
mixup_pos
=
mixup_idx1
!=
mixup_idx2
mixup_real_coef
=
mixup_coef
[
mixup_pos
]
loss_coef
=
[
mixup_real_coef
,
1
-
mixup_real_coef
]
mixup_real_idx1
=
mixup_idx1
[
mixup_pos
]
mixup_real_idx2
=
mixup_idx2
[
mixup_pos
]
def
get_ctc_mixup_consistent_loss
(
ctc_logit
,
non_padding_mask
):
# calculate KD loss for interpolation augmentation
def
get_mixup_consistent_loss
(
ctc_logit
,
non_padding_mask
,
mixup_pos
,
mixup_real_idx1
,
mixup_real_idx2
):
mixup_consistent_loss
=
0
mixup_real_logit
=
ctc_logit
[:,
mixup_pos
,
:]
no_mixup_logit
=
ctc_logit
[:,
~
mixup_pos
,
:]
...
...
@@ -958,9 +973,16 @@ class CtcCriterion(FairseqCriterion):
for
logit
,
pad
,
coef
in
zip
(
mixup_target_logit
,
mixup_target_pad_mask
,
loss_coef
):
if
self
.
ctc_mixup_consistent_hard_target
:
loss
=
F
.
kl_div
(
F
.
log_softmax
(
mixup_real_logit
,
dim
=-
1
,
dtype
=
torch
.
float32
),
utils
.
distribution_soft_to_hard
(
logit
.
detach
())
.
to
(
torch
.
float32
),
log_target
=
False
,
reduction
=
"none"
,
)
else
:
loss
=
F
.
kl_div
(
F
.
log_softmax
(
mixup_real_logit
,
dim
=-
1
,
dtype
=
torch
.
float32
),
# F.log_softmax(logit, dim=-1, dtype=torch.float32),
F
.
log_softmax
(
logit
.
detach
(),
dim
=-
1
,
dtype
=
torch
.
float32
),
log_target
=
True
,
reduction
=
"none"
,
...
...
@@ -970,12 +992,33 @@ class CtcCriterion(FairseqCriterion):
)
.
sum
()
return
mixup_consistent_loss
ctc_mixup_consistent_loss
=
0
inter_ctc_mixup_consistent_loss
=
0
xctc_mixup_consistent_loss
=
0
inter_xctc_mixup_consistent_loss
=
0
if
use_ctc
and
mixup
is
True
:
mixup_coef
=
net_output
[
"mixup"
][
"coef"
]
mixup_idx1
=
net_output
[
"mixup"
][
"index1"
]
mixup_idx2
=
net_output
[
"mixup"
][
"index2"
]
mixup_pos
=
mixup_idx1
!=
mixup_idx2
mixup_real_coef
=
mixup_coef
[
mixup_pos
]
loss_coef
=
[
mixup_real_coef
,
1
-
mixup_real_coef
]
mixup_real_idx1
=
mixup_idx1
[
mixup_pos
]
mixup_real_idx2
=
mixup_idx2
[
mixup_pos
]
if
self
.
ctc_mixup_consistent_weight
>
0
:
ctc_logit
=
net_output
[
"ctc_logit"
][
0
]
ctc_mixup_consistent_loss
=
get_
ctc_mixup_consistent_loss
(
ctc_logit
,
non_padding_mask
)
ctc_mixup_consistent_loss
=
get_
mixup_consistent_loss
(
ctc_logit
,
non_padding_mask
,
mixup_pos
,
mixup_real_idx1
,
mixup_real_idx2
)
logging_output
[
"ctc_mixup_consistent_loss"
]
=
utils
.
item
(
ctc_mixup_consistent_loss
.
data
)
if
self
.
xctc_mixup_consistent_weight
>
0
:
xctc_logit
=
net_output
[
"xctc_logit"
][
0
]
xctc_mixup_consistent_loss
=
get_mixup_consistent_loss
(
xctc_logit
,
non_padding_mask
,
mixup_pos
,
mixup_real_idx1
,
mixup_real_idx2
)
logging_output
[
"xctc_mixup_consistent_loss"
]
=
utils
.
item
(
xctc_mixup_consistent_loss
.
data
)
if
self
.
inter_ctc_mixup_consistent_weight
>
0
:
if
inter_ctc_num
>
0
:
...
...
@@ -989,12 +1032,40 @@ class CtcCriterion(FairseqCriterion):
inter_ctc_logit
=
logit
inter_non_padding_mask
=
non_padding_mask
inter_ctc_mixup_consistent_loss
+=
get_ctc_mixup_consistent_loss
(
inter_ctc_logit
,
inter_non_padding_mask
)
inter_ctc_mixup_consistent_loss
+=
get_mixup_consistent_loss
(
inter_ctc_logit
,
inter_non_padding_mask
,
mixup_pos
,
mixup_real_idx1
,
mixup_real_idx2
)
logging_output
[
"inter_ctc_mixup_consistent_loss"
]
=
utils
.
item
(
inter_ctc_mixup_consistent_loss
.
data
)
if
self
.
inter_xctc_mixup_consistent_weight
>
0
:
if
inter_xctc_num
>
0
:
logits
=
net_output
[
"inter_xctc_logits"
]
for
i
in
range
(
inter_xctc_num
):
logit
=
logits
[
i
]
if
type
(
logit
)
==
list
:
inter_xctc_logit
=
logit
[
0
]
inter_non_padding_mask
=
~
logit
[
1
]
if
logit
[
1
]
is
not
None
else
non_padding_mask
else
:
inter_xctc_logit
=
logit
inter_non_padding_mask
=
non_padding_mask
inter_xctc_mixup_consistent_loss
+=
get_mixup_consistent_loss
(
inter_xctc_logit
,
inter_non_padding_mask
,
mixup_pos
,
mixup_real_idx1
,
mixup_real_idx2
)
logging_output
[
"inter_xctc_mixup_consistent_loss"
]
=
utils
.
item
(
inter_xctc_mixup_consistent_loss
.
data
)
if
len
(
ctc_entropy
)
!=
0
:
ctc_entropy
=
sum
(
ctc_entropy
)
/
len
(
ctc_entropy
)
logging_output
[
"ctc_entropy"
]
=
utils
.
item
(
ctc_entropy
.
data
)
...
...
@@ -1012,6 +1083,8 @@ class CtcCriterion(FairseqCriterion):
+
self
.
ctc_entropy_weight
*
ctc_entropy
+
self
.
ctc_mixup_consistent_weight
*
ctc_mixup_consistent_loss
+
self
.
inter_ctc_mixup_consistent_weight
*
inter_ctc_mixup_consistent_loss
+
self
.
xctc_mixup_consistent_weight
*
xctc_mixup_consistent_loss
+
self
.
inter_xctc_mixup_consistent_weight
*
inter_xctc_mixup_consistent_loss
)
if
loss
!=
0
:
...
...
@@ -1137,6 +1210,13 @@ class CtcCriterion(FairseqCriterion):
inter_ctc_mixup_consistent_loss
=
utils
.
item
(
sum
(
log
.
get
(
"inter_ctc_mixup_consistent_loss"
,
0
)
for
log
in
logging_outputs
)
)
xctc_mixup_consistent_loss
=
utils
.
item
(
sum
(
log
.
get
(
"xctc_mixup_consistent_loss"
,
0
)
for
log
in
logging_outputs
)
)
inter_xctc_mixup_consistent_loss
=
utils
.
item
(
sum
(
log
.
get
(
"inter_xctc_mixup_consistent_loss"
,
0
)
for
log
in
logging_outputs
)
)
all_ctc_loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"all_ctc_loss"
,
0
)
for
log
in
logging_outputs
)
)
...
...
@@ -1245,6 +1325,20 @@ class CtcCriterion(FairseqCriterion):
sample_size
,
round
=
3
,
)
if
xctc_mixup_consistent_loss
>
0
:
metrics
.
log_scalar
(
"xctc_mixup_consistent_loss"
,
xctc_mixup_consistent_loss
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
if
inter_xctc_mixup_consistent_loss
>
0
:
metrics
.
log_scalar
(
"inter_xctc_mixup_consistent_loss"
,
inter_xctc_mixup_consistent_loss
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
metrics
.
log_scalar
(
"ntokens"
,
ntokens
)
metrics
.
log_scalar
(
"nsentences"
,
nsentences
)
...
...
fairseq/criterions/label_smoothed_cross_entropy.py
查看文件 @
cb2f2bcb
...
...
@@ -25,7 +25,7 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
default
=
0.0
,
metadata
=
{
"help"
:
"the weight for consistency regularization of mixup"
},
)
cal_mixup
_loss
:
bool
=
field
(
mixup_no_hard
_loss
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"calculate the loss for the mixed samples"
},
)
...
...
@@ -71,7 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
label_smoothing
,
ignore_prefix_size
=
0
,
report_accuracy
=
False
,
cal_mixup_loss
=
Tru
e
,
mixup_no_hard_loss
=
Fals
e
,
mixup_consistent_weight
=
0.0
,
):
super
()
.
__init__
(
task
)
...
...
@@ -79,7 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self
.
eps
=
float
(
label_smoothing
)
self
.
ignore_prefix_size
=
ignore_prefix_size
self
.
report_accuracy
=
report_accuracy
self
.
cal_mixup_loss
=
cal_mixup
_loss
self
.
mixup_no_hard_loss
=
mixup_no_hard
_loss
self
.
mixup_consistent_weight
=
mixup_consistent_weight
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
...
...
@@ -173,7 +173,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
mixup_coef
=
net_output
[
1
][
"mixup"
][
"coef"
][
mixup_flag
]
loss_coef
=
[
mixup_coef
,
1
-
mixup_coef
]
if
self
.
cal_mixup
_loss
:
if
not
self
.
mixup_no_hard
_loss
:
for
item_lprobs
,
item_target
,
item_coef
in
zip
(
mixup_lprobs
,
mixup_targets
,
loss_coef
):
batch_size
=
item_target
.
size
(
0
)
item_loss
,
item_nll_loss
=
label_smoothed_nll_loss
(
...
...
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
查看文件 @
cb2f2bcb
...
...
@@ -30,19 +30,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
cfg
:
CtcCriterionConfig
,
ctc_weight
=
0.0
,
save_dir
=
None
,
cal_mixup_loss
=
Tru
e
,
mixup_no_hard_loss
=
Fals
e
,
mixup_consistent_weight
=
0.0
,
only_train_enc_prob
=
0.0
,
get_oracle_when_only_train_enc
=
False
):
super
()
.
__init__
(
task
,
sentence_avg
,
label_smoothing
,
report_accuracy
=
True
,
cal_mixup_loss
=
cal_mixup
_loss
,
mixup_no_hard_loss
=
mixup_no_hard
_loss
,
mixup_consistent_weight
=
mixup_consistent_weight
)
self
.
report_accuracy
=
True
self
.
ctc_weight
=
ctc_weight
self
.
ctc_criterion
=
CtcCriterion
(
cfg
,
task
,
ctc_weight
,
save_dir
)
self
.
ctc_criterion
=
CtcCriterion
(
cfg
,
task
,
ctc_weight
,
save_dir
,
mixup_no_hard_loss
)
self
.
save_dir
=
save_dir
self
.
only_train_enc_prob
=
only_train_enc_prob
...
...
fairseq/data/audio/speech_to_text_dataset.py
查看文件 @
cb2f2bcb
fairseq/dataclass/configs.py
查看文件 @
cb2f2bcb
...
...
@@ -198,6 +198,12 @@ class CommonConfig(FairseqDataclass):
"help"
:
"training steps in each epoch"
}
)
sharded_data_load
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Use sharded data for efficient data load"
},
)
@dataclass
...
...
@@ -812,6 +818,14 @@ class GenerationConfig(FairseqDataclass):
default
=
0.0
,
metadata
=
{
"help"
:
"weight for ctc probs for lm fusion"
},
)
early_exit_count
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"early exit during decoding when n consecutive predictions are the same"
},
)
early_exit_layer
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"early exit during decoding at layer n"
},
)
# arguments for iterative refinement generator
iter_decode_eos_penalty
:
float
=
field
(
...
...
fairseq/models/speech_to_text/pdss2t_transformer.py
查看文件 @
cb2f2bcb
...
...
@@ -1019,12 +1019,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
):
if
hasattr
(
self
,
"ctc"
):
import
os
assert
src_dict
is
not
None
self
.
ctc
.
set_infer
(
ctc_infer
,
post_process
,
src_dict
,
path
=
path
+
".ctc"
if
path
is
not
None
else
None
,
path
=
os
.
path
.
splitext
(
path
)[
0
]
+
".ctc"
if
path
is
not
None
else
None
,
)
def
ctc_valid
(
self
,
lprobs
,
targets
,
input_lengths
,
dictionary
,
lang
=
"source"
):
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
cb2f2bcb
...
...
@@ -250,7 +250,6 @@ class CTCDecoder(object):
logger
.
info
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
print
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
from
torchprofile
import
profile_macs
macs
=
profile_macs
(
self
.
model
,
[
src_tokens
,
src_lengths
])
gmacs
=
macs
/
1e9
...
...
@@ -269,20 +268,22 @@ class CTCDecoder(object):
inter_logits
=
encoder_outs
.
get
(
"inter_xctc_logits"
,
[])
if
ctc_logit
is
None
:
ctc_logit
=
encoder_outs
[
"ctc_logit"
][
0
]
.
transpose
(
0
,
1
)
if
len
(
inter_logits
)
>
0
:
if
len
(
inter_logits
)
==
0
:
inter_logits
=
encoder_outs
.
get
(
"inter_ctc_logits"
,
[])
inter_logits_num
=
len
(
inter_logits
)
encoder_padding_mask
=
encoder_outs
[
"encoder_padding_mask"
][
0
]
if
self
.
ctc_inter_logit
!=
0
:
assert
inter_logits_num
>=
self
.
ctc_inter_logit
if
inter_logits_num
!=
0
:
assert
self
.
ctc_inter_logit
<=
inter_logits_num
ctc_logit_item
=
inter_logits
[
-
self
.
ctc_inter_logit
]
if
isinstance
(
ctc_logit_item
,
list
):
ctc_logit
=
ctc_logit_item
[
0
]
.
transpose
(
0
,
1
)
if
len
(
ctc_logit_item
)
>=
2
:
encoder_padding_mask
=
ctc_logit_item
[
1
]
else
:
ctc_logit
=
ctc_logit_item
.
transpose
(
0
,
1
)
logit_length
=
(
~
encoder_padding_mask
)
.
long
()
.
sum
(
-
1
)
finalized
=
[]
...
...
@@ -318,7 +319,7 @@ class CTCDecoder(object):
else
:
logit
=
inter_logits
[
i
]
inter_logits_prob
=
utils
.
log_softmax
(
logit
s
.
transpose
(
0
,
1
),
-
1
)
inter_logits_prob
=
utils
.
log_softmax
(
logit
.
transpose
(
0
,
1
),
-
1
)
ctc_probs
+=
inter_logits_prob
topk_prob
,
topk_index
=
ctc_probs
.
topk
(
1
,
dim
=
2
)
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
cb2f2bcb
...
...
@@ -888,6 +888,8 @@ class S2TTransformerEncoder(FairseqEncoder):
super
()
.
__init__
(
None
)
dim
=
args
.
encoder_embed_dim
self
.
source_dictionary
=
task
.
source_dictionary
self
.
target_dictionary
=
task
.
target_dictionary
layer_num
=
args
.
encoder_layers
self
.
dropout_module
=
FairseqDropout
(
p
=
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
...
...
@@ -1027,6 +1029,7 @@ class S2TTransformerEncoder(FairseqEncoder):
)
),
dropout
=
args
.
dropout
,
dictionary
=
task
.
source_dictionary
)
setattr
(
self
,
"inter_ctc
%
d"
%
layer_idx
,
inter_ctc
)
# inter_layer_norm = LayerNorm(dim)
...
...
@@ -1038,6 +1041,7 @@ class S2TTransformerEncoder(FairseqEncoder):
dim
,
dictionary_size
=
len
(
task
.
source_dictionary
),
dropout
=
args
.
dropout
,
dictionary
=
task
.
source_dictionary
,
)
if
(
getattr
(
args
,
"share_ctc_and_embed"
,
False
)
...
...
@@ -1116,6 +1120,7 @@ class S2TTransformerEncoder(FairseqEncoder):
else
len
(
task
.
target_dictionary
),
dropout
=
args
.
dropout
,
need_layernorm
=
True
if
self
.
inter_xctc
else
False
,
dictionary
=
task
.
target_dictionary
,
)
if
(
...
...
@@ -1375,6 +1380,10 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
mixup_infer
=
False
self
.
rep_dict
=
dict
()
self
.
early_exit_count
=
0
self
.
early_exit_layer_record
=
[]
self
.
early_exit_layer
=
0
@staticmethod
def
build_encoder_layer
(
args
):
return
S2TTransformerEncoderLayer
(
args
)
...
...
@@ -1400,6 +1409,10 @@ class S2TTransformerEncoder(FairseqEncoder):
layer
,
"dump"
)
else
None
print
(
"Early exit layer."
,
file
=
fstream
)
if
self
.
early_exit_count
!=
0
:
print
(
"
\n
"
.
join
([
str
(
l
)
for
l
in
self
.
early_exit_layer_record
]),
file
=
fstream
)
if
self
.
gather_cos_sim
:
print
(
"
\n
Cosine similarity of distance
%
d"
%
self
.
gather_cos_sim_dis
,
...
...
@@ -1540,13 +1553,17 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
mixup_infer
=
kwargs
.
get
(
"mixup_infer"
,
False
)
self
.
gather_cos_sim
=
kwargs
.
get
(
"gather_cos_sim"
,
False
)
self
.
gather_cos_sim_dis
=
kwargs
.
get
(
"gather_cos_sim_dis"
,
2
)
self
.
early_exit_layer
=
kwargs
.
get
(
"early_exit_layer"
,
0
)
if
self
.
early_exit_layer
!=
0
:
logger
.
info
(
"Using the logit in layer
%
d to infer."
%
self
.
early_exit_layer
)
if
self
.
mixup_infer
:
self
.
mixup_keep_org
=
True
def
set_ctc_infer
(
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
,
early_exit_count
=
0
):
self
.
early_exit_count
=
early_exit_count
if
hasattr
(
self
,
"ctc"
):
assert
src_dict
is
not
None
self
.
ctc
.
set_infer
(
...
...
@@ -1711,6 +1728,27 @@ class S2TTransformerEncoder(FairseqEncoder):
org_x
=
x
[:,
~
flag
,
:]
.
mean
(
0
)
rep_dict
[
layer_idx
]
.
append
(
org_x
)
def
early_exit_or_not
(
self
,
history
,
new_logit
,
count
):
history
.
append
(
new_logit
)
length
=
len
(
history
)
if
count
==
0
or
length
<
count
:
return
False
else
:
# for logit in history[length - count: length - 1]:
# if new_logit.size() != logit.size() or not (new_logit == logit).all():
# return False
# return True
hit
=
0
for
logit
in
history
[:
length
-
1
]:
if
new_logit
.
size
()
==
logit
.
size
()
and
(
new_logit
==
logit
)
.
all
():
hit
+=
1
if
hit
>=
count
:
return
True
else
:
return
False
def
forward
(
self
,
src_tokens
,
src_lengths
=
None
,
**
kwargs
):
layer_idx
=
-
1
...
...
@@ -1727,6 +1765,7 @@ class S2TTransformerEncoder(FairseqEncoder):
# (B, T, D) -> (T, B, D)
x
=
src_tokens
.
transpose
(
0
,
1
)
input_lengths
=
src_lengths
org_bsz
=
x
.
size
(
1
)
if
(
(
self
.
training
or
self
.
mixup_infer
)
...
...
@@ -1821,6 +1860,16 @@ class S2TTransformerEncoder(FairseqEncoder):
xctc_oracle_mask
=
None
xctc_force_emit
=
None
# Infer early exit
batch_idx_dict
=
dict
()
inter_ctc_logits_history
=
dict
()
final_ctc_logits
=
dict
()
final_encoder_padding_mask
=
dict
()
early_exit_layer
=
dict
()
for
i
in
range
(
x
.
size
(
1
)):
inter_ctc_logits_history
[
i
]
=
[]
batch_idx_dict
[
i
]
=
i
for
layer
in
self
.
layers
:
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
...
...
@@ -1879,7 +1928,7 @@ class S2TTransformerEncoder(FairseqEncoder):
# Inter CTC
if
layer_idx
in
self
.
inter_ctc_layers
:
if
self
.
inter_ctc_drop_prob
>
0
:
if
self
.
training
and
self
.
inter_ctc_drop_prob
>
0
:
p
=
torch
.
rand
(
1
)
.
uniform_
()
if
p
<
self
.
inter_ctc_drop_prob
:
break
...
...
@@ -1945,6 +1994,44 @@ class S2TTransformerEncoder(FairseqEncoder):
inter_ctc_logits
.
append
(
inter_logit
)
if
not
self
.
training
and
self
.
early_exit_layer
==
layer_idx
:
ctc_logit
=
inter_logit
[
0
]
break
if
not
self
.
training
and
self
.
early_exit_count
!=
0
:
predicts
=
inter_ctc
.
predict
(
inter_logit
[
0
],
encoder_padding_mask
)
if
len
(
inter_ctc_logits
)
<
self
.
early_exit_count
:
for
i
in
range
(
x
.
size
(
1
)):
inter_ctc_logits_history
[
i
]
.
append
(
predicts
[
i
])
else
:
if
org_bsz
==
1
:
early_exit_flag
=
self
.
early_exit_or_not
(
inter_ctc_logits_history
[
0
],
predicts
[
0
],
self
.
early_exit_count
)
if
early_exit_flag
:
ctc_logit
=
inter_logit
[
0
]
self
.
early_exit_layer_record
.
append
(
layer_idx
)
break
else
:
idx
=
0
keep_idx
=
[]
new_batch_idx_dict
=
dict
()
for
i
in
range
(
x
.
size
(
1
)):
real_idx
=
batch_idx_dict
[
i
]
early_exit_flag
=
self
.
early_exit_or_not
(
inter_ctc_logits_history
[
real_idx
],
predicts
[
i
],
self
.
early_exit_count
)
if
early_exit_flag
:
final_ctc_logits
[
real_idx
]
=
inter_logit
[
0
][:,
i
,
:]
final_encoder_padding_mask
[
real_idx
]
=
encoder_padding_mask
[
i
,
:]
early_exit_layer
[
real_idx
]
=
layer_idx
else
:
keep_idx
.
append
(
i
)
new_batch_idx_dict
[
idx
]
=
real_idx
idx
+=
1
if
idx
==
0
:
break
if
idx
<
x
.
size
(
1
):
batch_idx_dict
=
new_batch_idx_dict
x
=
x
[:,
keep_idx
,
:]
.
contiguous
()
encoder_padding_mask
=
encoder_padding_mask
[
keep_idx
,
:]
.
contiguous
()
if
layer_idx
in
self
.
compression_layers
:
ctc_prob
=
utils
.
softmax
(
logit
,
dim
=-
1
)
# (T B C)
blank_prob
=
ctc_prob
[:,
:,
0
]
...
...
@@ -2133,6 +2220,25 @@ class S2TTransformerEncoder(FairseqEncoder):
)
self
.
show_debug
(
x
,
"x after xctc"
)
if
not
self
.
training
and
self
.
early_exit_count
!=
0
and
org_bsz
!=
1
:
if
layer_idx
==
len
(
self
.
layers
)
+
1
:
for
i
in
range
(
x
.
size
(
1
)):
real_idx
=
batch_idx_dict
[
i
]
final_ctc_logits
[
real_idx
]
=
ctc_logit
[:,
i
,
:]
final_encoder_padding_mask
[
real_idx
]
=
encoder_padding_mask
[
i
,
:]
early_exit_layer
[
real_idx
]
=
layer_idx
-
1
output_logits
=
[]
output_encoder_padding_mask
=
[]
output_layers
=
[]
for
i
in
range
(
len
(
final_ctc_logits
)):
output_logits
.
append
(
final_ctc_logits
[
i
])
output_encoder_padding_mask
.
append
(
final_encoder_padding_mask
[
i
])
output_layers
.
append
(
early_exit_layer
[
i
])
ctc_logit
=
torch
.
stack
(
output_logits
,
dim
=
0
)
.
transpose
(
0
,
1
)
encoder_padding_mask
=
torch
.
stack
(
output_encoder_padding_mask
,
dim
=
0
)
self
.
early_exit_layer_record
.
extend
(
output_layers
)
if
ctc_force_emit
is
not
None
:
ctc_logit
=
[
ctc_logit
,
None
,
ctc_force_emit
]
...
...
@@ -2174,6 +2280,11 @@ class S2TTransformerEncoder(FairseqEncoder):
if
len
(
encoder_out
[
"xctc_logit"
])
==
0
else
[
x
.
index_select
(
1
,
new_order
)
for
x
in
encoder_out
[
"xctc_logit"
]]
)
new_inter_ctc_logits
=
(
[]
if
len
(
encoder_out
[
"inter_ctc_logits"
])
==
0
else
[[
x
[
0
]
.
index_select
(
1
,
new_order
)]
.
extend
(
x
[
1
:])
if
isinstance
(
x
,
list
)
else
x
.
index_select
(
1
,
new_order
)
for
x
in
encoder_out
[
"inter_ctc_logits"
]
if
x
is
not
None
]
)
new_encoder_padding_mask
=
(
[]
if
len
(
encoder_out
[
"encoder_padding_mask"
])
==
0
...
...
@@ -2200,6 +2311,7 @@ class S2TTransformerEncoder(FairseqEncoder):
"encoder_out"
:
new_encoder_out
,
# T x B x C
"ctc_logit"
:
new_ctc_logit
,
# T x B x C
"xctc_logit"
:
new_xctc_logit
,
"inter_ctc_logits"
:
new_inter_ctc_logits
,
"encoder_padding_mask"
:
new_encoder_padding_mask
,
# B x T
"encoder_embedding"
:
new_encoder_embedding
,
# B x T x C
"encoder_states"
:
encoder_states
,
# List[T x B x C]
...
...
fairseq/modules/speech_to_text/ctc.py
查看文件 @
cb2f2bcb
...
...
@@ -74,6 +74,19 @@ class CTC(nn.Module):
def
argmax
(
self
,
x
):
return
torch
.
argmax
(
self
.
ctc_projection
(
x
),
dim
=-
1
)
def
predict
(
self
,
logits
,
padding
):
input_lengths
=
(
~
padding
)
.
sum
(
-
1
)
logits
=
logits
.
transpose
(
0
,
1
)
.
float
()
.
contiguous
()
predicts
=
[]
for
logit
,
inp_l
in
zip
(
logits
,
input_lengths
):
toks
=
logit
[:
inp_l
]
.
argmax
(
dim
=-
1
)
.
unique_consecutive
()
pred_units_arr
=
toks
[
toks
!=
self
.
dictionary
.
bos
()]
# pred_units_arr = logit[:inp_l].argmax(dim=-1)
predicts
.
append
(
pred_units_arr
)
return
predicts
def
infer
(
self
,
logits_or_probs
,
lengths
,
tag
=
None
):
for
lp
,
inp_l
in
zip
(
logits_or_probs
,
...
...
fairseq/tasks/fairseq_task.py
查看文件 @
cb2f2bcb
...
...
@@ -130,6 +130,9 @@ class FairseqTask(object):
"""
return
cls
(
cfg
,
**
kwargs
)
def
sharded_data_load
(
self
):
return
getattr
(
self
.
cfg
,
"sharded_data_load"
,
False
)
def
has_sharded_data
(
self
,
split
):
return
os
.
pathsep
in
getattr
(
self
.
cfg
,
"data"
,
""
)
...
...
@@ -619,6 +622,9 @@ class LegacyFairseqTask(FairseqTask):
"""
return
cls
(
args
,
**
kwargs
)
def
sharded_data_load
(
self
):
return
getattr
(
self
.
args
,
"sharded_data_load"
,
False
)
def
has_sharded_data
(
self
,
split
):
return
os
.
pathsep
in
getattr
(
self
.
args
,
"data"
,
""
)
...
...
fairseq/trainer.py
查看文件 @
cb2f2bcb
...
...
@@ -521,16 +521,25 @@ class Trainer(object):
disable_iterator_cache
=
False
,
):
"""Return an EpochBatchIterator over the training set for a given epoch."""
if
self
.
task
.
sharded_data_load
():
datasets
=
self
.
cfg
.
dataset
.
train_subset
.
split
(
","
)
curr_dataset
=
datasets
[(
epoch
-
1
)
%
len
(
datasets
)]
logger
.
info
(
"sharded loading the training subset {}"
.
format
(
curr_dataset
))
else
:
curr_dataset
=
self
.
cfg
.
dataset
.
train_subset
load_dataset
=
load_dataset
or
self
.
task
.
sharded_data_load
()
disable_iterator_cache
=
disable_iterator_cache
or
self
.
task
.
sharded_data_load
()
if
load_dataset
:
logger
.
info
(
"loading train data for epoch {}"
.
format
(
epoch
))
self
.
task
.
load_dataset
(
self
.
cfg
.
dataset
.
train_sub
set
,
curr_data
set
,
epoch
=
epoch
,
combine
=
combine
,
data_selector
=
data_selector
,
)
batch_iterator
=
self
.
task
.
get_batch_iterator
(
dataset
=
self
.
task
.
dataset
(
self
.
cfg
.
dataset
.
train_sub
set
),
dataset
=
self
.
task
.
dataset
(
curr_data
set
),
max_tokens
=
self
.
cfg
.
dataset
.
max_tokens
,
max_sentences
=
self
.
cfg
.
dataset
.
batch_size
,
max_positions
=
utils
.
resolve_max_positions
(
...
...
fairseq/utils.py
查看文件 @
cb2f2bcb
...
...
@@ -754,3 +754,18 @@ def freeze_parameters(module, freeze_module_name):
freeze_module_name
=
freeze_module_name
.
split
(
","
)
for
name
in
freeze_module_name
:
freeze_module_params_by_name
(
module
,
name
)
def
distribution_soft_to_hard
(
logit_or_prob
):
argmax_prob
=
torch
.
argmax
(
logit_or_prob
,
dim
=-
1
,
keepdim
=
True
)
hard_distribution
=
(
(
argmax_prob
==
torch
.
arange
(
logit_or_prob
.
size
(
-
1
),
device
=
logit_or_prob
.
device
)
.
unsqueeze
(
0
)
)
.
to
(
logit_or_prob
.
dtype
)
)
return
hard_distribution
\ No newline at end of file
fairseq_cli/generate.py
查看文件 @
cb2f2bcb
...
...
@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for
model
in
models
:
if
hasattr
(
model
,
"encoder"
)
and
hasattr
(
model
.
encoder
,
"set_ctc_infer"
):
model
.
encoder
.
set_ctc_infer
(
cfg
.
generation
.
ctc_infer
,
"sentencepiece"
,
src_dict
,
tgt_dict
,
translation_path
)
src_dict
,
tgt_dict
,
translation_path
,
cfg
.
generation
.
early_exit_count
)
if
hasattr
(
model
,
"encoder"
)
and
hasattr
(
model
.
encoder
,
"set_flag"
):
model
.
encoder
.
set_flag
(
cal_localness
=
cfg
.
generation
.
cal_localness
,
...
...
@@ -120,6 +120,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
mixup_infer
=
cfg
.
generation
.
mixup_infer
,
gather_cos_sim
=
cfg
.
generation
.
gather_cos_sim
,
gather_cos_sim_dis
=
cfg
.
generation
.
gather_cos_sim_dis
,
early_exit_layer
=
cfg
.
generation
.
early_exit_layer
,
)
if
hasattr
(
model
,
"decoder"
)
and
hasattr
(
model
.
decoder
,
"set_flag"
):
model
.
decoder
.
set_flag
(
...
...
@@ -246,9 +247,15 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
# Remove padding
if
"src_tokens"
in
sample
[
"net_input"
]:
if
sample
[
"net_input"
][
"src_tokens"
]
.
dtype
in
[
torch
.
int32
,
torch
.
int64
]:
src_tokens
=
utils
.
strip_pad
(
sample
[
"net_input"
][
"src_tokens"
][
i
,
:],
tgt
_dict
.
pad
()
sample
[
"net_input"
][
"src_tokens"
][
i
,
:],
src
_dict
.
pad
()
)
elif
"transcript"
in
sample
:
src_tokens
=
utils
.
strip_pad
(
sample
[
"transcript"
][
"tokens"
][
i
,
:],
src_dict
.
pad
()
)
else
:
src_tokens
=
None
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论