Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
f9987d03
Commit
f9987d03
authored
Oct 08, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
report bleu and wer during validation
parent
8f45faa2
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
259 行增加
和
9 行删除
+259
-9
egs/mustc/asr/conf/basis.yaml
+8
-2
egs/mustc/st/conf/basis.yaml
+11
-2
fairseq/models/speech_to_text/s2t_transformer.py
+1
-0
fairseq/tasks/speech_to_text.py
+239
-5
没有找到文件。
egs/mustc/asr/conf/basis.yaml
查看文件 @
f9987d03
...
...
@@ -4,10 +4,16 @@ valid-subset: dev
max-epoch
:
100
max-update
:
100000
patience
:
20
best_checkpoint_metric
:
loss
maximize_best_checkpoint_metric
:
False
post-process
:
sentencepiece
#best_checkpoint_metric: loss
#maximize_best_checkpoint_metric: False
eval-wer
:
True
maximize_best_checkpoint_metric
:
False
eval-wer-args
:
{
"
beam"
:
5
,
"
lenpen"
:
1.0
}
eval-wer-tok-args
:
{
"
wer_remove_punct"
:
true
,
"
wer_lowercase"
:
true
,
"
wer_char_level"
:
false
}
no-epoch-checkpoints
:
True
#keep-last-epochs: 10
keep-best-checkpoints
:
10
...
...
egs/mustc/st/conf/basis.yaml
查看文件 @
f9987d03
...
...
@@ -4,10 +4,19 @@ valid-subset: dev
max-epoch
:
100
max-update
:
100000
patience
:
20
best_checkpoint_metric
:
loss
maximize_best_checkpoint_metric
:
False
post-process
:
sentencepiece
#best_checkpoint_metric: loss
#maximize_best_checkpoint_metric: False
eval-bleu
:
True
eval-bleu-args
:
{
"
beam"
:
5
,
"
lenpen"
:
1.0
}
eval-bleu-detok
:
moses
eval-bleu-remove-bpe
:
sentencepiece
eval-bleu-print-samples
:
True
best_checkpoint_metric
:
bleu
maximize_best_checkpoint_metric
:
True
#fp16-scale-tolerance: 0.25
no-epoch-checkpoints
:
True
#keep-last-epochs: 10
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
f9987d03
...
...
@@ -931,6 +931,7 @@ class S2TTransformerEncoder(FairseqEncoder):
else
:
positions
=
self
.
embed_positions
(
encoder_padding_mask
)
.
transpose
(
0
,
1
)
self
.
show_debug
(
positions
,
"position embedding"
)
x
+=
positions
positions
=
None
self
.
show_debug
(
x
,
"x after position embedding"
)
...
...
fairseq/tasks/speech_to_text.py
查看文件 @
f9987d03
...
...
@@ -4,9 +4,12 @@
# LICENSE file in the root directory of this source tree.
import
logging
import
json
import
os.path
as
op
import
numpy
as
np
from
argparse
import
Namespace
from
fairseq
import
metrics
,
utils
from
fairseq.data
import
Dictionary
,
encoders
from
fairseq.data.audio.speech_to_text_dataset
import
(
S2TDataConfig
,
...
...
@@ -14,8 +17,12 @@ from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDatasetCreator
,
get_features_or_waveform
)
from
fairseq.scoring.tokenizer
import
EvaluationTokenizer
from
fairseq.tasks
import
LegacyFairseqTask
,
register_task
EVAL_BLEU_ORDER
=
4
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -57,6 +64,77 @@ class SpeechToTextTask(LegacyFairseqTask):
help
=
"use aligned text for loss"
,
)
# options for reporting BLEU during validation
parser
.
add_argument
(
"--eval-bleu"
,
default
=
False
,
action
=
"store_true"
,
help
=
"evaluation with BLEU scores"
,
)
parser
.
add_argument
(
"--eval-bleu-args"
,
default
=
"{}"
,
type
=
str
,
help
=
'generation args for BLUE scoring, e.g.,
\'
{"beam": 4, "lenpen": 0.6}
\'
, as JSON string'
,
)
parser
.
add_argument
(
"--eval-bleu-detok"
,
default
=
"space"
,
type
=
str
,
help
=
"detokenize before computing BLEU (e.g., 'moses'); required if using --eval-bleu; "
"use 'space' to disable detokenization; see fairseq.data.encoders for other options"
,
)
parser
.
add_argument
(
"--eval-bleu-detok-args"
,
default
=
"{}"
,
type
=
str
,
help
=
"args for building the tokenizer, if needed, as JSON string"
,
)
parser
.
add_argument
(
"--eval-tokenized-bleu"
,
default
=
False
,
action
=
"store_true"
,
help
=
"compute tokenized BLEU instead of sacrebleu"
,
)
parser
.
add_argument
(
"--eval-bleu-remove-bpe"
,
default
=
"@@ "
,
type
=
str
,
help
=
"remove BPE before computing BLEU"
,
)
parser
.
add_argument
(
"--eval-bleu-print-samples"
,
default
=
False
,
action
=
"store_true"
,
help
=
"print sample generations during validation"
,
)
# options for reporting WER during validation
parser
.
add_argument
(
"--eval-wer"
,
default
=
False
,
action
=
"store_true"
,
help
=
"evaluation with WER scores"
,
)
parser
.
add_argument
(
"--eval-wer-args"
,
default
=
"{}"
,
type
=
str
,
help
=
'generation args for WER scoring, e.g.,
\'
{"beam": 4, "lenpen": 0.6}
\'
, as JSON string'
,
)
parser
.
add_argument
(
"--eval-wer-tok-args"
,
default
=
"{}"
,
type
=
str
,
help
=
'tokenizer args for WER scoring, e.g.,
\'
{"beam": 4, "lenpen": 0.6}
\'
, as JSON string'
,
)
parser
.
add_argument
(
"--eval-wer-detok-args"
,
default
=
"{}"
,
type
=
str
,
help
=
"args for building the tokenizer, if needed, as JSON string"
,
)
def
__init__
(
self
,
args
,
tgt_dict
,
src_dict
=
None
):
super
()
.
__init__
(
args
)
self
.
src_dict
=
src_dict
...
...
@@ -114,10 +192,6 @@ class SpeechToTextTask(LegacyFairseqTask):
src_bpe_tokenizer
=
self
.
build_src_bpe
(
self
.
args
)
else
:
src_bpe_tokenizer
=
bpe_tokenizer
# if self.data_cfg.share_src_and_tgt:
# src_bpe_tokenizer = bpe_tokenizer
# else:
# src_bpe_tokenizer = None
if
self
.
use_aligned_text
:
from
fairseq.data.audio.aligned_speech_to_text_dataset
import
SpeechToTextDatasetCreator
as
Creator
else
:
...
...
@@ -150,7 +224,116 @@ class SpeechToTextTask(LegacyFairseqTask):
def
build_model
(
self
,
args
):
args
.
input_feat_per_channel
=
self
.
data_cfg
.
input_feat_per_channel
args
.
input_channels
=
self
.
data_cfg
.
input_channels
return
super
(
SpeechToTextTask
,
self
)
.
build_model
(
args
)
model
=
super
(
SpeechToTextTask
,
self
)
.
build_model
(
args
)
if
self
.
args
.
eval_bleu
:
detok_args
=
json
.
loads
(
self
.
args
.
eval_bleu_detok_args
)
self
.
tokenizer
=
encoders
.
build_tokenizer
(
Namespace
(
tokenizer
=
self
.
args
.
eval_bleu_detok
,
**
detok_args
)
)
gen_args
=
json
.
loads
(
self
.
args
.
eval_bleu_args
)
self
.
sequence_generator
=
self
.
build_generator
(
[
model
],
Namespace
(
**
gen_args
)
)
if
self
.
args
.
eval_wer
:
try
:
import
editdistance
as
ed
except
ImportError
:
raise
ImportError
(
"Please install editdistance to use WER scorer"
)
self
.
ed
=
ed
detok_args
=
json
.
loads
(
self
.
args
.
eval_wer_detok_args
)
self
.
tokenizer
=
encoders
.
build_tokenizer
(
Namespace
(
tokenizer
=
self
.
args
.
eval_bleu_detok
,
**
detok_args
)
)
wer_tok_args
=
json
.
loads
(
self
.
args
.
eval_wer_tok_args
)
self
.
wer_tokenizer
=
EvaluationTokenizer
(
tokenizer_type
=
wer_tok_args
.
get
(
"wer_tokenizer"
,
"none"
),
lowercase
=
wer_tok_args
.
get
(
"wer_lowercase"
,
False
),
punctuation_removal
=
wer_tok_args
.
get
(
"wer_remove_punct"
,
False
),
character_tokenization
=
wer_tok_args
.
get
(
"wer_char_level"
,
False
),
)
wer_gen_args
=
json
.
loads
(
self
.
args
.
eval_wer_args
)
self
.
wer_sequence_generator
=
self
.
build_generator
(
[
model
],
Namespace
(
**
wer_gen_args
)
)
return
model
def
valid_step
(
self
,
sample
,
model
,
criterion
):
loss
,
sample_size
,
logging_output
=
super
()
.
valid_step
(
sample
,
model
,
criterion
)
if
self
.
args
.
eval_bleu
:
hyps
,
refs
=
self
.
_inference
(
self
.
sequence_generator
,
sample
,
model
)
bleu
=
self
.
_cal_bleu
(
hyps
,
refs
)
logging_output
[
"_bleu_sys_len"
]
=
bleu
.
sys_len
logging_output
[
"_bleu_ref_len"
]
=
bleu
.
ref_len
# we split counts into separate entries so that they can be
# summed efficiently across workers using fast-stat-sync
assert
len
(
bleu
.
counts
)
==
EVAL_BLEU_ORDER
for
i
in
range
(
EVAL_BLEU_ORDER
):
logging_output
[
"_bleu_counts_"
+
str
(
i
)]
=
bleu
.
counts
[
i
]
logging_output
[
"_bleu_totals_"
+
str
(
i
)]
=
bleu
.
totals
[
i
]
if
self
.
args
.
eval_wer
:
hyps
,
refs
=
self
.
_inference
(
self
.
wer_sequence_generator
,
sample
,
model
)
distance
,
ref_length
=
self
.
_cal_wer
(
hyps
,
refs
)
logging_output
[
"_wer_distance"
]
=
distance
logging_output
[
"_wer_ref_length"
]
=
ref_length
return
loss
,
sample_size
,
logging_output
def
reduce_metrics
(
self
,
logging_outputs
,
criterion
):
super
()
.
reduce_metrics
(
logging_outputs
,
criterion
)
if
self
.
args
.
eval_wer
:
distance
=
sum
(
log
.
get
(
"distance"
,
0
)
for
log
in
logging_outputs
)
ref_length
=
sum
(
log
.
get
(
"ref_length"
,
0
)
for
log
in
logging_outputs
)
if
ref_length
>
0
:
metrics
.
log_scalar
(
"wer"
,
100.0
*
distance
/
ref_length
)
if
self
.
args
.
eval_bleu
:
def
sum_logs
(
key
):
import
torch
result
=
sum
(
log
.
get
(
key
,
0
)
for
log
in
logging_outputs
)
if
torch
.
is_tensor
(
result
):
result
=
result
.
cpu
()
return
result
counts
,
totals
=
[],
[]
for
i
in
range
(
EVAL_BLEU_ORDER
):
counts
.
append
(
sum_logs
(
"_bleu_counts_"
+
str
(
i
)))
totals
.
append
(
sum_logs
(
"_bleu_totals_"
+
str
(
i
)))
if
max
(
totals
)
>
0
:
# log counts as numpy arrays -- log_scalar will sum them correctly
metrics
.
log_scalar
(
"_bleu_counts"
,
np
.
array
(
counts
))
metrics
.
log_scalar
(
"_bleu_totals"
,
np
.
array
(
totals
))
metrics
.
log_scalar
(
"_bleu_sys_len"
,
sum_logs
(
"_bleu_sys_len"
))
metrics
.
log_scalar
(
"_bleu_ref_len"
,
sum_logs
(
"_bleu_ref_len"
))
def
compute_bleu
(
meters
):
import
inspect
import
sacrebleu
fn_sig
=
inspect
.
getfullargspec
(
sacrebleu
.
compute_bleu
)[
0
]
if
"smooth_method"
in
fn_sig
:
smooth
=
{
"smooth_method"
:
"exp"
}
else
:
smooth
=
{
"smooth"
:
"exp"
}
bleu
=
sacrebleu
.
compute_bleu
(
correct
=
meters
[
"_bleu_counts"
]
.
sum
,
total
=
meters
[
"_bleu_totals"
]
.
sum
,
sys_len
=
meters
[
"_bleu_sys_len"
]
.
sum
,
ref_len
=
meters
[
"_bleu_ref_len"
]
.
sum
,
**
smooth
)
return
round
(
bleu
.
score
,
2
)
metrics
.
log_derived
(
"bleu"
,
compute_bleu
)
def
build_generator
(
self
,
...
...
@@ -200,3 +383,54 @@ class SpeechToTextTask(LegacyFairseqTask):
return
SpeechToTextDataset
(
"interactive"
,
False
,
self
.
data_cfg
,
src_tokens
,
src_lengths
)
def
_inference
(
self
,
generator
,
sample
,
model
):
def
decode
(
toks
,
escape_unk
=
False
):
s
=
self
.
tgt_dict
.
string
(
toks
.
int
()
.
cpu
(),
self
.
args
.
eval_bleu_remove_bpe
,
# The default unknown string in fairseq is `<unk>`, but
# this is tokenized by sacrebleu as `< unk >`, inflating
# BLEU scores. Instead, we use a somewhat more verbose
# alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens.
unk_string
=
(
"UNKNOWNTOKENINREF"
if
escape_unk
else
"UNKNOWNTOKENINHYP"
),
)
if
self
.
tokenizer
:
s
=
self
.
tokenizer
.
decode
(
s
)
return
s
gen_out
=
self
.
inference_step
(
generator
,
[
model
],
sample
,
prefix_tokens
=
None
)
hyps
,
refs
=
[],
[]
for
i
in
range
(
len
(
gen_out
)):
hyps
.
append
(
decode
(
gen_out
[
i
][
0
][
"tokens"
]))
refs
.
append
(
decode
(
utils
.
strip_pad
(
sample
[
"target"
][
i
],
self
.
tgt_dict
.
pad
()),
escape_unk
=
True
,
# don't count <unk> as matches to the hypo
)
)
return
hyps
,
refs
def
_cal_bleu
(
self
,
hyps
,
refs
):
import
sacrebleu
if
self
.
args
.
eval_bleu_print_samples
:
logger
.
info
(
"example hypothesis: "
+
hyps
[
0
])
logger
.
info
(
"example reference: "
+
refs
[
0
])
if
self
.
args
.
eval_tokenized_bleu
:
return
sacrebleu
.
corpus_bleu
(
hyps
,
[
refs
],
tokenize
=
"none"
)
else
:
return
sacrebleu
.
corpus_bleu
(
hyps
,
[
refs
])
def
_cal_wer
(
self
,
hyps
,
refs
):
distance
=
0
ref_length
=
0
for
hyp
,
ref
in
zip
(
hyps
,
refs
):
ref
=
ref
.
replace
(
"<<unk>>"
,
"@"
)
hyp
=
hyp
.
replace
(
"<<unk>>"
,
"@"
)
ref_items
=
self
.
wer_tokenizer
.
tokenize
(
ref
)
.
split
()
hyp_items
=
self
.
wer_tokenizer
.
tokenize
(
hyp
)
.
split
()
distance
+=
self
.
ed
.
eval
(
ref_items
,
hyp_items
)
ref_length
+=
len
(
ref_items
)
return
distance
,
ref_length
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论