Commit f9987d03 by xuchen

report bleu and wer during validation

parent 8f45faa2
......@@ -4,10 +4,16 @@ valid-subset: dev
max-epoch: 100
max-update: 100000
patience: 20
best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
post-process: sentencepiece
#best_checkpoint_metric: loss
#maximize_best_checkpoint_metric: False
eval-wer: True
maximize_best_checkpoint_metric: False
eval-wer-args: {"beam": 5, "lenpen": 1.0}
eval-wer-tok-args: {"wer_remove_punct": true, "wer_lowercase": true, "wer_char_level": false}
no-epoch-checkpoints: True
#keep-last-epochs: 10
keep-best-checkpoints: 10
......
......@@ -4,10 +4,19 @@ valid-subset: dev
max-epoch: 100
max-update: 100000
patience: 20
best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
post-process: sentencepiece
#best_checkpoint_metric: loss
#maximize_best_checkpoint_metric: False
eval-bleu: True
eval-bleu-args: {"beam": 5, "lenpen": 1.0}
eval-bleu-detok: moses
eval-bleu-remove-bpe: sentencepiece
eval-bleu-print-samples: True
best_checkpoint_metric: bleu
maximize_best_checkpoint_metric: True
#fp16-scale-tolerance: 0.25
no-epoch-checkpoints: True
#keep-last-epochs: 10
......
......@@ -931,6 +931,7 @@ class S2TTransformerEncoder(FairseqEncoder):
else:
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
self.show_debug(positions, "position embedding")
x += positions
positions = None
self.show_debug(x, "x after position embedding")
......
......@@ -4,9 +4,12 @@
# LICENSE file in the root directory of this source tree.
import logging
import json
import os.path as op
import numpy as np
from argparse import Namespace
from fairseq import metrics, utils
from fairseq.data import Dictionary, encoders
from fairseq.data.audio.speech_to_text_dataset import (
S2TDataConfig,
......@@ -14,8 +17,12 @@ from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDatasetCreator,
get_features_or_waveform
)
from fairseq.scoring.tokenizer import EvaluationTokenizer
from fairseq.tasks import LegacyFairseqTask, register_task
EVAL_BLEU_ORDER = 4
logger = logging.getLogger(__name__)
......@@ -57,6 +64,77 @@ class SpeechToTextTask(LegacyFairseqTask):
help="use aligned text for loss",
)
# options for reporting BLEU during validation
parser.add_argument(
"--eval-bleu",
default=False,
action="store_true",
help="evaluation with BLEU scores",
)
parser.add_argument(
"--eval-bleu-args",
default="{}",
type=str,
help='generation args for BLUE scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string',
)
parser.add_argument(
"--eval-bleu-detok",
default="space",
type=str,
help="detokenize before computing BLEU (e.g., 'moses'); required if using --eval-bleu; "
"use 'space' to disable detokenization; see fairseq.data.encoders for other options",
)
parser.add_argument(
"--eval-bleu-detok-args",
default="{}",
type=str,
help="args for building the tokenizer, if needed, as JSON string",
)
parser.add_argument(
"--eval-tokenized-bleu",
default=False,
action="store_true",
help="compute tokenized BLEU instead of sacrebleu",
)
parser.add_argument(
"--eval-bleu-remove-bpe",
default="@@ ",
type=str,
help="remove BPE before computing BLEU",
)
parser.add_argument(
"--eval-bleu-print-samples",
default=False,
action="store_true",
help="print sample generations during validation",
)
# options for reporting WER during validation
parser.add_argument(
"--eval-wer",
default=False,
action="store_true",
help="evaluation with WER scores",
)
parser.add_argument(
"--eval-wer-args",
default="{}",
type=str,
help='generation args for WER scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string',
)
parser.add_argument(
"--eval-wer-tok-args",
default="{}",
type=str,
help='tokenizer args for WER scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string',
)
parser.add_argument(
"--eval-wer-detok-args",
default="{}",
type=str,
help="args for building the tokenizer, if needed, as JSON string",
)
def __init__(self, args, tgt_dict, src_dict=None):
super().__init__(args)
self.src_dict = src_dict
......@@ -114,10 +192,6 @@ class SpeechToTextTask(LegacyFairseqTask):
src_bpe_tokenizer = self.build_src_bpe(self.args)
else:
src_bpe_tokenizer = bpe_tokenizer
# if self.data_cfg.share_src_and_tgt:
# src_bpe_tokenizer = bpe_tokenizer
# else:
# src_bpe_tokenizer = None
if self.use_aligned_text:
from fairseq.data.audio.aligned_speech_to_text_dataset import SpeechToTextDatasetCreator as Creator
else:
......@@ -150,7 +224,116 @@ class SpeechToTextTask(LegacyFairseqTask):
def build_model(self, args):
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
args.input_channels = self.data_cfg.input_channels
return super(SpeechToTextTask, self).build_model(args)
model = super(SpeechToTextTask, self).build_model(args)
if self.args.eval_bleu:
detok_args = json.loads(self.args.eval_bleu_detok_args)
self.tokenizer = encoders.build_tokenizer(
Namespace(tokenizer=self.args.eval_bleu_detok, **detok_args)
)
gen_args = json.loads(self.args.eval_bleu_args)
self.sequence_generator = self.build_generator(
[model], Namespace(**gen_args)
)
if self.args.eval_wer:
try:
import editdistance as ed
except ImportError:
raise ImportError("Please install editdistance to use WER scorer")
self.ed = ed
detok_args = json.loads(self.args.eval_wer_detok_args)
self.tokenizer = encoders.build_tokenizer(
Namespace(tokenizer=self.args.eval_bleu_detok, **detok_args)
)
wer_tok_args = json.loads(self.args.eval_wer_tok_args)
self.wer_tokenizer = EvaluationTokenizer(
tokenizer_type=wer_tok_args.get("wer_tokenizer", "none"),
lowercase=wer_tok_args.get("wer_lowercase", False),
punctuation_removal=wer_tok_args.get("wer_remove_punct", False),
character_tokenization=wer_tok_args.get("wer_char_level", False),
)
wer_gen_args = json.loads(self.args.eval_wer_args)
self.wer_sequence_generator = self.build_generator(
[model], Namespace(**wer_gen_args)
)
return model
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
if self.args.eval_bleu:
hyps, refs = self._inference(self.sequence_generator, sample, model)
bleu = self._cal_bleu(hyps, refs)
logging_output["_bleu_sys_len"] = bleu.sys_len
logging_output["_bleu_ref_len"] = bleu.ref_len
# we split counts into separate entries so that they can be
# summed efficiently across workers using fast-stat-sync
assert len(bleu.counts) == EVAL_BLEU_ORDER
for i in range(EVAL_BLEU_ORDER):
logging_output["_bleu_counts_" + str(i)] = bleu.counts[i]
logging_output["_bleu_totals_" + str(i)] = bleu.totals[i]
if self.args.eval_wer:
hyps, refs = self._inference(self.wer_sequence_generator, sample, model)
distance, ref_length = self._cal_wer(hyps, refs)
logging_output["_wer_distance"] = distance
logging_output["_wer_ref_length"] = ref_length
return loss, sample_size, logging_output
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if self.args.eval_wer:
distance = sum(log.get("distance", 0) for log in logging_outputs)
ref_length = sum(log.get("ref_length", 0) for log in logging_outputs)
if ref_length > 0:
metrics.log_scalar("wer", 100.0 * distance / ref_length)
if self.args.eval_bleu:
def sum_logs(key):
import torch
result = sum(log.get(key, 0) for log in logging_outputs)
if torch.is_tensor(result):
result = result.cpu()
return result
counts, totals = [], []
for i in range(EVAL_BLEU_ORDER):
counts.append(sum_logs("_bleu_counts_" + str(i)))
totals.append(sum_logs("_bleu_totals_" + str(i)))
if max(totals) > 0:
# log counts as numpy arrays -- log_scalar will sum them correctly
metrics.log_scalar("_bleu_counts", np.array(counts))
metrics.log_scalar("_bleu_totals", np.array(totals))
metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len"))
metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len"))
def compute_bleu(meters):
import inspect
import sacrebleu
fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0]
if "smooth_method" in fn_sig:
smooth = {"smooth_method": "exp"}
else:
smooth = {"smooth": "exp"}
bleu = sacrebleu.compute_bleu(
correct=meters["_bleu_counts"].sum,
total=meters["_bleu_totals"].sum,
sys_len=meters["_bleu_sys_len"].sum,
ref_len=meters["_bleu_ref_len"].sum,
**smooth
)
return round(bleu.score, 2)
metrics.log_derived("bleu", compute_bleu)
def build_generator(
self,
......@@ -200,3 +383,54 @@ class SpeechToTextTask(LegacyFairseqTask):
return SpeechToTextDataset(
"interactive", False, self.data_cfg, src_tokens, src_lengths
)
def _inference(self, generator, sample, model):
def decode(toks, escape_unk=False):
s = self.tgt_dict.string(
toks.int().cpu(),
self.args.eval_bleu_remove_bpe,
# The default unknown string in fairseq is `<unk>`, but
# this is tokenized by sacrebleu as `< unk >`, inflating
# BLEU scores. Instead, we use a somewhat more verbose
# alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens.
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
)
if self.tokenizer:
s = self.tokenizer.decode(s)
return s
gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None)
hyps, refs = [], []
for i in range(len(gen_out)):
hyps.append(decode(gen_out[i][0]["tokens"]))
refs.append(
decode(
utils.strip_pad(sample["target"][i], self.tgt_dict.pad()),
escape_unk=True, # don't count <unk> as matches to the hypo
)
)
return hyps, refs
def _cal_bleu(self, hyps, refs):
import sacrebleu
if self.args.eval_bleu_print_samples:
logger.info("example hypothesis: " + hyps[0])
logger.info("example reference: " + refs[0])
if self.args.eval_tokenized_bleu:
return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none")
else:
return sacrebleu.corpus_bleu(hyps, [refs])
def _cal_wer(self, hyps, refs):
distance = 0
ref_length = 0
for hyp, ref in zip(hyps, refs):
ref = ref.replace("<<unk>>", "@")
hyp = hyp.replace("<<unk>>", "@")
ref_items = self.wer_tokenizer.tokenize(ref).split()
hyp_items = self.wer_tokenizer.tokenize(hyp).split()
distance += self.ed.eval(ref_items, hyp_items)
ref_length += len(ref_items)
return distance, ref_length
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论