Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
S
S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
S2T
Commits
cb2f2bcb
Commit
cb2f2bcb
authored
Nov 28, 2023
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
2023.11
parent
51395037
全部展开
显示空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
308 行增加
和
27 行删除
+308
-27
egs/librispeech/asr/run.sh
+7
-5
examples/speech_to_text/prep_audio_data.py
+104
-4
fairseq/criterions/ctc.py
+0
-0
fairseq/criterions/label_smoothed_cross_entropy.py
+4
-4
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
+3
-3
fairseq/data/audio/speech_to_text_dataset.py
+0
-0
fairseq/dataclass/configs.py
+14
-0
fairseq/models/speech_to_text/pdss2t_transformer.py
+2
-1
fairseq/models/speech_to_text/s2t_ctc.py
+5
-4
fairseq/models/speech_to_text/s2t_transformer.py
+114
-2
fairseq/modules/speech_to_text/ctc.py
+13
-0
fairseq/tasks/fairseq_task.py
+6
-0
fairseq/trainer.py
+11
-2
fairseq/utils.py
+16
-0
fairseq_cli/generate.py
+9
-2
没有找到文件。
egs/librispeech/asr/run.sh
查看文件 @
cb2f2bcb
...
@@ -83,10 +83,12 @@ epoch_ensemble=0
...
@@ -83,10 +83,12 @@ epoch_ensemble=0
best_ensemble
=
1
best_ensemble
=
1
infer_debug
=
0
infer_debug
=
0
infer_score
=
0
infer_score
=
0
infer_tag
=
infer_parameter
=
infer_tag
=
ee6
infer_tag
=
ee6
infer_parameter
s
=
"--early-exit-count 6"
infer_parameter
=
"--early-exit-count 6"
#infer_parameter
s
="--early-exit-layer 12"
#infer_parameter="--early-exit-layer 12"
#infer_parameter
s
="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
#infer_parameter="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
data_config
=
config.yaml
data_config
=
config.yaml
...
@@ -416,9 +418,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
...
@@ -416,9 +418,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
cmd
=
"
${
cmd
}
cmd
=
"
${
cmd
}
--score-reference"
--score-reference"
fi
fi
if
[[
-n
${
infer_parameter
s
}
]]
;
then
if
[[
-n
${
infer_parameter
}
]]
;
then
cmd
=
"
${
cmd
}
cmd
=
"
${
cmd
}
${
infer_parameter
s
}
"
${
infer_parameter
}
"
fi
fi
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
...
...
examples/speech_to_text/prep_audio_data.py
查看文件 @
cb2f2bcb
...
@@ -37,7 +37,7 @@ from tqdm import tqdm
...
@@ -37,7 +37,7 @@ from tqdm import tqdm
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
MANIFEST_COLUMNS
=
[
"id"
,
"audio"
,
"n_frames"
,
"tgt_text"
]
MANIFEST_COLUMNS
=
[
"id"
,
"audio"
,
"n_frames"
,
"tgt_text"
,
"tgt_lang"
]
class
AudioDataset
(
Dataset
):
class
AudioDataset
(
Dataset
):
...
@@ -398,6 +398,7 @@ def process(args):
...
@@ -398,6 +398,7 @@ def process(args):
if
args
.
add_src
and
src_utt
is
not
None
:
if
args
.
add_src
and
src_utt
is
not
None
:
manifest
[
"src_text"
]
.
append
(
src_utt
)
manifest
[
"src_text"
]
.
append
(
src_utt
)
manifest
[
"tgt_text"
]
.
append
(
tgt_utt
)
manifest
[
"tgt_text"
]
.
append
(
tgt_utt
)
manifest
[
"tgt_lang"
]
.
append
(
tgt_lang
)
if
is_train_split
:
if
is_train_split
:
if
args
.
task
==
"st"
and
args
.
add_src
and
args
.
share
:
if
args
.
task
==
"st"
and
args
.
add_src
and
args
.
share
:
...
@@ -454,8 +455,8 @@ def process(args):
...
@@ -454,8 +455,8 @@ def process(args):
# if task == "st" and args.add_src and args.share:
# if task == "st" and args.add_src and args.share:
if
args
.
add_src
and
args
.
share
:
if
args
.
add_src
and
args
.
share
:
for
e
in
reader
:
for
e
in
reader
:
if
"src_text"
in
dict
(
e
):
src_utt
=
dict
(
e
)[
"src_text"
]
src_utt
=
dict
(
e
)[
"src_text"
]
tgt_utt
=
dict
(
e
)[
"tgt_text"
]
if
args
.
lowercase_src
:
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
if
args
.
rm_punc_src
:
...
@@ -463,6 +464,8 @@ def process(args):
...
@@ -463,6 +464,8 @@ def process(args):
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
train_text
.
append
(
src_utt
)
train_text
.
append
(
src_utt
)
tgt_utt
=
dict
(
e
)[
"tgt_text"
]
train_text
.
append
(
tgt_utt
)
train_text
.
append
(
tgt_utt
)
else
:
else
:
tgt_text
=
[(
dict
(
e
))[
"tgt_text"
]
for
e
in
reader
]
tgt_text
=
[(
dict
(
e
))[
"tgt_text"
]
for
e
in
reader
]
...
@@ -471,11 +474,16 @@ def process(args):
...
@@ -471,11 +474,16 @@ def process(args):
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
for
t
in
train_text
:
for
t
in
train_text
:
f
.
write
(
t
+
"
\n
"
)
f
.
write
(
t
+
"
\n
"
)
special_symbols
=
None
if
args
.
add_syms
:
special_symbols
=
[
f
'<lang:{lang}>'
for
lang
in
args
.
tgt_langs
.
split
(
","
)]
gen_vocab
(
gen_vocab
(
Path
(
f
.
name
),
Path
(
f
.
name
),
output_root
/
spm_filename_prefix
,
output_root
/
spm_filename_prefix
,
args
.
vocab_type
,
args
.
vocab_type
,
args
.
vocab_size
,
args
.
vocab_size
,
special_symbols
=
special_symbols
)
)
# Generate config YAML
# Generate config YAML
...
@@ -491,9 +499,94 @@ def process(args):
...
@@ -491,9 +499,94 @@ def process(args):
cmvn_type
=
args
.
cmvn_type
,
cmvn_type
=
args
.
cmvn_type
,
gcmvn_path
=
(
output_root
/
"gcmvn.npz"
if
args
.
cmvn_type
==
"global"
else
None
),
gcmvn_path
=
(
output_root
/
"gcmvn.npz"
if
args
.
cmvn_type
==
"global"
else
None
),
asr_spm_filename
=
asr_spm_filename
,
asr_spm_filename
=
asr_spm_filename
,
share_src_and_tgt
=
True
if
task
==
"asr"
else
False
,
share_src_and_tgt
=
True
if
task
==
"asr"
and
not
args
.
add_src
else
False
,
prepend_tgt_lang_tag
=
(
args
.
add_syms
),
)
def
process_joint
(
args
):
cur_root
=
Path
(
args
.
data_root
)
.
absolute
()
task
=
args
.
task
languages
=
args
.
languages
.
split
(
","
)
assert
all
((
cur_root
/
f
"{lang}"
)
.
is_dir
()
for
lang
in
languages
),
\
"do not have downloaded data available for all languages"
if
args
.
output_root
is
None
:
output_root
=
cur_root
else
:
output_root
=
Path
(
args
.
output_root
)
.
absolute
()
# Generate vocab
v_size_str
=
""
if
args
.
vocab_type
==
"char"
else
str
(
args
.
vocab_size
)
spm_filename_prefix
=
f
"spm_{args.vocab_type}{v_size_str}_{args.task}"
asr_spm_filename
=
None
if
args
.
add_src
:
if
args
.
share
:
if
args
.
st_spm_prefix
is
not
None
:
spm_filename_prefix
=
args
.
st_spm_prefix
else
:
spm_filename_prefix
=
f
"spm_{args.vocab_type}{v_size_str}_{task}_share"
asr_spm_filename
=
spm_filename_prefix
+
".model"
else
:
if
args
.
st_spm_prefix
is
not
None
:
spm_filename_prefix
=
args
.
st_spm_prefix
assert
args
.
asr_prefix
is
not
None
asr_spm_filename
=
args
.
asr_prefix
+
".model"
elif
task
==
"asr"
:
if
args
.
asr_prefix
is
not
None
:
spm_filename_prefix
=
args
.
asr_prefix
punctuation_str
=
string
.
punctuation
punctuation_str
=
punctuation_str
.
replace
(
"'"
,
""
)
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
for
lang
in
languages
:
tsv_path
=
cur_root
/
f
"{lang}"
/
f
"{args.task}"
/
f
"train.tsv"
df
=
load_df_from_tsv
(
tsv_path
)
for
t
in
df
[
"tgt_text"
]:
f
.
write
(
t
+
"
\n
"
)
if
args
.
add_src
:
for
src_utt
in
df
[
"src_text"
]:
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
for
w
in
punctuation_str
:
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
f
.
write
(
src_utt
+
"
\n
"
)
special_symbols
=
None
if
args
.
task
==
'st'
:
special_symbols
=
[
f
'<lang:{lang.split("-")[1]}>'
for
lang
in
languages
]
gen_vocab
(
Path
(
f
.
name
),
output_root
/
spm_filename_prefix
,
args
.
vocab_type
,
args
.
vocab_size
,
special_symbols
=
special_symbols
)
)
# Generate config YAML
yaml_filename
=
f
"config.yaml"
if
task
==
"st"
and
args
.
add_src
and
args
.
share
:
yaml_filename
=
f
"config_share.yaml"
gen_config_yaml
(
output_root
,
spm_filename_prefix
+
".model"
,
yaml_filename
=
yaml_filename
,
specaugment_policy
=
"ld2"
,
asr_spm_filename
=
asr_spm_filename
,
share_src_and_tgt
=
True
if
task
==
"asr"
else
False
,
prepend_tgt_lang_tag
=
(
args
.
task
==
"st"
),
)
# Make symbolic links to manifests
for
lang
in
languages
:
for
split
in
args
.
splits
.
split
(
","
):
src_path
=
cur_root
/
f
"{lang}"
/
f
"{task}"
/
f
"{split}.tsv"
desc_path
=
output_root
/
f
"{split}_{lang}.tsv"
if
not
desc_path
.
is_symlink
():
shutil
.
copy
(
src_path
,
desc_path
)
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -501,8 +594,12 @@ def main():
...
@@ -501,8 +594,12 @@ def main():
parser
.
add_argument
(
"--data-root"
,
"-d"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--data-root"
,
"-d"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--output-root"
,
"-o"
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
"--output-root"
,
"-o"
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"st"
,
choices
=
[
"asr"
,
"st"
])
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"st"
,
choices
=
[
"asr"
,
"st"
])
parser
.
add_argument
(
"--src-lang"
,
type
=
str
,
required
=
True
,
help
=
"source language"
)
parser
.
add_argument
(
"--joint"
,
action
=
"store_true"
,
help
=
""
)
parser
.
add_argument
(
"--add-syms"
,
action
=
"store_true"
,
help
=
""
)
parser
.
add_argument
(
"--src-lang"
,
type
=
str
,
help
=
"source language"
)
parser
.
add_argument
(
"--tgt-lang"
,
type
=
str
,
help
=
"target language"
)
parser
.
add_argument
(
"--tgt-lang"
,
type
=
str
,
help
=
"target language"
)
parser
.
add_argument
(
"--tgt-langs"
,
type
=
str
,
help
=
"target languages for multilingual training"
)
parser
.
add_argument
(
"--languages"
,
type
=
str
,
help
=
"languages for multilingual training"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--splits"
,
type
=
str
,
default
=
"train,dev,test"
,
help
=
"dataset splits"
"--splits"
,
type
=
str
,
default
=
"train,dev,test"
,
help
=
"dataset splits"
)
)
...
@@ -569,6 +666,9 @@ def main():
...
@@ -569,6 +666,9 @@ def main():
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
joint
:
process_joint
(
args
)
else
:
process
(
args
)
process
(
args
)
...
...
fairseq/criterions/ctc.py
查看文件 @
cb2f2bcb
差异被折叠。
点击展开。
fairseq/criterions/label_smoothed_cross_entropy.py
查看文件 @
cb2f2bcb
...
@@ -25,7 +25,7 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
...
@@ -25,7 +25,7 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
default
=
0.0
,
default
=
0.0
,
metadata
=
{
"help"
:
"the weight for consistency regularization of mixup"
},
metadata
=
{
"help"
:
"the weight for consistency regularization of mixup"
},
)
)
cal_mixup
_loss
:
bool
=
field
(
mixup_no_hard
_loss
:
bool
=
field
(
default
=
False
,
default
=
False
,
metadata
=
{
"help"
:
"calculate the loss for the mixed samples"
},
metadata
=
{
"help"
:
"calculate the loss for the mixed samples"
},
)
)
...
@@ -71,7 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -71,7 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
label_smoothing
,
label_smoothing
,
ignore_prefix_size
=
0
,
ignore_prefix_size
=
0
,
report_accuracy
=
False
,
report_accuracy
=
False
,
cal_mixup_loss
=
Tru
e
,
mixup_no_hard_loss
=
Fals
e
,
mixup_consistent_weight
=
0.0
,
mixup_consistent_weight
=
0.0
,
):
):
super
()
.
__init__
(
task
)
super
()
.
__init__
(
task
)
...
@@ -79,7 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -79,7 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self
.
eps
=
float
(
label_smoothing
)
self
.
eps
=
float
(
label_smoothing
)
self
.
ignore_prefix_size
=
ignore_prefix_size
self
.
ignore_prefix_size
=
ignore_prefix_size
self
.
report_accuracy
=
report_accuracy
self
.
report_accuracy
=
report_accuracy
self
.
cal_mixup_loss
=
cal_mixup
_loss
self
.
mixup_no_hard_loss
=
mixup_no_hard
_loss
self
.
mixup_consistent_weight
=
mixup_consistent_weight
self
.
mixup_consistent_weight
=
mixup_consistent_weight
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
...
@@ -173,7 +173,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -173,7 +173,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
mixup_coef
=
net_output
[
1
][
"mixup"
][
"coef"
][
mixup_flag
]
mixup_coef
=
net_output
[
1
][
"mixup"
][
"coef"
][
mixup_flag
]
loss_coef
=
[
mixup_coef
,
1
-
mixup_coef
]
loss_coef
=
[
mixup_coef
,
1
-
mixup_coef
]
if
self
.
cal_mixup
_loss
:
if
not
self
.
mixup_no_hard
_loss
:
for
item_lprobs
,
item_target
,
item_coef
in
zip
(
mixup_lprobs
,
mixup_targets
,
loss_coef
):
for
item_lprobs
,
item_target
,
item_coef
in
zip
(
mixup_lprobs
,
mixup_targets
,
loss_coef
):
batch_size
=
item_target
.
size
(
0
)
batch_size
=
item_target
.
size
(
0
)
item_loss
,
item_nll_loss
=
label_smoothed_nll_loss
(
item_loss
,
item_nll_loss
=
label_smoothed_nll_loss
(
...
...
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
查看文件 @
cb2f2bcb
...
@@ -30,19 +30,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
...
@@ -30,19 +30,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
cfg
:
CtcCriterionConfig
,
cfg
:
CtcCriterionConfig
,
ctc_weight
=
0.0
,
ctc_weight
=
0.0
,
save_dir
=
None
,
save_dir
=
None
,
cal_mixup_loss
=
Tru
e
,
mixup_no_hard_loss
=
Fals
e
,
mixup_consistent_weight
=
0.0
,
mixup_consistent_weight
=
0.0
,
only_train_enc_prob
=
0.0
,
only_train_enc_prob
=
0.0
,
get_oracle_when_only_train_enc
=
False
get_oracle_when_only_train_enc
=
False
):
):
super
()
.
__init__
(
task
,
sentence_avg
,
label_smoothing
,
super
()
.
__init__
(
task
,
sentence_avg
,
label_smoothing
,
report_accuracy
=
True
,
report_accuracy
=
True
,
cal_mixup_loss
=
cal_mixup
_loss
,
mixup_no_hard_loss
=
mixup_no_hard
_loss
,
mixup_consistent_weight
=
mixup_consistent_weight
)
mixup_consistent_weight
=
mixup_consistent_weight
)
self
.
report_accuracy
=
True
self
.
report_accuracy
=
True
self
.
ctc_weight
=
ctc_weight
self
.
ctc_weight
=
ctc_weight
self
.
ctc_criterion
=
CtcCriterion
(
cfg
,
task
,
ctc_weight
,
save_dir
)
self
.
ctc_criterion
=
CtcCriterion
(
cfg
,
task
,
ctc_weight
,
save_dir
,
mixup_no_hard_loss
)
self
.
save_dir
=
save_dir
self
.
save_dir
=
save_dir
self
.
only_train_enc_prob
=
only_train_enc_prob
self
.
only_train_enc_prob
=
only_train_enc_prob
...
...
fairseq/data/audio/speech_to_text_dataset.py
查看文件 @
cb2f2bcb
fairseq/dataclass/configs.py
查看文件 @
cb2f2bcb
...
@@ -198,6 +198,12 @@ class CommonConfig(FairseqDataclass):
...
@@ -198,6 +198,12 @@ class CommonConfig(FairseqDataclass):
"help"
:
"training steps in each epoch"
"help"
:
"training steps in each epoch"
}
}
)
)
sharded_data_load
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Use sharded data for efficient data load"
},
)
@dataclass
@dataclass
...
@@ -812,6 +818,14 @@ class GenerationConfig(FairseqDataclass):
...
@@ -812,6 +818,14 @@ class GenerationConfig(FairseqDataclass):
default
=
0.0
,
default
=
0.0
,
metadata
=
{
"help"
:
"weight for ctc probs for lm fusion"
},
metadata
=
{
"help"
:
"weight for ctc probs for lm fusion"
},
)
)
early_exit_count
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"early exit during decoding when n consecutive predictions are the same"
},
)
early_exit_layer
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"early exit during decoding at layer n"
},
)
# arguments for iterative refinement generator
# arguments for iterative refinement generator
iter_decode_eos_penalty
:
float
=
field
(
iter_decode_eos_penalty
:
float
=
field
(
...
...
fairseq/models/speech_to_text/pdss2t_transformer.py
查看文件 @
cb2f2bcb
...
@@ -1019,12 +1019,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -1019,12 +1019,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
):
):
if
hasattr
(
self
,
"ctc"
):
if
hasattr
(
self
,
"ctc"
):
import
os
assert
src_dict
is
not
None
assert
src_dict
is
not
None
self
.
ctc
.
set_infer
(
self
.
ctc
.
set_infer
(
ctc_infer
,
ctc_infer
,
post_process
,
post_process
,
src_dict
,
src_dict
,
path
=
path
+
".ctc"
if
path
is
not
None
else
None
,
path
=
os
.
path
.
splitext
(
path
)[
0
]
+
".ctc"
if
path
is
not
None
else
None
,
)
)
def
ctc_valid
(
self
,
lprobs
,
targets
,
input_lengths
,
dictionary
,
lang
=
"source"
):
def
ctc_valid
(
self
,
lprobs
,
targets
,
input_lengths
,
dictionary
,
lang
=
"source"
):
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
cb2f2bcb
...
@@ -250,7 +250,6 @@ class CTCDecoder(object):
...
@@ -250,7 +250,6 @@ class CTCDecoder(object):
logger
.
info
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
logger
.
info
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
print
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
print
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
from
torchprofile
import
profile_macs
from
torchprofile
import
profile_macs
macs
=
profile_macs
(
self
.
model
,
[
src_tokens
,
src_lengths
])
macs
=
profile_macs
(
self
.
model
,
[
src_tokens
,
src_lengths
])
gmacs
=
macs
/
1e9
gmacs
=
macs
/
1e9
...
@@ -269,20 +268,22 @@ class CTCDecoder(object):
...
@@ -269,20 +268,22 @@ class CTCDecoder(object):
inter_logits
=
encoder_outs
.
get
(
"inter_xctc_logits"
,
[])
inter_logits
=
encoder_outs
.
get
(
"inter_xctc_logits"
,
[])
if
ctc_logit
is
None
:
if
ctc_logit
is
None
:
ctc_logit
=
encoder_outs
[
"ctc_logit"
][
0
]
.
transpose
(
0
,
1
)
ctc_logit
=
encoder_outs
[
"ctc_logit"
][
0
]
.
transpose
(
0
,
1
)
if
len
(
inter_logits
)
>
0
:
if
len
(
inter_logits
)
==
0
:
inter_logits
=
encoder_outs
.
get
(
"inter_ctc_logits"
,
[])
inter_logits
=
encoder_outs
.
get
(
"inter_ctc_logits"
,
[])
inter_logits_num
=
len
(
inter_logits
)
inter_logits_num
=
len
(
inter_logits
)
encoder_padding_mask
=
encoder_outs
[
"encoder_padding_mask"
][
0
]
encoder_padding_mask
=
encoder_outs
[
"encoder_padding_mask"
][
0
]
if
self
.
ctc_inter_logit
!=
0
:
if
self
.
ctc_inter_logit
!=
0
:
assert
inter_logits_num
>=
self
.
ctc_inter_logit
if
inter_logits_num
!=
0
:
if
inter_logits_num
!=
0
:
assert
self
.
ctc_inter_logit
<=
inter_logits_num
ctc_logit_item
=
inter_logits
[
-
self
.
ctc_inter_logit
]
ctc_logit_item
=
inter_logits
[
-
self
.
ctc_inter_logit
]
if
isinstance
(
ctc_logit_item
,
list
):
if
isinstance
(
ctc_logit_item
,
list
):
ctc_logit
=
ctc_logit_item
[
0
]
.
transpose
(
0
,
1
)
ctc_logit
=
ctc_logit_item
[
0
]
.
transpose
(
0
,
1
)
if
len
(
ctc_logit_item
)
>=
2
:
if
len
(
ctc_logit_item
)
>=
2
:
encoder_padding_mask
=
ctc_logit_item
[
1
]
encoder_padding_mask
=
ctc_logit_item
[
1
]
else
:
ctc_logit
=
ctc_logit_item
.
transpose
(
0
,
1
)
logit_length
=
(
~
encoder_padding_mask
)
.
long
()
.
sum
(
-
1
)
logit_length
=
(
~
encoder_padding_mask
)
.
long
()
.
sum
(
-
1
)
finalized
=
[]
finalized
=
[]
...
@@ -318,7 +319,7 @@ class CTCDecoder(object):
...
@@ -318,7 +319,7 @@ class CTCDecoder(object):
else
:
else
:
logit
=
inter_logits
[
i
]
logit
=
inter_logits
[
i
]
inter_logits_prob
=
utils
.
log_softmax
(
logit
s
.
transpose
(
0
,
1
),
-
1
)
inter_logits_prob
=
utils
.
log_softmax
(
logit
.
transpose
(
0
,
1
),
-
1
)
ctc_probs
+=
inter_logits_prob
ctc_probs
+=
inter_logits_prob
topk_prob
,
topk_index
=
ctc_probs
.
topk
(
1
,
dim
=
2
)
topk_prob
,
topk_index
=
ctc_probs
.
topk
(
1
,
dim
=
2
)
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
cb2f2bcb
差异被折叠。
点击展开。
fairseq/modules/speech_to_text/ctc.py
查看文件 @
cb2f2bcb
...
@@ -74,6 +74,19 @@ class CTC(nn.Module):
...
@@ -74,6 +74,19 @@ class CTC(nn.Module):
def
argmax
(
self
,
x
):
def
argmax
(
self
,
x
):
return
torch
.
argmax
(
self
.
ctc_projection
(
x
),
dim
=-
1
)
return
torch
.
argmax
(
self
.
ctc_projection
(
x
),
dim
=-
1
)
def
predict
(
self
,
logits
,
padding
):
input_lengths
=
(
~
padding
)
.
sum
(
-
1
)
logits
=
logits
.
transpose
(
0
,
1
)
.
float
()
.
contiguous
()
predicts
=
[]
for
logit
,
inp_l
in
zip
(
logits
,
input_lengths
):
toks
=
logit
[:
inp_l
]
.
argmax
(
dim
=-
1
)
.
unique_consecutive
()
pred_units_arr
=
toks
[
toks
!=
self
.
dictionary
.
bos
()]
# pred_units_arr = logit[:inp_l].argmax(dim=-1)
predicts
.
append
(
pred_units_arr
)
return
predicts
def
infer
(
self
,
logits_or_probs
,
lengths
,
tag
=
None
):
def
infer
(
self
,
logits_or_probs
,
lengths
,
tag
=
None
):
for
lp
,
inp_l
in
zip
(
for
lp
,
inp_l
in
zip
(
logits_or_probs
,
logits_or_probs
,
...
...
fairseq/tasks/fairseq_task.py
查看文件 @
cb2f2bcb
...
@@ -130,6 +130,9 @@ class FairseqTask(object):
...
@@ -130,6 +130,9 @@ class FairseqTask(object):
"""
"""
return
cls
(
cfg
,
**
kwargs
)
return
cls
(
cfg
,
**
kwargs
)
def
sharded_data_load
(
self
):
return
getattr
(
self
.
cfg
,
"sharded_data_load"
,
False
)
def
has_sharded_data
(
self
,
split
):
def
has_sharded_data
(
self
,
split
):
return
os
.
pathsep
in
getattr
(
self
.
cfg
,
"data"
,
""
)
return
os
.
pathsep
in
getattr
(
self
.
cfg
,
"data"
,
""
)
...
@@ -619,6 +622,9 @@ class LegacyFairseqTask(FairseqTask):
...
@@ -619,6 +622,9 @@ class LegacyFairseqTask(FairseqTask):
"""
"""
return
cls
(
args
,
**
kwargs
)
return
cls
(
args
,
**
kwargs
)
def
sharded_data_load
(
self
):
return
getattr
(
self
.
args
,
"sharded_data_load"
,
False
)
def
has_sharded_data
(
self
,
split
):
def
has_sharded_data
(
self
,
split
):
return
os
.
pathsep
in
getattr
(
self
.
args
,
"data"
,
""
)
return
os
.
pathsep
in
getattr
(
self
.
args
,
"data"
,
""
)
...
...
fairseq/trainer.py
查看文件 @
cb2f2bcb
...
@@ -521,16 +521,25 @@ class Trainer(object):
...
@@ -521,16 +521,25 @@ class Trainer(object):
disable_iterator_cache
=
False
,
disable_iterator_cache
=
False
,
):
):
"""Return an EpochBatchIterator over the training set for a given epoch."""
"""Return an EpochBatchIterator over the training set for a given epoch."""
if
self
.
task
.
sharded_data_load
():
datasets
=
self
.
cfg
.
dataset
.
train_subset
.
split
(
","
)
curr_dataset
=
datasets
[(
epoch
-
1
)
%
len
(
datasets
)]
logger
.
info
(
"sharded loading the training subset {}"
.
format
(
curr_dataset
))
else
:
curr_dataset
=
self
.
cfg
.
dataset
.
train_subset
load_dataset
=
load_dataset
or
self
.
task
.
sharded_data_load
()
disable_iterator_cache
=
disable_iterator_cache
or
self
.
task
.
sharded_data_load
()
if
load_dataset
:
if
load_dataset
:
logger
.
info
(
"loading train data for epoch {}"
.
format
(
epoch
))
logger
.
info
(
"loading train data for epoch {}"
.
format
(
epoch
))
self
.
task
.
load_dataset
(
self
.
task
.
load_dataset
(
self
.
cfg
.
dataset
.
train_sub
set
,
curr_data
set
,
epoch
=
epoch
,
epoch
=
epoch
,
combine
=
combine
,
combine
=
combine
,
data_selector
=
data_selector
,
data_selector
=
data_selector
,
)
)
batch_iterator
=
self
.
task
.
get_batch_iterator
(
batch_iterator
=
self
.
task
.
get_batch_iterator
(
dataset
=
self
.
task
.
dataset
(
self
.
cfg
.
dataset
.
train_sub
set
),
dataset
=
self
.
task
.
dataset
(
curr_data
set
),
max_tokens
=
self
.
cfg
.
dataset
.
max_tokens
,
max_tokens
=
self
.
cfg
.
dataset
.
max_tokens
,
max_sentences
=
self
.
cfg
.
dataset
.
batch_size
,
max_sentences
=
self
.
cfg
.
dataset
.
batch_size
,
max_positions
=
utils
.
resolve_max_positions
(
max_positions
=
utils
.
resolve_max_positions
(
...
...
fairseq/utils.py
查看文件 @
cb2f2bcb
...
@@ -754,3 +754,18 @@ def freeze_parameters(module, freeze_module_name):
...
@@ -754,3 +754,18 @@ def freeze_parameters(module, freeze_module_name):
freeze_module_name
=
freeze_module_name
.
split
(
","
)
freeze_module_name
=
freeze_module_name
.
split
(
","
)
for
name
in
freeze_module_name
:
for
name
in
freeze_module_name
:
freeze_module_params_by_name
(
module
,
name
)
freeze_module_params_by_name
(
module
,
name
)
def
distribution_soft_to_hard
(
logit_or_prob
):
argmax_prob
=
torch
.
argmax
(
logit_or_prob
,
dim
=-
1
,
keepdim
=
True
)
hard_distribution
=
(
(
argmax_prob
==
torch
.
arange
(
logit_or_prob
.
size
(
-
1
),
device
=
logit_or_prob
.
device
)
.
unsqueeze
(
0
)
)
.
to
(
logit_or_prob
.
dtype
)
)
return
hard_distribution
\ No newline at end of file
fairseq_cli/generate.py
查看文件 @
cb2f2bcb
...
@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
...
@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for
model
in
models
:
for
model
in
models
:
if
hasattr
(
model
,
"encoder"
)
and
hasattr
(
model
.
encoder
,
"set_ctc_infer"
):
if
hasattr
(
model
,
"encoder"
)
and
hasattr
(
model
.
encoder
,
"set_ctc_infer"
):
model
.
encoder
.
set_ctc_infer
(
cfg
.
generation
.
ctc_infer
,
"sentencepiece"
,
model
.
encoder
.
set_ctc_infer
(
cfg
.
generation
.
ctc_infer
,
"sentencepiece"
,
src_dict
,
tgt_dict
,
translation_path
)
src_dict
,
tgt_dict
,
translation_path
,
cfg
.
generation
.
early_exit_count
)
if
hasattr
(
model
,
"encoder"
)
and
hasattr
(
model
.
encoder
,
"set_flag"
):
if
hasattr
(
model
,
"encoder"
)
and
hasattr
(
model
.
encoder
,
"set_flag"
):
model
.
encoder
.
set_flag
(
model
.
encoder
.
set_flag
(
cal_localness
=
cfg
.
generation
.
cal_localness
,
cal_localness
=
cfg
.
generation
.
cal_localness
,
...
@@ -120,6 +120,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
...
@@ -120,6 +120,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
mixup_infer
=
cfg
.
generation
.
mixup_infer
,
mixup_infer
=
cfg
.
generation
.
mixup_infer
,
gather_cos_sim
=
cfg
.
generation
.
gather_cos_sim
,
gather_cos_sim
=
cfg
.
generation
.
gather_cos_sim
,
gather_cos_sim_dis
=
cfg
.
generation
.
gather_cos_sim_dis
,
gather_cos_sim_dis
=
cfg
.
generation
.
gather_cos_sim_dis
,
early_exit_layer
=
cfg
.
generation
.
early_exit_layer
,
)
)
if
hasattr
(
model
,
"decoder"
)
and
hasattr
(
model
.
decoder
,
"set_flag"
):
if
hasattr
(
model
,
"decoder"
)
and
hasattr
(
model
.
decoder
,
"set_flag"
):
model
.
decoder
.
set_flag
(
model
.
decoder
.
set_flag
(
...
@@ -246,9 +247,15 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
...
@@ -246,9 +247,15 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
# Remove padding
# Remove padding
if
"src_tokens"
in
sample
[
"net_input"
]:
if
"src_tokens"
in
sample
[
"net_input"
]:
if
sample
[
"net_input"
][
"src_tokens"
]
.
dtype
in
[
torch
.
int32
,
torch
.
int64
]:
src_tokens
=
utils
.
strip_pad
(
src_tokens
=
utils
.
strip_pad
(
sample
[
"net_input"
][
"src_tokens"
][
i
,
:],
tgt
_dict
.
pad
()
sample
[
"net_input"
][
"src_tokens"
][
i
,
:],
src
_dict
.
pad
()
)
)
elif
"transcript"
in
sample
:
src_tokens
=
utils
.
strip_pad
(
sample
[
"transcript"
][
"tokens"
][
i
,
:],
src_dict
.
pad
()
)
else
:
else
:
src_tokens
=
None
src_tokens
=
None
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论