Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
S
S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
S2T
Commits
cb2f2bcb
Commit
cb2f2bcb
authored
Nov 28, 2023
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
2023.11
parent
51395037
隐藏空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
459 行增加
和
84 行删除
+459
-84
egs/librispeech/asr/run.sh
+7
-5
examples/speech_to_text/prep_audio_data.py
+112
-12
fairseq/criterions/ctc.py
+138
-44
fairseq/criterions/label_smoothed_cross_entropy.py
+4
-4
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
+3
-3
fairseq/data/audio/speech_to_text_dataset.py
+1
-1
fairseq/dataclass/configs.py
+14
-0
fairseq/models/speech_to_text/pdss2t_transformer.py
+2
-1
fairseq/models/speech_to_text/s2t_ctc.py
+5
-4
fairseq/models/speech_to_text/s2t_transformer.py
+115
-3
fairseq/modules/speech_to_text/ctc.py
+13
-0
fairseq/tasks/fairseq_task.py
+6
-0
fairseq/trainer.py
+11
-2
fairseq/utils.py
+16
-0
fairseq_cli/generate.py
+12
-5
没有找到文件。
egs/librispeech/asr/run.sh
查看文件 @
cb2f2bcb
...
@@ -83,10 +83,12 @@ epoch_ensemble=0
...
@@ -83,10 +83,12 @@ epoch_ensemble=0
best_ensemble
=
1
best_ensemble
=
1
infer_debug
=
0
infer_debug
=
0
infer_score
=
0
infer_score
=
0
infer_tag
=
infer_parameter
=
infer_tag
=
ee6
infer_tag
=
ee6
infer_parameter
s
=
"--early-exit-count 6"
infer_parameter
=
"--early-exit-count 6"
#infer_parameter
s
="--early-exit-layer 12"
#infer_parameter="--early-exit-layer 12"
#infer_parameter
s
="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
#infer_parameter="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
data_config
=
config.yaml
data_config
=
config.yaml
...
@@ -416,9 +418,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
...
@@ -416,9 +418,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
cmd
=
"
${
cmd
}
cmd
=
"
${
cmd
}
--score-reference"
--score-reference"
fi
fi
if
[[
-n
${
infer_parameter
s
}
]]
;
then
if
[[
-n
${
infer_parameter
}
]]
;
then
cmd
=
"
${
cmd
}
cmd
=
"
${
cmd
}
${
infer_parameter
s
}
"
${
infer_parameter
}
"
fi
fi
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
...
...
examples/speech_to_text/prep_audio_data.py
查看文件 @
cb2f2bcb
...
@@ -37,7 +37,7 @@ from tqdm import tqdm
...
@@ -37,7 +37,7 @@ from tqdm import tqdm
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
MANIFEST_COLUMNS
=
[
"id"
,
"audio"
,
"n_frames"
,
"tgt_text"
]
MANIFEST_COLUMNS
=
[
"id"
,
"audio"
,
"n_frames"
,
"tgt_text"
,
"tgt_lang"
]
class
AudioDataset
(
Dataset
):
class
AudioDataset
(
Dataset
):
...
@@ -398,6 +398,7 @@ def process(args):
...
@@ -398,6 +398,7 @@ def process(args):
if
args
.
add_src
and
src_utt
is
not
None
:
if
args
.
add_src
and
src_utt
is
not
None
:
manifest
[
"src_text"
]
.
append
(
src_utt
)
manifest
[
"src_text"
]
.
append
(
src_utt
)
manifest
[
"tgt_text"
]
.
append
(
tgt_utt
)
manifest
[
"tgt_text"
]
.
append
(
tgt_utt
)
manifest
[
"tgt_lang"
]
.
append
(
tgt_lang
)
if
is_train_split
:
if
is_train_split
:
if
args
.
task
==
"st"
and
args
.
add_src
and
args
.
share
:
if
args
.
task
==
"st"
and
args
.
add_src
and
args
.
share
:
...
@@ -454,15 +455,17 @@ def process(args):
...
@@ -454,15 +455,17 @@ def process(args):
# if task == "st" and args.add_src and args.share:
# if task == "st" and args.add_src and args.share:
if
args
.
add_src
and
args
.
share
:
if
args
.
add_src
and
args
.
share
:
for
e
in
reader
:
for
e
in
reader
:
src_utt
=
dict
(
e
)[
"src_text"
]
if
"src_text"
in
dict
(
e
):
src_utt
=
dict
(
e
)[
"src_text"
]
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
for
w
in
punctuation_str
:
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
train_text
.
append
(
src_utt
)
tgt_utt
=
dict
(
e
)[
"tgt_text"
]
tgt_utt
=
dict
(
e
)[
"tgt_text"
]
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
for
w
in
punctuation_str
:
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
train_text
.
append
(
src_utt
)
train_text
.
append
(
tgt_utt
)
train_text
.
append
(
tgt_utt
)
else
:
else
:
tgt_text
=
[(
dict
(
e
))[
"tgt_text"
]
for
e
in
reader
]
tgt_text
=
[(
dict
(
e
))[
"tgt_text"
]
for
e
in
reader
]
...
@@ -471,11 +474,16 @@ def process(args):
...
@@ -471,11 +474,16 @@ def process(args):
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
for
t
in
train_text
:
for
t
in
train_text
:
f
.
write
(
t
+
"
\n
"
)
f
.
write
(
t
+
"
\n
"
)
special_symbols
=
None
if
args
.
add_syms
:
special_symbols
=
[
f
'<lang:{lang}>'
for
lang
in
args
.
tgt_langs
.
split
(
","
)]
gen_vocab
(
gen_vocab
(
Path
(
f
.
name
),
Path
(
f
.
name
),
output_root
/
spm_filename_prefix
,
output_root
/
spm_filename_prefix
,
args
.
vocab_type
,
args
.
vocab_type
,
args
.
vocab_size
,
args
.
vocab_size
,
special_symbols
=
special_symbols
)
)
# Generate config YAML
# Generate config YAML
...
@@ -491,18 +499,107 @@ def process(args):
...
@@ -491,18 +499,107 @@ def process(args):
cmvn_type
=
args
.
cmvn_type
,
cmvn_type
=
args
.
cmvn_type
,
gcmvn_path
=
(
output_root
/
"gcmvn.npz"
if
args
.
cmvn_type
==
"global"
else
None
),
gcmvn_path
=
(
output_root
/
"gcmvn.npz"
if
args
.
cmvn_type
==
"global"
else
None
),
asr_spm_filename
=
asr_spm_filename
,
asr_spm_filename
=
asr_spm_filename
,
share_src_and_tgt
=
True
if
task
==
"asr"
else
False
,
share_src_and_tgt
=
True
if
task
==
"asr"
and
not
args
.
add_src
else
False
,
prepend_tgt_lang_tag
=
(
args
.
add_syms
),
)
)
def
process_joint
(
args
):
cur_root
=
Path
(
args
.
data_root
)
.
absolute
()
task
=
args
.
task
languages
=
args
.
languages
.
split
(
","
)
assert
all
((
cur_root
/
f
"{lang}"
)
.
is_dir
()
for
lang
in
languages
),
\
"do not have downloaded data available for all languages"
if
args
.
output_root
is
None
:
output_root
=
cur_root
else
:
output_root
=
Path
(
args
.
output_root
)
.
absolute
()
# Generate vocab
v_size_str
=
""
if
args
.
vocab_type
==
"char"
else
str
(
args
.
vocab_size
)
spm_filename_prefix
=
f
"spm_{args.vocab_type}{v_size_str}_{args.task}"
asr_spm_filename
=
None
if
args
.
add_src
:
if
args
.
share
:
if
args
.
st_spm_prefix
is
not
None
:
spm_filename_prefix
=
args
.
st_spm_prefix
else
:
spm_filename_prefix
=
f
"spm_{args.vocab_type}{v_size_str}_{task}_share"
asr_spm_filename
=
spm_filename_prefix
+
".model"
else
:
if
args
.
st_spm_prefix
is
not
None
:
spm_filename_prefix
=
args
.
st_spm_prefix
assert
args
.
asr_prefix
is
not
None
asr_spm_filename
=
args
.
asr_prefix
+
".model"
elif
task
==
"asr"
:
if
args
.
asr_prefix
is
not
None
:
spm_filename_prefix
=
args
.
asr_prefix
punctuation_str
=
string
.
punctuation
punctuation_str
=
punctuation_str
.
replace
(
"'"
,
""
)
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
for
lang
in
languages
:
tsv_path
=
cur_root
/
f
"{lang}"
/
f
"{args.task}"
/
f
"train.tsv"
df
=
load_df_from_tsv
(
tsv_path
)
for
t
in
df
[
"tgt_text"
]:
f
.
write
(
t
+
"
\n
"
)
if
args
.
add_src
:
for
src_utt
in
df
[
"src_text"
]:
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
for
w
in
punctuation_str
:
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
f
.
write
(
src_utt
+
"
\n
"
)
special_symbols
=
None
if
args
.
task
==
'st'
:
special_symbols
=
[
f
'<lang:{lang.split("-")[1]}>'
for
lang
in
languages
]
gen_vocab
(
Path
(
f
.
name
),
output_root
/
spm_filename_prefix
,
args
.
vocab_type
,
args
.
vocab_size
,
special_symbols
=
special_symbols
)
# Generate config YAML
yaml_filename
=
f
"config.yaml"
if
task
==
"st"
and
args
.
add_src
and
args
.
share
:
yaml_filename
=
f
"config_share.yaml"
gen_config_yaml
(
output_root
,
spm_filename_prefix
+
".model"
,
yaml_filename
=
yaml_filename
,
specaugment_policy
=
"ld2"
,
asr_spm_filename
=
asr_spm_filename
,
share_src_and_tgt
=
True
if
task
==
"asr"
else
False
,
prepend_tgt_lang_tag
=
(
args
.
task
==
"st"
),
)
# Make symbolic links to manifests
for
lang
in
languages
:
for
split
in
args
.
splits
.
split
(
","
):
src_path
=
cur_root
/
f
"{lang}"
/
f
"{task}"
/
f
"{split}.tsv"
desc_path
=
output_root
/
f
"{split}_{lang}.tsv"
if
not
desc_path
.
is_symlink
():
shutil
.
copy
(
src_path
,
desc_path
)
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
# general setting
# general setting
parser
.
add_argument
(
"--data-root"
,
"-d"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--data-root"
,
"-d"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--output-root"
,
"-o"
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
"--output-root"
,
"-o"
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"st"
,
choices
=
[
"asr"
,
"st"
])
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"st"
,
choices
=
[
"asr"
,
"st"
])
parser
.
add_argument
(
"--src-lang"
,
type
=
str
,
required
=
True
,
help
=
"source language"
)
parser
.
add_argument
(
"--joint"
,
action
=
"store_true"
,
help
=
""
)
parser
.
add_argument
(
"--add-syms"
,
action
=
"store_true"
,
help
=
""
)
parser
.
add_argument
(
"--src-lang"
,
type
=
str
,
help
=
"source language"
)
parser
.
add_argument
(
"--tgt-lang"
,
type
=
str
,
help
=
"target language"
)
parser
.
add_argument
(
"--tgt-lang"
,
type
=
str
,
help
=
"target language"
)
parser
.
add_argument
(
"--tgt-langs"
,
type
=
str
,
help
=
"target languages for multilingual training"
)
parser
.
add_argument
(
"--languages"
,
type
=
str
,
help
=
"languages for multilingual training"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--splits"
,
type
=
str
,
default
=
"train,dev,test"
,
help
=
"dataset splits"
"--splits"
,
type
=
str
,
default
=
"train,dev,test"
,
help
=
"dataset splits"
)
)
...
@@ -569,7 +666,10 @@ def main():
...
@@ -569,7 +666,10 @@ def main():
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
process
(
args
)
if
args
.
joint
:
process_joint
(
args
)
else
:
process
(
args
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
fairseq/criterions/ctc.py
查看文件 @
cb2f2bcb
...
@@ -125,6 +125,18 @@ class CtcCriterionConfig(FairseqDataclass):
...
@@ -125,6 +125,18 @@ class CtcCriterionConfig(FairseqDataclass):
default
=
0
,
default
=
0
,
metadata
=
{
"help"
:
"consistent regularization for inter CTC loss in mixup"
},
metadata
=
{
"help"
:
"consistent regularization for inter CTC loss in mixup"
},
)
)
xctc_mixup_consistent_weight
:
float
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"consistent regularization for XCTC loss in mixup"
},
)
inter_xctc_mixup_consistent_weight
:
float
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"consistent regularization for Inter XCTC loss in mixup"
},
)
ctc_mixup_consistent_hard_target
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"use hard distribution during mixup consistent learning"
},
)
wer_kenlm_model
:
Optional
[
str
]
=
field
(
wer_kenlm_model
:
Optional
[
str
]
=
field
(
default
=
None
,
default
=
None
,
...
@@ -156,7 +168,11 @@ class CtcCriterionConfig(FairseqDataclass):
...
@@ -156,7 +168,11 @@ class CtcCriterionConfig(FairseqDataclass):
@register_criterion
(
"ctc"
,
dataclass
=
CtcCriterionConfig
)
@register_criterion
(
"ctc"
,
dataclass
=
CtcCriterionConfig
)
class
CtcCriterion
(
FairseqCriterion
):
class
CtcCriterion
(
FairseqCriterion
):
def
__init__
(
def
__init__
(
self
,
cfg
:
CtcCriterionConfig
,
task
:
FairseqTask
,
ctc_weight
=
1.0
,
save_dir
=
None
self
,
cfg
:
CtcCriterionConfig
,
task
:
FairseqTask
,
ctc_weight
=
1.0
,
save_dir
=
None
,
mixup_no_hard_loss
=
False
,
):
):
super
()
.
__init__
(
task
)
super
()
.
__init__
(
task
)
...
@@ -224,6 +240,10 @@ class CtcCriterion(FairseqCriterion):
...
@@ -224,6 +240,10 @@ class CtcCriterion(FairseqCriterion):
self
.
ctc_mixup_consistent_weight
=
cfg
.
ctc_mixup_consistent_weight
self
.
ctc_mixup_consistent_weight
=
cfg
.
ctc_mixup_consistent_weight
self
.
inter_ctc_mixup_consistent_weight
=
cfg
.
inter_ctc_mixup_consistent_weight
self
.
inter_ctc_mixup_consistent_weight
=
cfg
.
inter_ctc_mixup_consistent_weight
self
.
xctc_mixup_consistent_weight
=
cfg
.
xctc_mixup_consistent_weight
self
.
inter_xctc_mixup_consistent_weight
=
cfg
.
inter_xctc_mixup_consistent_weight
self
.
mixup_no_hard_loss
=
mixup_no_hard_loss
self
.
ctc_mixup_consistent_hard_target
=
cfg
.
ctc_mixup_consistent_hard_target
self
.
all_ctc_weight
=
(
self
.
all_ctc_weight
=
(
self
.
ctc_weight
self
.
ctc_weight
...
@@ -441,6 +461,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -441,6 +461,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths
,
target_lengths
,
loss_coef
,
loss_coef
,
force_emit
=
None
,
force_emit
=
None
,
loss_mask_flag
=
None
):
):
lprobs
=
model
.
get_normalized_probs
(
lprobs
=
model
.
get_normalized_probs
(
[
ctc_logit
],
log_probs
=
True
[
ctc_logit
],
log_probs
=
True
...
@@ -470,6 +491,9 @@ class CtcCriterion(FairseqCriterion):
...
@@ -470,6 +491,9 @@ class CtcCriterion(FairseqCriterion):
input_lengths
,
input_lengths
,
item_target_lengths
,
item_target_lengths
,
)
)
if
loss_mask_flag
is
not
None
:
item_loss
=
item_loss
*
loss_mask_flag
loss
+=
(
item_loss
*
item_coef
)
.
sum
()
loss
+=
(
item_loss
*
item_coef
)
.
sum
()
return
loss
,
lprobs
return
loss
,
lprobs
...
@@ -518,6 +542,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -518,6 +542,7 @@ class CtcCriterion(FairseqCriterion):
target_tokens
!=
self
.
eos_idx
target_tokens
!=
self
.
eos_idx
)
)
mixup_flag
=
None
if
"mixup"
in
net_output
and
net_output
[
"mixup"
]
is
not
None
:
if
"mixup"
in
net_output
and
net_output
[
"mixup"
]
is
not
None
:
mixup_coef
=
net_output
[
"mixup"
][
"coef"
]
mixup_coef
=
net_output
[
"mixup"
][
"coef"
]
mixup_idx1
=
net_output
[
"mixup"
][
"index1"
]
mixup_idx1
=
net_output
[
"mixup"
][
"index1"
]
...
@@ -532,12 +557,14 @@ class CtcCriterion(FairseqCriterion):
...
@@ -532,12 +557,14 @@ class CtcCriterion(FairseqCriterion):
target_tokens
=
[
target_tokens1
,
target_tokens2
]
target_tokens
=
[
target_tokens1
,
target_tokens2
]
target_lengths
=
[
target_lengths1
,
target_lengths2
]
target_lengths
=
[
target_lengths1
,
target_lengths2
]
loss_coef
=
[
mixup_coef
,
1
-
mixup_coef
]
loss_coef
=
[
mixup_coef
,
1
-
mixup_coef
]
if
self
.
mixup_no_hard_loss
:
mixup_flag
=
~
net_output
[
"mixup"
][
"mixup_flag"
]
else
:
else
:
target_tokens
=
[
target_tokens
.
masked_select
(
target_pad_mask
)]
target_tokens
=
[
target_tokens
.
masked_select
(
target_pad_mask
)]
target_lengths
=
[
target_pad_mask
.
sum
(
-
1
)]
target_lengths
=
[
target_pad_mask
.
sum
(
-
1
)]
loss_coef
=
[
1
]
loss_coef
=
[
1
]
return
target_tokens
,
target_lengths
,
loss_coef
return
target_tokens
,
target_lengths
,
loss_coef
,
mixup_flag
def
compute_ctc_loss
(
self
,
model
,
sample
,
net_output
,
logging_output
):
def
compute_ctc_loss
(
self
,
model
,
sample
,
net_output
,
logging_output
):
if
"transcript"
in
sample
:
if
"transcript"
in
sample
:
...
@@ -557,7 +584,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -557,7 +584,7 @@ class CtcCriterion(FairseqCriterion):
nfeatures
=
input_lengths
.
sum
()
.
item
()
nfeatures
=
input_lengths
.
sum
()
.
item
()
logging_output
[
"nfeatures"
]
=
nfeatures
logging_output
[
"nfeatures"
]
=
nfeatures
transcripts
,
transcript_lengths
,
loss_coef
=
self
.
get_targets_for_ctc_loss
(
tokens
,
net_output
)
transcripts
,
transcript_lengths
,
loss_coef
,
mixup_flag
=
self
.
get_targets_for_ctc_loss
(
tokens
,
net_output
)
all_ctc_logits
=
dict
()
all_ctc_logits
=
dict
()
self
.
ctc_names
=
[]
self
.
ctc_names
=
[]
...
@@ -570,17 +597,17 @@ class CtcCriterion(FairseqCriterion):
...
@@ -570,17 +597,17 @@ class CtcCriterion(FairseqCriterion):
if
"inter_ctc_logits"
in
net_output
:
if
"inter_ctc_logits"
in
net_output
:
inter_ctc_num
=
len
(
net_output
[
"inter_ctc_logits"
])
inter_ctc_num
=
len
(
net_output
[
"inter_ctc_logits"
])
# calculate the
i
nter CTC loss
# calculate the
I
nter CTC loss
if
self
.
inter_ctc_weight
>
0
and
inter_ctc_num
>
0
:
if
self
.
inter_ctc_weight
>
0
and
inter_ctc_num
>
0
:
logits
=
net_output
[
"inter_ctc_logits"
]
logits
=
net_output
[
"inter_ctc_logits"
]
for
i
in
range
(
inter_ctc_num
):
for
i
in
range
(
inter_ctc_num
):
inter_transcripts
,
inter_transcript_lengths
,
inter_loss_coef
=
transcripts
,
transcript_lengths
,
loss_coef
inter_transcripts
,
inter_transcript_lengths
,
inter_loss_coef
,
inter_mixup_flag
=
transcripts
,
transcript_lengths
,
loss_coef
,
mixup_flag
if
self
.
inter_ctc_mlo
is
not
None
:
if
self
.
inter_ctc_mlo
is
not
None
:
order
=
self
.
inter_ctc_mlo
[
i
]
order
=
self
.
inter_ctc_mlo
[
i
]
tokens_key
=
"transcript
%
s"
%
order
tokens_key
=
"transcript
%
s"
%
order
if
sample
.
get
(
tokens_key
,
None
):
if
sample
.
get
(
tokens_key
,
None
):
inter_tokens
=
sample
[
tokens_key
][
"tokens"
]
inter_tokens
=
sample
[
tokens_key
][
"tokens"
]
inter_transcripts
,
inter_transcript_lengths
,
inter_loss_coef
=
self
.
get_targets_for_ctc_loss
(
inter_tokens
,
net_output
)
inter_transcripts
,
inter_transcript_lengths
,
inter_loss_coef
,
inter_mixup_flag
=
self
.
get_targets_for_ctc_loss
(
inter_tokens
,
net_output
)
logit
=
logits
[
i
]
logit
=
logits
[
i
]
force_emit
=
None
force_emit
=
None
...
@@ -625,6 +652,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -625,6 +652,7 @@ class CtcCriterion(FairseqCriterion):
inter_transcript_lengths
,
inter_transcript_lengths
,
inter_loss_coef
,
inter_loss_coef
,
force_emit
,
force_emit
,
inter_mixup_flag
)
)
inter_ctc_loss
+=
inter_loss
inter_ctc_loss
+=
inter_loss
lprobs
=
inter_lprobs
lprobs
=
inter_lprobs
...
@@ -641,7 +669,6 @@ class CtcCriterion(FairseqCriterion):
...
@@ -641,7 +669,6 @@ class CtcCriterion(FairseqCriterion):
):
):
use_ctc
=
True
use_ctc
=
True
logit
=
net_output
[
"ctc_logit"
][
0
]
logit
=
net_output
[
"ctc_logit"
][
0
]
# all_ctc_logits["ctc_logit"] = [ctc_logit, input_lengths]
force_emit
=
None
force_emit
=
None
if
type
(
logit
)
==
list
:
if
type
(
logit
)
==
list
:
...
@@ -664,6 +691,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -664,6 +691,7 @@ class CtcCriterion(FairseqCriterion):
transcript_lengths
,
transcript_lengths
,
loss_coef
,
loss_coef
,
force_emit
,
force_emit
,
mixup_flag
)
)
if
self
.
ctc_entropy_weight
>
0
:
if
self
.
ctc_entropy_weight
>
0
:
...
@@ -687,7 +715,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -687,7 +715,7 @@ class CtcCriterion(FairseqCriterion):
if
self
.
use_axctc
:
if
self
.
use_axctc
:
aligned_target_tokens
=
self
.
get_aligned_target_text
(
sample
)
aligned_target_tokens
=
self
.
get_aligned_target_text
(
sample
)
target_tokens
,
target_lengths
,
loss_coef
=
self
.
get_targets_for_ctc_loss
(
target_tokens
,
target_lengths
,
loss_coef
,
target_mixup_flag
=
self
.
get_targets_for_ctc_loss
(
aligned_target_tokens
,
net_output
aligned_target_tokens
,
net_output
)
)
...
@@ -711,7 +739,6 @@ class CtcCriterion(FairseqCriterion):
...
@@ -711,7 +739,6 @@ class CtcCriterion(FairseqCriterion):
inter_axctc_logit
=
logit
inter_axctc_logit
=
logit
inter_input_lengths
=
input_lengths
inter_input_lengths
=
input_lengths
# all_ctc_logits["inter_axctc_logit%d" % i] = [inter_axctc_logit, inter_input_lengths]
inter_loss
,
target_inter_lprobs
=
self
.
get_ctc_loss
(
inter_loss
,
target_inter_lprobs
=
self
.
get_ctc_loss
(
model
,
model
,
inter_axctc_logit
,
inter_axctc_logit
,
...
@@ -720,6 +747,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -720,6 +747,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths
,
target_lengths
,
loss_coef
,
loss_coef
,
force_emit
,
force_emit
,
target_mixup_flag
)
)
inter_axctc_loss
+=
inter_loss
inter_axctc_loss
+=
inter_loss
target_lprobs
=
target_inter_lprobs
target_lprobs
=
target_inter_lprobs
...
@@ -730,7 +758,6 @@ class CtcCriterion(FairseqCriterion):
...
@@ -730,7 +758,6 @@ class CtcCriterion(FairseqCriterion):
if
self
.
axctc_weight
>
0
:
if
self
.
axctc_weight
>
0
:
assert
"axctc_logit"
in
net_output
assert
"axctc_logit"
in
net_output
logit
=
net_output
[
"axctc_logit"
][
0
]
logit
=
net_output
[
"axctc_logit"
][
0
]
# all_ctc_logits["axctc_logit"] = [axctc_logit, input_lengths]
force_emit
=
None
force_emit
=
None
if
type
(
logit
)
==
list
:
if
type
(
logit
)
==
list
:
...
@@ -753,6 +780,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -753,6 +780,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths
,
target_lengths
,
loss_coef
,
loss_coef
,
force_emit
,
force_emit
,
target_mixup_flag
)
)
logging_output
[
"axctc_loss"
]
=
utils
.
item
(
axctc_loss
.
data
)
logging_output
[
"axctc_loss"
]
=
utils
.
item
(
axctc_loss
.
data
)
...
@@ -762,7 +790,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -762,7 +790,7 @@ class CtcCriterion(FairseqCriterion):
if
self
.
use_xctc
:
if
self
.
use_xctc
:
ctc_target_tokens
=
self
.
get_ctc_target_text
(
sample
)
ctc_target_tokens
=
self
.
get_ctc_target_text
(
sample
)
target_tokens
,
target_lengths
,
loss_coef
=
self
.
get_targets_for_ctc_loss
(
target_tokens
,
target_lengths
,
loss_coef
,
target_mixup_flag
=
self
.
get_targets_for_ctc_loss
(
ctc_target_tokens
,
net_output
ctc_target_tokens
,
net_output
)
)
...
@@ -787,7 +815,6 @@ class CtcCriterion(FairseqCriterion):
...
@@ -787,7 +815,6 @@ class CtcCriterion(FairseqCriterion):
inter_xctc_logit
=
logit
inter_xctc_logit
=
logit
inter_input_lengths
=
input_lengths
inter_input_lengths
=
input_lengths
# all_ctc_logits["inter_xctc_logit%d" % i] = [inter_xctc_logit, inter_input_lengths]
inter_loss
,
target_inter_lprobs
=
self
.
get_ctc_loss
(
inter_loss
,
target_inter_lprobs
=
self
.
get_ctc_loss
(
model
,
model
,
inter_xctc_logit
,
inter_xctc_logit
,
...
@@ -796,6 +823,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -796,6 +823,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths
,
target_lengths
,
loss_coef
,
loss_coef
,
force_emit
,
force_emit
,
target_mixup_flag
)
)
inter_xctc_loss
+=
inter_loss
inter_xctc_loss
+=
inter_loss
target_lprobs
=
target_inter_lprobs
target_lprobs
=
target_inter_lprobs
...
@@ -819,7 +847,6 @@ class CtcCriterion(FairseqCriterion):
...
@@ -819,7 +847,6 @@ class CtcCriterion(FairseqCriterion):
force_emit
=
logit
[
2
]
force_emit
=
logit
[
2
]
else
:
else
:
xctc_logit
=
logit
xctc_logit
=
logit
# all_ctc_logits["xctc_logit"] = [xctc_logit, input_lengths]
xctc_loss
,
target_lprobs
=
self
.
get_ctc_loss
(
xctc_loss
,
target_lprobs
=
self
.
get_ctc_loss
(
model
,
model
,
...
@@ -829,6 +856,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -829,6 +856,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths
,
target_lengths
,
loss_coef
,
loss_coef
,
force_emit
,
force_emit
,
target_mixup_flag
)
)
logging_output
[
"xctc_loss"
]
=
utils
.
item
(
xctc_loss
.
data
)
logging_output
[
"xctc_loss"
]
=
utils
.
item
(
xctc_loss
.
data
)
...
@@ -928,8 +956,46 @@ class CtcCriterion(FairseqCriterion):
...
@@ -928,8 +956,46 @@ class CtcCriterion(FairseqCriterion):
xctc_self_distill_loss
*
self
.
xctc_self_distill_weight
xctc_self_distill_loss
*
self
.
xctc_self_distill_weight
)
)
# calculate KD loss for interpolation augmentation
def
get_mixup_consistent_loss
(
ctc_logit
,
non_padding_mask
,
mixup_pos
,
mixup_real_idx1
,
mixup_real_idx2
):
mixup_consistent_loss
=
0
mixup_real_logit
=
ctc_logit
[:,
mixup_pos
,
:]
no_mixup_logit
=
ctc_logit
[:,
~
mixup_pos
,
:]
mixup_target_logit
=
[
no_mixup_logit
[:,
mixup_real_idx1
,
:],
no_mixup_logit
[:,
mixup_real_idx2
,
:],
]
mixup_target_pad_mask
=
[
non_padding_mask
[
mixup_real_idx1
],
non_padding_mask
[
mixup_real_idx2
],
]
for
logit
,
pad
,
coef
in
zip
(
mixup_target_logit
,
mixup_target_pad_mask
,
loss_coef
):
if
self
.
ctc_mixup_consistent_hard_target
:
loss
=
F
.
kl_div
(
F
.
log_softmax
(
mixup_real_logit
,
dim
=-
1
,
dtype
=
torch
.
float32
),
utils
.
distribution_soft_to_hard
(
logit
.
detach
())
.
to
(
torch
.
float32
),
log_target
=
False
,
reduction
=
"none"
,
)
else
:
loss
=
F
.
kl_div
(
F
.
log_softmax
(
mixup_real_logit
,
dim
=-
1
,
dtype
=
torch
.
float32
),
F
.
log_softmax
(
logit
.
detach
(),
dim
=-
1
,
dtype
=
torch
.
float32
),
log_target
=
True
,
reduction
=
"none"
,
)
mixup_consistent_loss
+=
(
loss
.
sum
(
-
1
)
.
transpose
(
0
,
1
)
.
masked_fill
(
~
pad
,
0.0
)
.
sum
(
-
1
)
*
coef
)
.
sum
()
return
mixup_consistent_loss
ctc_mixup_consistent_loss
=
0
ctc_mixup_consistent_loss
=
0
inter_ctc_mixup_consistent_loss
=
0
inter_ctc_mixup_consistent_loss
=
0
xctc_mixup_consistent_loss
=
0
inter_xctc_mixup_consistent_loss
=
0
if
use_ctc
and
mixup
is
True
:
if
use_ctc
and
mixup
is
True
:
mixup_coef
=
net_output
[
"mixup"
][
"coef"
]
mixup_coef
=
net_output
[
"mixup"
][
"coef"
]
mixup_idx1
=
net_output
[
"mixup"
][
"index1"
]
mixup_idx1
=
net_output
[
"mixup"
][
"index1"
]
...
@@ -941,41 +1007,18 @@ class CtcCriterion(FairseqCriterion):
...
@@ -941,41 +1007,18 @@ class CtcCriterion(FairseqCriterion):
mixup_real_idx1
=
mixup_idx1
[
mixup_pos
]
mixup_real_idx1
=
mixup_idx1
[
mixup_pos
]
mixup_real_idx2
=
mixup_idx2
[
mixup_pos
]
mixup_real_idx2
=
mixup_idx2
[
mixup_pos
]
def
get_ctc_mixup_consistent_loss
(
ctc_logit
,
non_padding_mask
):
mixup_consistent_loss
=
0
mixup_real_logit
=
ctc_logit
[:,
mixup_pos
,
:]
no_mixup_logit
=
ctc_logit
[:,
~
mixup_pos
,
:]
mixup_target_logit
=
[
no_mixup_logit
[:,
mixup_real_idx1
,
:],
no_mixup_logit
[:,
mixup_real_idx2
,
:],
]
mixup_target_pad_mask
=
[
non_padding_mask
[
mixup_real_idx1
],
non_padding_mask
[
mixup_real_idx2
],
]
for
logit
,
pad
,
coef
in
zip
(
mixup_target_logit
,
mixup_target_pad_mask
,
loss_coef
):
loss
=
F
.
kl_div
(
F
.
log_softmax
(
mixup_real_logit
,
dim
=-
1
,
dtype
=
torch
.
float32
),
# F.log_softmax(logit, dim=-1, dtype=torch.float32),
F
.
log_softmax
(
logit
.
detach
(),
dim
=-
1
,
dtype
=
torch
.
float32
),
log_target
=
True
,
reduction
=
"none"
,
)
mixup_consistent_loss
+=
(
loss
.
sum
(
-
1
)
.
transpose
(
0
,
1
)
.
masked_fill
(
~
pad
,
0.0
)
.
sum
(
-
1
)
*
coef
)
.
sum
()
return
mixup_consistent_loss
if
self
.
ctc_mixup_consistent_weight
>
0
:
if
self
.
ctc_mixup_consistent_weight
>
0
:
ctc_logit
=
net_output
[
"ctc_logit"
][
0
]
ctc_logit
=
net_output
[
"ctc_logit"
][
0
]
ctc_mixup_consistent_loss
=
get_
ctc_mixup_consistent_loss
(
ctc_logit
,
non_padding_mask
)
ctc_mixup_consistent_loss
=
get_
mixup_consistent_loss
(
ctc_logit
,
non_padding_mask
,
mixup_pos
,
mixup_real_idx1
,
mixup_real_idx2
)
logging_output
[
"ctc_mixup_consistent_loss"
]
=
utils
.
item
(
logging_output
[
"ctc_mixup_consistent_loss"
]
=
utils
.
item
(
ctc_mixup_consistent_loss
.
data
ctc_mixup_consistent_loss
.
data
)
)
if
self
.
xctc_mixup_consistent_weight
>
0
:
xctc_logit
=
net_output
[
"xctc_logit"
][
0
]
xctc_mixup_consistent_loss
=
get_mixup_consistent_loss
(
xctc_logit
,
non_padding_mask
,
mixup_pos
,
mixup_real_idx1
,
mixup_real_idx2
)
logging_output
[
"xctc_mixup_consistent_loss"
]
=
utils
.
item
(
xctc_mixup_consistent_loss
.
data
)
if
self
.
inter_ctc_mixup_consistent_weight
>
0
:
if
self
.
inter_ctc_mixup_consistent_weight
>
0
:
if
inter_ctc_num
>
0
:
if
inter_ctc_num
>
0
:
...
@@ -989,12 +1032,40 @@ class CtcCriterion(FairseqCriterion):
...
@@ -989,12 +1032,40 @@ class CtcCriterion(FairseqCriterion):
inter_ctc_logit
=
logit
inter_ctc_logit
=
logit
inter_non_padding_mask
=
non_padding_mask
inter_non_padding_mask
=
non_padding_mask
inter_ctc_mixup_consistent_loss
+=
get_ctc_mixup_consistent_loss
(
inter_ctc_logit
,
inter_non_padding_mask
)
inter_ctc_mixup_consistent_loss
+=
get_mixup_consistent_loss
(
inter_ctc_logit
,
inter_non_padding_mask
,
mixup_pos
,
mixup_real_idx1
,
mixup_real_idx2
)
logging_output
[
"inter_ctc_mixup_consistent_loss"
]
=
utils
.
item
(
logging_output
[
"inter_ctc_mixup_consistent_loss"
]
=
utils
.
item
(
inter_ctc_mixup_consistent_loss
.
data
inter_ctc_mixup_consistent_loss
.
data
)
)
if
self
.
inter_xctc_mixup_consistent_weight
>
0
:
if
inter_xctc_num
>
0
:
logits
=
net_output
[
"inter_xctc_logits"
]
for
i
in
range
(
inter_xctc_num
):
logit
=
logits
[
i
]
if
type
(
logit
)
==
list
:
inter_xctc_logit
=
logit
[
0
]
inter_non_padding_mask
=
~
logit
[
1
]
if
logit
[
1
]
is
not
None
else
non_padding_mask
else
:
inter_xctc_logit
=
logit
inter_non_padding_mask
=
non_padding_mask
inter_xctc_mixup_consistent_loss
+=
get_mixup_consistent_loss
(
inter_xctc_logit
,
inter_non_padding_mask
,
mixup_pos
,
mixup_real_idx1
,
mixup_real_idx2
)
logging_output
[
"inter_xctc_mixup_consistent_loss"
]
=
utils
.
item
(
inter_xctc_mixup_consistent_loss
.
data
)
if
len
(
ctc_entropy
)
!=
0
:
if
len
(
ctc_entropy
)
!=
0
:
ctc_entropy
=
sum
(
ctc_entropy
)
/
len
(
ctc_entropy
)
ctc_entropy
=
sum
(
ctc_entropy
)
/
len
(
ctc_entropy
)
logging_output
[
"ctc_entropy"
]
=
utils
.
item
(
ctc_entropy
.
data
)
logging_output
[
"ctc_entropy"
]
=
utils
.
item
(
ctc_entropy
.
data
)
...
@@ -1012,6 +1083,8 @@ class CtcCriterion(FairseqCriterion):
...
@@ -1012,6 +1083,8 @@ class CtcCriterion(FairseqCriterion):
+
self
.
ctc_entropy_weight
*
ctc_entropy
+
self
.
ctc_entropy_weight
*
ctc_entropy
+
self
.
ctc_mixup_consistent_weight
*
ctc_mixup_consistent_loss
+
self
.
ctc_mixup_consistent_weight
*
ctc_mixup_consistent_loss
+
self
.
inter_ctc_mixup_consistent_weight
*
inter_ctc_mixup_consistent_loss
+
self
.
inter_ctc_mixup_consistent_weight
*
inter_ctc_mixup_consistent_loss
+
self
.
xctc_mixup_consistent_weight
*
xctc_mixup_consistent_loss
+
self
.
inter_xctc_mixup_consistent_weight
*
inter_xctc_mixup_consistent_loss
)
)
if
loss
!=
0
:
if
loss
!=
0
:
...
@@ -1137,6 +1210,13 @@ class CtcCriterion(FairseqCriterion):
...
@@ -1137,6 +1210,13 @@ class CtcCriterion(FairseqCriterion):
inter_ctc_mixup_consistent_loss
=
utils
.
item
(
inter_ctc_mixup_consistent_loss
=
utils
.
item
(
sum
(
log
.
get
(
"inter_ctc_mixup_consistent_loss"
,
0
)
for
log
in
logging_outputs
)
sum
(
log
.
get
(
"inter_ctc_mixup_consistent_loss"
,
0
)
for
log
in
logging_outputs
)
)
)
xctc_mixup_consistent_loss
=
utils
.
item
(
sum
(
log
.
get
(
"xctc_mixup_consistent_loss"
,
0
)
for
log
in
logging_outputs
)
)
inter_xctc_mixup_consistent_loss
=
utils
.
item
(
sum
(
log
.
get
(
"inter_xctc_mixup_consistent_loss"
,
0
)
for
log
in
logging_outputs
)
)
all_ctc_loss_sum
=
utils
.
item
(
all_ctc_loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"all_ctc_loss"
,
0
)
for
log
in
logging_outputs
)
sum
(
log
.
get
(
"all_ctc_loss"
,
0
)
for
log
in
logging_outputs
)
)
)
...
@@ -1245,6 +1325,20 @@ class CtcCriterion(FairseqCriterion):
...
@@ -1245,6 +1325,20 @@ class CtcCriterion(FairseqCriterion):
sample_size
,
sample_size
,
round
=
3
,
round
=
3
,
)
)
if
xctc_mixup_consistent_loss
>
0
:
metrics
.
log_scalar
(
"xctc_mixup_consistent_loss"
,
xctc_mixup_consistent_loss
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
if
inter_xctc_mixup_consistent_loss
>
0
:
metrics
.
log_scalar
(
"inter_xctc_mixup_consistent_loss"
,
inter_xctc_mixup_consistent_loss
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
metrics
.
log_scalar
(
"ntokens"
,
ntokens
)
metrics
.
log_scalar
(
"ntokens"
,
ntokens
)
metrics
.
log_scalar
(
"nsentences"
,
nsentences
)
metrics
.
log_scalar
(
"nsentences"
,
nsentences
)
...
...
fairseq/criterions/label_smoothed_cross_entropy.py
查看文件 @
cb2f2bcb
...
@@ -25,7 +25,7 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
...
@@ -25,7 +25,7 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
default
=
0.0
,
default
=
0.0
,
metadata
=
{
"help"
:
"the weight for consistency regularization of mixup"
},
metadata
=
{
"help"
:
"the weight for consistency regularization of mixup"
},
)
)
cal_mixup
_loss
:
bool
=
field
(
mixup_no_hard
_loss
:
bool
=
field
(
default
=
False
,
default
=
False
,
metadata
=
{
"help"
:
"calculate the loss for the mixed samples"
},
metadata
=
{
"help"
:
"calculate the loss for the mixed samples"
},
)
)
...
@@ -71,7 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -71,7 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
label_smoothing
,
label_smoothing
,
ignore_prefix_size
=
0
,
ignore_prefix_size
=
0
,
report_accuracy
=
False
,
report_accuracy
=
False
,
cal_mixup_loss
=
Tru
e
,
mixup_no_hard_loss
=
Fals
e
,
mixup_consistent_weight
=
0.0
,
mixup_consistent_weight
=
0.0
,
):
):
super
()
.
__init__
(
task
)
super
()
.
__init__
(
task
)
...
@@ -79,7 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -79,7 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self
.
eps
=
float
(
label_smoothing
)
self
.
eps
=
float
(
label_smoothing
)
self
.
ignore_prefix_size
=
ignore_prefix_size
self
.
ignore_prefix_size
=
ignore_prefix_size
self
.
report_accuracy
=
report_accuracy
self
.
report_accuracy
=
report_accuracy
self
.
cal_mixup_loss
=
cal_mixup
_loss
self
.
mixup_no_hard_loss
=
mixup_no_hard
_loss
self
.
mixup_consistent_weight
=
mixup_consistent_weight
self
.
mixup_consistent_weight
=
mixup_consistent_weight
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
...
@@ -173,7 +173,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -173,7 +173,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
mixup_coef
=
net_output
[
1
][
"mixup"
][
"coef"
][
mixup_flag
]
mixup_coef
=
net_output
[
1
][
"mixup"
][
"coef"
][
mixup_flag
]
loss_coef
=
[
mixup_coef
,
1
-
mixup_coef
]
loss_coef
=
[
mixup_coef
,
1
-
mixup_coef
]
if
self
.
cal_mixup
_loss
:
if
not
self
.
mixup_no_hard
_loss
:
for
item_lprobs
,
item_target
,
item_coef
in
zip
(
mixup_lprobs
,
mixup_targets
,
loss_coef
):
for
item_lprobs
,
item_target
,
item_coef
in
zip
(
mixup_lprobs
,
mixup_targets
,
loss_coef
):
batch_size
=
item_target
.
size
(
0
)
batch_size
=
item_target
.
size
(
0
)
item_loss
,
item_nll_loss
=
label_smoothed_nll_loss
(
item_loss
,
item_nll_loss
=
label_smoothed_nll_loss
(
...
...
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
查看文件 @
cb2f2bcb
...
@@ -30,19 +30,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
...
@@ -30,19 +30,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
cfg
:
CtcCriterionConfig
,
cfg
:
CtcCriterionConfig
,
ctc_weight
=
0.0
,
ctc_weight
=
0.0
,
save_dir
=
None
,
save_dir
=
None
,
cal_mixup_loss
=
Tru
e
,
mixup_no_hard_loss
=
Fals
e
,
mixup_consistent_weight
=
0.0
,
mixup_consistent_weight
=
0.0
,
only_train_enc_prob
=
0.0
,
only_train_enc_prob
=
0.0
,
get_oracle_when_only_train_enc
=
False
get_oracle_when_only_train_enc
=
False
):
):
super
()
.
__init__
(
task
,
sentence_avg
,
label_smoothing
,
super
()
.
__init__
(
task
,
sentence_avg
,
label_smoothing
,
report_accuracy
=
True
,
report_accuracy
=
True
,
cal_mixup_loss
=
cal_mixup
_loss
,
mixup_no_hard_loss
=
mixup_no_hard
_loss
,
mixup_consistent_weight
=
mixup_consistent_weight
)
mixup_consistent_weight
=
mixup_consistent_weight
)
self
.
report_accuracy
=
True
self
.
report_accuracy
=
True
self
.
ctc_weight
=
ctc_weight
self
.
ctc_weight
=
ctc_weight
self
.
ctc_criterion
=
CtcCriterion
(
cfg
,
task
,
ctc_weight
,
save_dir
)
self
.
ctc_criterion
=
CtcCriterion
(
cfg
,
task
,
ctc_weight
,
save_dir
,
mixup_no_hard_loss
)
self
.
save_dir
=
save_dir
self
.
save_dir
=
save_dir
self
.
only_train_enc_prob
=
only_train_enc_prob
self
.
only_train_enc_prob
=
only_train_enc_prob
...
...
fairseq/data/audio/speech_to_text_dataset.py
查看文件 @
cb2f2bcb
...
@@ -358,7 +358,7 @@ class SpeechToTextDataset(FairseqDataset):
...
@@ -358,7 +358,7 @@ class SpeechToTextDataset(FairseqDataset):
def
check_tgt_lang_tag
(
self
):
def
check_tgt_lang_tag
(
self
):
if
self
.
data_cfg
.
prepend_tgt_lang_tag
:
if
self
.
data_cfg
.
prepend_tgt_lang_tag
:
assert
self
.
tgt_langs
is
not
None
and
self
.
tgt_dict
is
not
None
assert
self
.
tgt_langs
is
not
None
and
self
.
tgt_dict
is
not
None
tgt_lang_tags
=
[
tgt_lang_tags
=
[
self
.
LANG_TAG_TEMPLATE
.
format
(
t
)
for
t
in
set
(
self
.
tgt_langs
)
self
.
LANG_TAG_TEMPLATE
.
format
(
t
)
for
t
in
set
(
self
.
tgt_langs
)
]
]
assert
all
(
t
in
self
.
tgt_dict
for
t
in
tgt_lang_tags
)
assert
all
(
t
in
self
.
tgt_dict
for
t
in
tgt_lang_tags
)
...
...
fairseq/dataclass/configs.py
查看文件 @
cb2f2bcb
...
@@ -198,6 +198,12 @@ class CommonConfig(FairseqDataclass):
...
@@ -198,6 +198,12 @@ class CommonConfig(FairseqDataclass):
"help"
:
"training steps in each epoch"
"help"
:
"training steps in each epoch"
}
}
)
)
sharded_data_load
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Use sharded data for efficient data load"
},
)
@dataclass
@dataclass
...
@@ -812,6 +818,14 @@ class GenerationConfig(FairseqDataclass):
...
@@ -812,6 +818,14 @@ class GenerationConfig(FairseqDataclass):
default
=
0.0
,
default
=
0.0
,
metadata
=
{
"help"
:
"weight for ctc probs for lm fusion"
},
metadata
=
{
"help"
:
"weight for ctc probs for lm fusion"
},
)
)
early_exit_count
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"early exit during decoding when n consecutive predictions are the same"
},
)
early_exit_layer
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"early exit during decoding at layer n"
},
)
# arguments for iterative refinement generator
# arguments for iterative refinement generator
iter_decode_eos_penalty
:
float
=
field
(
iter_decode_eos_penalty
:
float
=
field
(
...
...
fairseq/models/speech_to_text/pdss2t_transformer.py
查看文件 @
cb2f2bcb
...
@@ -1019,12 +1019,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -1019,12 +1019,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
):
):
if
hasattr
(
self
,
"ctc"
):
if
hasattr
(
self
,
"ctc"
):
import
os
assert
src_dict
is
not
None
assert
src_dict
is
not
None
self
.
ctc
.
set_infer
(
self
.
ctc
.
set_infer
(
ctc_infer
,
ctc_infer
,
post_process
,
post_process
,
src_dict
,
src_dict
,
path
=
path
+
".ctc"
if
path
is
not
None
else
None
,
path
=
os
.
path
.
splitext
(
path
)[
0
]
+
".ctc"
if
path
is
not
None
else
None
,
)
)
def
ctc_valid
(
self
,
lprobs
,
targets
,
input_lengths
,
dictionary
,
lang
=
"source"
):
def
ctc_valid
(
self
,
lprobs
,
targets
,
input_lengths
,
dictionary
,
lang
=
"source"
):
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
cb2f2bcb
...
@@ -250,7 +250,6 @@ class CTCDecoder(object):
...
@@ -250,7 +250,6 @@ class CTCDecoder(object):
logger
.
info
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
logger
.
info
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
print
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
print
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
from
torchprofile
import
profile_macs
from
torchprofile
import
profile_macs
macs
=
profile_macs
(
self
.
model
,
[
src_tokens
,
src_lengths
])
macs
=
profile_macs
(
self
.
model
,
[
src_tokens
,
src_lengths
])
gmacs
=
macs
/
1e9
gmacs
=
macs
/
1e9
...
@@ -269,20 +268,22 @@ class CTCDecoder(object):
...
@@ -269,20 +268,22 @@ class CTCDecoder(object):
inter_logits
=
encoder_outs
.
get
(
"inter_xctc_logits"
,
[])
inter_logits
=
encoder_outs
.
get
(
"inter_xctc_logits"
,
[])
if
ctc_logit
is
None
:
if
ctc_logit
is
None
:
ctc_logit
=
encoder_outs
[
"ctc_logit"
][
0
]
.
transpose
(
0
,
1
)
ctc_logit
=
encoder_outs
[
"ctc_logit"
][
0
]
.
transpose
(
0
,
1
)
if
len
(
inter_logits
)
>
0
:
if
len
(
inter_logits
)
==
0
:
inter_logits
=
encoder_outs
.
get
(
"inter_ctc_logits"
,
[])
inter_logits
=
encoder_outs
.
get
(
"inter_ctc_logits"
,
[])
inter_logits_num
=
len
(
inter_logits
)
inter_logits_num
=
len
(
inter_logits
)
encoder_padding_mask
=
encoder_outs
[
"encoder_padding_mask"
][
0
]
encoder_padding_mask
=
encoder_outs
[
"encoder_padding_mask"
][
0
]
if
self
.
ctc_inter_logit
!=
0
:
if
self
.
ctc_inter_logit
!=
0
:
assert
inter_logits_num
>=
self
.
ctc_inter_logit
if
inter_logits_num
!=
0
:
if
inter_logits_num
!=
0
:
assert
self
.
ctc_inter_logit
<=
inter_logits_num
ctc_logit_item
=
inter_logits
[
-
self
.
ctc_inter_logit
]
ctc_logit_item
=
inter_logits
[
-
self
.
ctc_inter_logit
]
if
isinstance
(
ctc_logit_item
,
list
):
if
isinstance
(
ctc_logit_item
,
list
):
ctc_logit
=
ctc_logit_item
[
0
]
.
transpose
(
0
,
1
)
ctc_logit
=
ctc_logit_item
[
0
]
.
transpose
(
0
,
1
)
if
len
(
ctc_logit_item
)
>=
2
:
if
len
(
ctc_logit_item
)
>=
2
:
encoder_padding_mask
=
ctc_logit_item
[
1
]
encoder_padding_mask
=
ctc_logit_item
[
1
]
else
:
ctc_logit
=
ctc_logit_item
.
transpose
(
0
,
1
)
logit_length
=
(
~
encoder_padding_mask
)
.
long
()
.
sum
(
-
1
)
logit_length
=
(
~
encoder_padding_mask
)
.
long
()
.
sum
(
-
1
)
finalized
=
[]
finalized
=
[]
...
@@ -318,7 +319,7 @@ class CTCDecoder(object):
...
@@ -318,7 +319,7 @@ class CTCDecoder(object):
else
:
else
:
logit
=
inter_logits
[
i
]
logit
=
inter_logits
[
i
]
inter_logits_prob
=
utils
.
log_softmax
(
logit
s
.
transpose
(
0
,
1
),
-
1
)
inter_logits_prob
=
utils
.
log_softmax
(
logit
.
transpose
(
0
,
1
),
-
1
)
ctc_probs
+=
inter_logits_prob
ctc_probs
+=
inter_logits_prob
topk_prob
,
topk_index
=
ctc_probs
.
topk
(
1
,
dim
=
2
)
topk_prob
,
topk_index
=
ctc_probs
.
topk
(
1
,
dim
=
2
)
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
cb2f2bcb
...
@@ -888,6 +888,8 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -888,6 +888,8 @@ class S2TTransformerEncoder(FairseqEncoder):
super
()
.
__init__
(
None
)
super
()
.
__init__
(
None
)
dim
=
args
.
encoder_embed_dim
dim
=
args
.
encoder_embed_dim
self
.
source_dictionary
=
task
.
source_dictionary
self
.
target_dictionary
=
task
.
target_dictionary
layer_num
=
args
.
encoder_layers
layer_num
=
args
.
encoder_layers
self
.
dropout_module
=
FairseqDropout
(
self
.
dropout_module
=
FairseqDropout
(
p
=
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
p
=
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
...
@@ -1027,6 +1029,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -1027,6 +1029,7 @@ class S2TTransformerEncoder(FairseqEncoder):
)
)
),
),
dropout
=
args
.
dropout
,
dropout
=
args
.
dropout
,
dictionary
=
task
.
source_dictionary
)
)
setattr
(
self
,
"inter_ctc
%
d"
%
layer_idx
,
inter_ctc
)
setattr
(
self
,
"inter_ctc
%
d"
%
layer_idx
,
inter_ctc
)
# inter_layer_norm = LayerNorm(dim)
# inter_layer_norm = LayerNorm(dim)
...
@@ -1038,6 +1041,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -1038,6 +1041,7 @@ class S2TTransformerEncoder(FairseqEncoder):
dim
,
dim
,
dictionary_size
=
len
(
task
.
source_dictionary
),
dictionary_size
=
len
(
task
.
source_dictionary
),
dropout
=
args
.
dropout
,
dropout
=
args
.
dropout
,
dictionary
=
task
.
source_dictionary
,
)
)
if
(
if
(
getattr
(
args
,
"share_ctc_and_embed"
,
False
)
getattr
(
args
,
"share_ctc_and_embed"
,
False
)
...
@@ -1116,6 +1120,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -1116,6 +1120,7 @@ class S2TTransformerEncoder(FairseqEncoder):
else
len
(
task
.
target_dictionary
),
else
len
(
task
.
target_dictionary
),
dropout
=
args
.
dropout
,
dropout
=
args
.
dropout
,
need_layernorm
=
True
if
self
.
inter_xctc
else
False
,
need_layernorm
=
True
if
self
.
inter_xctc
else
False
,
dictionary
=
task
.
target_dictionary
,
)
)
if
(
if
(
...
@@ -1374,6 +1379,10 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -1374,6 +1379,10 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
save_rep
=
False
self
.
save_rep
=
False
self
.
mixup_infer
=
False
self
.
mixup_infer
=
False
self
.
rep_dict
=
dict
()
self
.
rep_dict
=
dict
()
self
.
early_exit_count
=
0
self
.
early_exit_layer_record
=
[]
self
.
early_exit_layer
=
0
@staticmethod
@staticmethod
def
build_encoder_layer
(
args
):
def
build_encoder_layer
(
args
):
...
@@ -1400,6 +1409,10 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -1400,6 +1409,10 @@ class S2TTransformerEncoder(FairseqEncoder):
layer
,
"dump"
layer
,
"dump"
)
else
None
)
else
None
print
(
"Early exit layer."
,
file
=
fstream
)
if
self
.
early_exit_count
!=
0
:
print
(
"
\n
"
.
join
([
str
(
l
)
for
l
in
self
.
early_exit_layer_record
]),
file
=
fstream
)
if
self
.
gather_cos_sim
:
if
self
.
gather_cos_sim
:
print
(
print
(
"
\n
Cosine similarity of distance
%
d"
%
self
.
gather_cos_sim_dis
,
"
\n
Cosine similarity of distance
%
d"
%
self
.
gather_cos_sim_dis
,
...
@@ -1540,13 +1553,17 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -1540,13 +1553,17 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
mixup_infer
=
kwargs
.
get
(
"mixup_infer"
,
False
)
self
.
mixup_infer
=
kwargs
.
get
(
"mixup_infer"
,
False
)
self
.
gather_cos_sim
=
kwargs
.
get
(
"gather_cos_sim"
,
False
)
self
.
gather_cos_sim
=
kwargs
.
get
(
"gather_cos_sim"
,
False
)
self
.
gather_cos_sim_dis
=
kwargs
.
get
(
"gather_cos_sim_dis"
,
2
)
self
.
gather_cos_sim_dis
=
kwargs
.
get
(
"gather_cos_sim_dis"
,
2
)
self
.
early_exit_layer
=
kwargs
.
get
(
"early_exit_layer"
,
0
)
if
self
.
early_exit_layer
!=
0
:
logger
.
info
(
"Using the logit in layer
%
d to infer."
%
self
.
early_exit_layer
)
if
self
.
mixup_infer
:
if
self
.
mixup_infer
:
self
.
mixup_keep_org
=
True
self
.
mixup_keep_org
=
True
def
set_ctc_infer
(
def
set_ctc_infer
(
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
,
early_exit_count
=
0
):
):
self
.
early_exit_count
=
early_exit_count
if
hasattr
(
self
,
"ctc"
):
if
hasattr
(
self
,
"ctc"
):
assert
src_dict
is
not
None
assert
src_dict
is
not
None
self
.
ctc
.
set_infer
(
self
.
ctc
.
set_infer
(
...
@@ -1711,6 +1728,27 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -1711,6 +1728,27 @@ class S2TTransformerEncoder(FairseqEncoder):
org_x
=
x
[:,
~
flag
,
:]
.
mean
(
0
)
org_x
=
x
[:,
~
flag
,
:]
.
mean
(
0
)
rep_dict
[
layer_idx
]
.
append
(
org_x
)
rep_dict
[
layer_idx
]
.
append
(
org_x
)
def
early_exit_or_not
(
self
,
history
,
new_logit
,
count
):
history
.
append
(
new_logit
)
length
=
len
(
history
)
if
count
==
0
or
length
<
count
:
return
False
else
:
# for logit in history[length - count: length - 1]:
# if new_logit.size() != logit.size() or not (new_logit == logit).all():
# return False
# return True
hit
=
0
for
logit
in
history
[:
length
-
1
]:
if
new_logit
.
size
()
==
logit
.
size
()
and
(
new_logit
==
logit
)
.
all
():
hit
+=
1
if
hit
>=
count
:
return
True
else
:
return
False
def
forward
(
self
,
src_tokens
,
src_lengths
=
None
,
**
kwargs
):
def
forward
(
self
,
src_tokens
,
src_lengths
=
None
,
**
kwargs
):
layer_idx
=
-
1
layer_idx
=
-
1
...
@@ -1727,6 +1765,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -1727,6 +1765,7 @@ class S2TTransformerEncoder(FairseqEncoder):
# (B, T, D) -> (T, B, D)
# (B, T, D) -> (T, B, D)
x
=
src_tokens
.
transpose
(
0
,
1
)
x
=
src_tokens
.
transpose
(
0
,
1
)
input_lengths
=
src_lengths
input_lengths
=
src_lengths
org_bsz
=
x
.
size
(
1
)
if
(
if
(
(
self
.
training
or
self
.
mixup_infer
)
(
self
.
training
or
self
.
mixup_infer
)
...
@@ -1821,6 +1860,16 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -1821,6 +1860,16 @@ class S2TTransformerEncoder(FairseqEncoder):
xctc_oracle_mask
=
None
xctc_oracle_mask
=
None
xctc_force_emit
=
None
xctc_force_emit
=
None
# Infer early exit
batch_idx_dict
=
dict
()
inter_ctc_logits_history
=
dict
()
final_ctc_logits
=
dict
()
final_encoder_padding_mask
=
dict
()
early_exit_layer
=
dict
()
for
i
in
range
(
x
.
size
(
1
)):
inter_ctc_logits_history
[
i
]
=
[]
batch_idx_dict
[
i
]
=
i
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
if
self
.
history
is
not
None
:
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
x
=
self
.
history
.
pop
()
...
@@ -1879,7 +1928,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -1879,7 +1928,7 @@ class S2TTransformerEncoder(FairseqEncoder):
# Inter CTC
# Inter CTC
if
layer_idx
in
self
.
inter_ctc_layers
:
if
layer_idx
in
self
.
inter_ctc_layers
:
if
self
.
inter_ctc_drop_prob
>
0
:
if
self
.
training
and
self
.
inter_ctc_drop_prob
>
0
:
p
=
torch
.
rand
(
1
)
.
uniform_
()
p
=
torch
.
rand
(
1
)
.
uniform_
()
if
p
<
self
.
inter_ctc_drop_prob
:
if
p
<
self
.
inter_ctc_drop_prob
:
break
break
...
@@ -1942,9 +1991,47 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -1942,9 +1991,47 @@ class S2TTransformerEncoder(FairseqEncoder):
[
pae_input
,
logit
],
encoder_padding_mask
,
ctc_oracle
,
ctc_oracle_mask
[
pae_input
,
logit
],
encoder_padding_mask
,
ctc_oracle
,
ctc_oracle_mask
)
)
self
.
show_debug
(
x
,
"x after pae"
)
self
.
show_debug
(
x
,
"x after pae"
)
inter_ctc_logits
.
append
(
inter_logit
)
inter_ctc_logits
.
append
(
inter_logit
)
if
not
self
.
training
and
self
.
early_exit_layer
==
layer_idx
:
ctc_logit
=
inter_logit
[
0
]
break
if
not
self
.
training
and
self
.
early_exit_count
!=
0
:
predicts
=
inter_ctc
.
predict
(
inter_logit
[
0
],
encoder_padding_mask
)
if
len
(
inter_ctc_logits
)
<
self
.
early_exit_count
:
for
i
in
range
(
x
.
size
(
1
)):
inter_ctc_logits_history
[
i
]
.
append
(
predicts
[
i
])
else
:
if
org_bsz
==
1
:
early_exit_flag
=
self
.
early_exit_or_not
(
inter_ctc_logits_history
[
0
],
predicts
[
0
],
self
.
early_exit_count
)
if
early_exit_flag
:
ctc_logit
=
inter_logit
[
0
]
self
.
early_exit_layer_record
.
append
(
layer_idx
)
break
else
:
idx
=
0
keep_idx
=
[]
new_batch_idx_dict
=
dict
()
for
i
in
range
(
x
.
size
(
1
)):
real_idx
=
batch_idx_dict
[
i
]
early_exit_flag
=
self
.
early_exit_or_not
(
inter_ctc_logits_history
[
real_idx
],
predicts
[
i
],
self
.
early_exit_count
)
if
early_exit_flag
:
final_ctc_logits
[
real_idx
]
=
inter_logit
[
0
][:,
i
,
:]
final_encoder_padding_mask
[
real_idx
]
=
encoder_padding_mask
[
i
,
:]
early_exit_layer
[
real_idx
]
=
layer_idx
else
:
keep_idx
.
append
(
i
)
new_batch_idx_dict
[
idx
]
=
real_idx
idx
+=
1
if
idx
==
0
:
break
if
idx
<
x
.
size
(
1
):
batch_idx_dict
=
new_batch_idx_dict
x
=
x
[:,
keep_idx
,
:]
.
contiguous
()
encoder_padding_mask
=
encoder_padding_mask
[
keep_idx
,
:]
.
contiguous
()
if
layer_idx
in
self
.
compression_layers
:
if
layer_idx
in
self
.
compression_layers
:
ctc_prob
=
utils
.
softmax
(
logit
,
dim
=-
1
)
# (T B C)
ctc_prob
=
utils
.
softmax
(
logit
,
dim
=-
1
)
# (T B C)
blank_prob
=
ctc_prob
[:,
:,
0
]
blank_prob
=
ctc_prob
[:,
:,
0
]
...
@@ -2133,6 +2220,25 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -2133,6 +2220,25 @@ class S2TTransformerEncoder(FairseqEncoder):
)
)
self
.
show_debug
(
x
,
"x after xctc"
)
self
.
show_debug
(
x
,
"x after xctc"
)
if
not
self
.
training
and
self
.
early_exit_count
!=
0
and
org_bsz
!=
1
:
if
layer_idx
==
len
(
self
.
layers
)
+
1
:
for
i
in
range
(
x
.
size
(
1
)):
real_idx
=
batch_idx_dict
[
i
]
final_ctc_logits
[
real_idx
]
=
ctc_logit
[:,
i
,
:]
final_encoder_padding_mask
[
real_idx
]
=
encoder_padding_mask
[
i
,
:]
early_exit_layer
[
real_idx
]
=
layer_idx
-
1
output_logits
=
[]
output_encoder_padding_mask
=
[]
output_layers
=
[]
for
i
in
range
(
len
(
final_ctc_logits
)):
output_logits
.
append
(
final_ctc_logits
[
i
])
output_encoder_padding_mask
.
append
(
final_encoder_padding_mask
[
i
])
output_layers
.
append
(
early_exit_layer
[
i
])
ctc_logit
=
torch
.
stack
(
output_logits
,
dim
=
0
)
.
transpose
(
0
,
1
)
encoder_padding_mask
=
torch
.
stack
(
output_encoder_padding_mask
,
dim
=
0
)
self
.
early_exit_layer_record
.
extend
(
output_layers
)
if
ctc_force_emit
is
not
None
:
if
ctc_force_emit
is
not
None
:
ctc_logit
=
[
ctc_logit
,
None
,
ctc_force_emit
]
ctc_logit
=
[
ctc_logit
,
None
,
ctc_force_emit
]
...
@@ -2174,6 +2280,11 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -2174,6 +2280,11 @@ class S2TTransformerEncoder(FairseqEncoder):
if
len
(
encoder_out
[
"xctc_logit"
])
==
0
if
len
(
encoder_out
[
"xctc_logit"
])
==
0
else
[
x
.
index_select
(
1
,
new_order
)
for
x
in
encoder_out
[
"xctc_logit"
]]
else
[
x
.
index_select
(
1
,
new_order
)
for
x
in
encoder_out
[
"xctc_logit"
]]
)
)
new_inter_ctc_logits
=
(
[]
if
len
(
encoder_out
[
"inter_ctc_logits"
])
==
0
else
[[
x
[
0
]
.
index_select
(
1
,
new_order
)]
.
extend
(
x
[
1
:])
if
isinstance
(
x
,
list
)
else
x
.
index_select
(
1
,
new_order
)
for
x
in
encoder_out
[
"inter_ctc_logits"
]
if
x
is
not
None
]
)
new_encoder_padding_mask
=
(
new_encoder_padding_mask
=
(
[]
[]
if
len
(
encoder_out
[
"encoder_padding_mask"
])
==
0
if
len
(
encoder_out
[
"encoder_padding_mask"
])
==
0
...
@@ -2200,6 +2311,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -2200,6 +2311,7 @@ class S2TTransformerEncoder(FairseqEncoder):
"encoder_out"
:
new_encoder_out
,
# T x B x C
"encoder_out"
:
new_encoder_out
,
# T x B x C
"ctc_logit"
:
new_ctc_logit
,
# T x B x C
"ctc_logit"
:
new_ctc_logit
,
# T x B x C
"xctc_logit"
:
new_xctc_logit
,
"xctc_logit"
:
new_xctc_logit
,
"inter_ctc_logits"
:
new_inter_ctc_logits
,
"encoder_padding_mask"
:
new_encoder_padding_mask
,
# B x T
"encoder_padding_mask"
:
new_encoder_padding_mask
,
# B x T
"encoder_embedding"
:
new_encoder_embedding
,
# B x T x C
"encoder_embedding"
:
new_encoder_embedding
,
# B x T x C
"encoder_states"
:
encoder_states
,
# List[T x B x C]
"encoder_states"
:
encoder_states
,
# List[T x B x C]
...
...
fairseq/modules/speech_to_text/ctc.py
查看文件 @
cb2f2bcb
...
@@ -74,6 +74,19 @@ class CTC(nn.Module):
...
@@ -74,6 +74,19 @@ class CTC(nn.Module):
def
argmax
(
self
,
x
):
def
argmax
(
self
,
x
):
return
torch
.
argmax
(
self
.
ctc_projection
(
x
),
dim
=-
1
)
return
torch
.
argmax
(
self
.
ctc_projection
(
x
),
dim
=-
1
)
def
predict
(
self
,
logits
,
padding
):
input_lengths
=
(
~
padding
)
.
sum
(
-
1
)
logits
=
logits
.
transpose
(
0
,
1
)
.
float
()
.
contiguous
()
predicts
=
[]
for
logit
,
inp_l
in
zip
(
logits
,
input_lengths
):
toks
=
logit
[:
inp_l
]
.
argmax
(
dim
=-
1
)
.
unique_consecutive
()
pred_units_arr
=
toks
[
toks
!=
self
.
dictionary
.
bos
()]
# pred_units_arr = logit[:inp_l].argmax(dim=-1)
predicts
.
append
(
pred_units_arr
)
return
predicts
def
infer
(
self
,
logits_or_probs
,
lengths
,
tag
=
None
):
def
infer
(
self
,
logits_or_probs
,
lengths
,
tag
=
None
):
for
lp
,
inp_l
in
zip
(
for
lp
,
inp_l
in
zip
(
logits_or_probs
,
logits_or_probs
,
...
...
fairseq/tasks/fairseq_task.py
查看文件 @
cb2f2bcb
...
@@ -130,6 +130,9 @@ class FairseqTask(object):
...
@@ -130,6 +130,9 @@ class FairseqTask(object):
"""
"""
return
cls
(
cfg
,
**
kwargs
)
return
cls
(
cfg
,
**
kwargs
)
def
sharded_data_load
(
self
):
return
getattr
(
self
.
cfg
,
"sharded_data_load"
,
False
)
def
has_sharded_data
(
self
,
split
):
def
has_sharded_data
(
self
,
split
):
return
os
.
pathsep
in
getattr
(
self
.
cfg
,
"data"
,
""
)
return
os
.
pathsep
in
getattr
(
self
.
cfg
,
"data"
,
""
)
...
@@ -619,6 +622,9 @@ class LegacyFairseqTask(FairseqTask):
...
@@ -619,6 +622,9 @@ class LegacyFairseqTask(FairseqTask):
"""
"""
return
cls
(
args
,
**
kwargs
)
return
cls
(
args
,
**
kwargs
)
def
sharded_data_load
(
self
):
return
getattr
(
self
.
args
,
"sharded_data_load"
,
False
)
def
has_sharded_data
(
self
,
split
):
def
has_sharded_data
(
self
,
split
):
return
os
.
pathsep
in
getattr
(
self
.
args
,
"data"
,
""
)
return
os
.
pathsep
in
getattr
(
self
.
args
,
"data"
,
""
)
...
...
fairseq/trainer.py
查看文件 @
cb2f2bcb
...
@@ -521,16 +521,25 @@ class Trainer(object):
...
@@ -521,16 +521,25 @@ class Trainer(object):
disable_iterator_cache
=
False
,
disable_iterator_cache
=
False
,
):
):
"""Return an EpochBatchIterator over the training set for a given epoch."""
"""Return an EpochBatchIterator over the training set for a given epoch."""
if
self
.
task
.
sharded_data_load
():
datasets
=
self
.
cfg
.
dataset
.
train_subset
.
split
(
","
)
curr_dataset
=
datasets
[(
epoch
-
1
)
%
len
(
datasets
)]
logger
.
info
(
"sharded loading the training subset {}"
.
format
(
curr_dataset
))
else
:
curr_dataset
=
self
.
cfg
.
dataset
.
train_subset
load_dataset
=
load_dataset
or
self
.
task
.
sharded_data_load
()
disable_iterator_cache
=
disable_iterator_cache
or
self
.
task
.
sharded_data_load
()
if
load_dataset
:
if
load_dataset
:
logger
.
info
(
"loading train data for epoch {}"
.
format
(
epoch
))
logger
.
info
(
"loading train data for epoch {}"
.
format
(
epoch
))
self
.
task
.
load_dataset
(
self
.
task
.
load_dataset
(
self
.
cfg
.
dataset
.
train_sub
set
,
curr_data
set
,
epoch
=
epoch
,
epoch
=
epoch
,
combine
=
combine
,
combine
=
combine
,
data_selector
=
data_selector
,
data_selector
=
data_selector
,
)
)
batch_iterator
=
self
.
task
.
get_batch_iterator
(
batch_iterator
=
self
.
task
.
get_batch_iterator
(
dataset
=
self
.
task
.
dataset
(
self
.
cfg
.
dataset
.
train_sub
set
),
dataset
=
self
.
task
.
dataset
(
curr_data
set
),
max_tokens
=
self
.
cfg
.
dataset
.
max_tokens
,
max_tokens
=
self
.
cfg
.
dataset
.
max_tokens
,
max_sentences
=
self
.
cfg
.
dataset
.
batch_size
,
max_sentences
=
self
.
cfg
.
dataset
.
batch_size
,
max_positions
=
utils
.
resolve_max_positions
(
max_positions
=
utils
.
resolve_max_positions
(
...
...
fairseq/utils.py
查看文件 @
cb2f2bcb
...
@@ -754,3 +754,18 @@ def freeze_parameters(module, freeze_module_name):
...
@@ -754,3 +754,18 @@ def freeze_parameters(module, freeze_module_name):
freeze_module_name
=
freeze_module_name
.
split
(
","
)
freeze_module_name
=
freeze_module_name
.
split
(
","
)
for
name
in
freeze_module_name
:
for
name
in
freeze_module_name
:
freeze_module_params_by_name
(
module
,
name
)
freeze_module_params_by_name
(
module
,
name
)
def
distribution_soft_to_hard
(
logit_or_prob
):
argmax_prob
=
torch
.
argmax
(
logit_or_prob
,
dim
=-
1
,
keepdim
=
True
)
hard_distribution
=
(
(
argmax_prob
==
torch
.
arange
(
logit_or_prob
.
size
(
-
1
),
device
=
logit_or_prob
.
device
)
.
unsqueeze
(
0
)
)
.
to
(
logit_or_prob
.
dtype
)
)
return
hard_distribution
\ No newline at end of file
fairseq_cli/generate.py
查看文件 @
cb2f2bcb
...
@@ -108,8 +108,8 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
...
@@ -108,8 +108,8 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for
model
in
models
:
for
model
in
models
:
if
hasattr
(
model
,
"encoder"
)
and
hasattr
(
model
.
encoder
,
"set_ctc_infer"
):
if
hasattr
(
model
,
"encoder"
)
and
hasattr
(
model
.
encoder
,
"set_ctc_infer"
):
model
.
encoder
.
set_ctc_infer
(
cfg
.
generation
.
ctc_infer
,
"sentencepiece"
,
model
.
encoder
.
set_ctc_infer
(
cfg
.
generation
.
ctc_infer
,
"sentencepiece"
,
src_dict
,
tgt_dict
,
translation_path
)
src_dict
,
tgt_dict
,
translation_path
,
cfg
.
generation
.
early_exit_count
)
if
hasattr
(
model
,
"encoder"
)
and
hasattr
(
model
.
encoder
,
"set_flag"
):
if
hasattr
(
model
,
"encoder"
)
and
hasattr
(
model
.
encoder
,
"set_flag"
):
model
.
encoder
.
set_flag
(
model
.
encoder
.
set_flag
(
cal_localness
=
cfg
.
generation
.
cal_localness
,
cal_localness
=
cfg
.
generation
.
cal_localness
,
localness_window
=
cfg
.
generation
.
localness_window
,
localness_window
=
cfg
.
generation
.
localness_window
,
...
@@ -120,6 +120,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
...
@@ -120,6 +120,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
mixup_infer
=
cfg
.
generation
.
mixup_infer
,
mixup_infer
=
cfg
.
generation
.
mixup_infer
,
gather_cos_sim
=
cfg
.
generation
.
gather_cos_sim
,
gather_cos_sim
=
cfg
.
generation
.
gather_cos_sim
,
gather_cos_sim_dis
=
cfg
.
generation
.
gather_cos_sim_dis
,
gather_cos_sim_dis
=
cfg
.
generation
.
gather_cos_sim_dis
,
early_exit_layer
=
cfg
.
generation
.
early_exit_layer
,
)
)
if
hasattr
(
model
,
"decoder"
)
and
hasattr
(
model
.
decoder
,
"set_flag"
):
if
hasattr
(
model
,
"decoder"
)
and
hasattr
(
model
.
decoder
,
"set_flag"
):
model
.
decoder
.
set_flag
(
model
.
decoder
.
set_flag
(
...
@@ -246,9 +247,15 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
...
@@ -246,9 +247,15 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
# Remove padding
# Remove padding
if
"src_tokens"
in
sample
[
"net_input"
]:
if
"src_tokens"
in
sample
[
"net_input"
]:
src_tokens
=
utils
.
strip_pad
(
if
sample
[
"net_input"
][
"src_tokens"
]
.
dtype
in
[
torch
.
int32
,
torch
.
int64
]:
sample
[
"net_input"
][
"src_tokens"
][
i
,
:],
tgt_dict
.
pad
()
src_tokens
=
utils
.
strip_pad
(
)
sample
[
"net_input"
][
"src_tokens"
][
i
,
:],
src_dict
.
pad
()
)
elif
"transcript"
in
sample
:
src_tokens
=
utils
.
strip_pad
(
sample
[
"transcript"
][
"tokens"
][
i
,
:],
src_dict
.
pad
()
)
else
:
else
:
src_tokens
=
None
src_tokens
=
None
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论