Commit 9fe8cd1e by xuchen

fix some bugs

parent a201a883
...@@ -44,6 +44,10 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -44,6 +44,10 @@ class CtcCriterionConfig(FairseqDataclass):
"See fairseq.data.data_utils.post_process() for full list of options" "See fairseq.data.data_utils.post_process() for full list of options"
}, },
) )
ctc_weight: float = field(
default=0.0,
metadata={"help": "weight of CTC loss"},
)
ctc_entropy: float = field( ctc_entropy: float = field(
default=0.0, default=0.0,
metadata={"help": "weight of CTC entropy"}, metadata={"help": "weight of CTC entropy"},
...@@ -312,7 +316,8 @@ class CtcCriterion(FairseqCriterion): ...@@ -312,7 +316,8 @@ class CtcCriterion(FairseqCriterion):
loss = F.kl_div( loss = F.kl_div(
F.log_softmax(student_logit, dim=-1, dtype=torch.float32), F.log_softmax(student_logit, dim=-1, dtype=torch.float32),
F.log_softmax(teacher_logit.detach(), dim=-1, dtype=torch.float32), F.log_softmax(teacher_logit, dim=-1, dtype=torch.float32),
# F.log_softmax(teacher_logit.detach(), dim=-1, dtype=torch.float32),
log_target=True, log_target=True,
reduction="none", reduction="none",
) )
...@@ -491,7 +496,8 @@ class CtcCriterion(FairseqCriterion): ...@@ -491,7 +496,8 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_num = 0 ctc_self_distill_num = 0
non_padding = non_padding_mask non_padding = non_padding_mask
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0: # if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0:
if self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0:
teacher_logit = ctc_logit teacher_logit = ctc_logit
student_logits = net_output["interleaved_ctc_logits"] student_logits = net_output["interleaved_ctc_logits"]
ctc_self_distill_num = interleaved_ctc_num ctc_self_distill_num = interleaved_ctc_num
...@@ -550,15 +556,17 @@ class CtcCriterion(FairseqCriterion): ...@@ -550,15 +556,17 @@ class CtcCriterion(FairseqCriterion):
if self.target_ctc_weight != 0: if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss) logger.warning("Target CTC loss %f!" % target_ctc_loss)
# CER is not completely accurate and is for reference only.
if not model.training: if not model.training:
if hasattr(model.encoder, "ctc_valid"): encoder = model.encoder.encoder if hasattr(model.encoder, "encoder") else model.encoder
if hasattr(encoder, "ctc_valid"):
if lprobs is not None: if lprobs is not None:
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
if mixup: if mixup:
idx = mixup_idx1 if mixup_coef > 0.5 else mixup_idx2 idx = mixup_idx1 if mixup_coef > 0.5 else mixup_idx2
tokens = tokens[idx] tokens = tokens[idx]
c_err, c_len, w_errs, w_len, wv_errs = model.encoder.ctc_valid( c_err, c_len, w_errs, w_len, wv_errs = encoder.ctc_valid(
lprobs_t, tokens, input_lengths, self.task.source_dictionary, lang="source") lprobs_t, tokens, input_lengths, self.task.source_dictionary, lang="source")
logging_output["wv_errors"] = wv_errs logging_output["wv_errors"] = wv_errs
......
...@@ -854,9 +854,18 @@ class GenerationConfig(FairseqDataclass): ...@@ -854,9 +854,18 @@ class GenerationConfig(FairseqDataclass):
default=False, default=False,
metadata={"help": "if set, dont use seed for initializing random generators"}, metadata={"help": "if set, dont use seed for initializing random generators"},
) )
# CTC inference
ctc_infer: bool = field( ctc_infer: bool = field(
default=False, default=False,
metadata={"help": "generate CTC decoding results during inference"} metadata={"help": "generate CTC results during inference"}
)
ctc_self_ensemble: bool = field(
default=False,
metadata={"help": "ensemble the top representation and intermediate representations for decoding"}
)
ctc_inter_logit: int = field(
default=0,
metadata={"help": "use the specific logit (from top to bottom, 0 is the top layer) for inference"}
) )
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
from .berard import * # noqa from .berard import * # noqa
from .convtransformer import * # noqa from .convtransformer import * # noqa
from .s2t_ctc import *
from .s2t_transformer import * # noqa from .s2t_transformer import * # noqa
from .pdss2t_transformer import * # noqa from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa from .s2t_sate import * # noqa
from .s2t_dual import * # noqa from .s2t_dual import * # noqa
from .s2t_ctc import *
...@@ -11,8 +11,17 @@ from fairseq.models import ( ...@@ -11,8 +11,17 @@ from fairseq.models import (
register_model_architecture, register_model_architecture,
) )
from .s2t_transformer import S2TTransformerModel, S2TTransformerEncoder # from .s2t_sate import S2TSATEModel, S2TSATEEncoder
from .pdss2t_transformer import PDSS2TTransformerModel, PDSS2TTransformerEncoder # from .s2t_transformer import S2TTransformerModel, S2TTransformerEncoder
# from .pdss2t_transformer import PDSS2TTransformerModel, PDSS2TTransformerEncoder
from fairseq.models.speech_to_text import (
S2TTransformerModel,
S2TTransformerEncoder,
PDSS2TTransformerModel,
PDSS2TTransformerEncoder,
S2TSATEModel,
S2TSATEEncoder,
)
from torch import Tensor from torch import Tensor
...@@ -30,6 +39,7 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -30,6 +39,7 @@ class S2TCTCModel(FairseqEncoderModel):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
S2TTransformerModel.add_args(parser) S2TTransformerModel.add_args(parser)
PDSS2TTransformerModel.add_specific_args(parser) PDSS2TTransformerModel.add_specific_args(parser)
S2TSATEModel.add_specific_args(parser)
# encoder # encoder
parser.add_argument( parser.add_argument(
...@@ -49,7 +59,7 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -49,7 +59,7 @@ class S2TCTCModel(FairseqEncoderModel):
f"{args.load_pretrained_encoder_from}" f"{args.load_pretrained_encoder_from}"
) )
encoder = checkpoint_utils.load_pretrained_component_from_model( encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False component=encoder.encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
) )
return encoder return encoder
...@@ -108,11 +118,16 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -108,11 +118,16 @@ class S2TCTCEncoder(FairseqEncoder):
self.encoder = S2TTransformerEncoder(args, task) self.encoder = S2TTransformerEncoder(args, task)
elif encoder_type == "pds": elif encoder_type == "pds":
self.encoder = PDSS2TTransformerEncoder(args, task) self.encoder = PDSS2TTransformerEncoder(args, task)
elif encoder_type == "sate":
self.encoder = S2TSATEEncoder(args, task)
else: else:
logger.error("Unsupported architecture: %s." % encoder_type) logger.error("Unsupported architecture: %s." % encoder_type)
return return
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None):
self.encoder.set_ctc_infer(ctc_infer, post_process, src_dict=src_dict, tgt_dict=tgt_dict, path=path)
def forward(self, src_tokens, src_lengths, **kwargs): def forward(self, src_tokens, src_lengths, **kwargs):
return self.encoder(src_tokens, src_lengths, **kwargs) return self.encoder(src_tokens, src_lengths, **kwargs)
...@@ -132,6 +147,16 @@ class CTCDecoder(object): ...@@ -132,6 +147,16 @@ class CTCDecoder(object):
self.unk = dictionary.unk() self.unk = dictionary.unk()
self.eos = dictionary.eos() self.eos = dictionary.eos()
self.ctc_self_ensemble = getattr(args, "ctc_self_ensemble", False)
self.ctc_inter_logit = getattr(args, "ctc_inter_logit", 0)
assert not (self.ctc_self_ensemble is True and self.ctc_inter_logit is True), \
"Self ensemble and inference by intermediate logit can not be True at the same time."
if self.ctc_self_ensemble:
logger.info("Using self ensemble for CTC inference")
if self.ctc_inter_logit != 0:
logger.info("Using intermediate logit %d for CTC inference" % self.ctc_inter_logit)
self.vocab_size = len(dictionary) self.vocab_size = len(dictionary)
self.beam_size = args.beam self.beam_size = args.beam
# the max beam size is the dictionary size - 1, since we never select pad # the max beam size is the dictionary size - 1, since we never select pad
...@@ -150,19 +175,25 @@ class CTCDecoder(object): ...@@ -150,19 +175,25 @@ class CTCDecoder(object):
if self.lm_model is not None: if self.lm_model is not None:
self.lm_model.eval() self.lm_model.eval()
from ctcdecode import CTCBeamDecoder self.infer = "greedy"
self.ctc_decoder = CTCBeamDecoder( if self.beam_size > 1:
dictionary.symbols, try:
model_path=self.lm_model, from ctcdecode import CTCBeamDecoder
alpha=self.lm_weight, self.infer = "beam"
beta=0, self.ctc_decoder = CTCBeamDecoder(
cutoff_top_n=40, dictionary.symbols,
cutoff_prob=1.0, model_path=self.lm_model,
beam_width=self.beam_size, alpha=self.lm_weight,
num_processes=20, beta=0,
blank_id=self.blank, cutoff_top_n=40,
log_probs_input=False cutoff_prob=1.0,
) beam_width=self.beam_size,
num_processes=20,
blank_id=self.blank,
log_probs_input=False
)
except ImportError:
logger.warning("Cannot import the CTCBeamDecoder library. We use the greedy search for CTC decoding.")
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
...@@ -173,31 +204,76 @@ class CTCDecoder(object): ...@@ -173,31 +204,76 @@ class CTCDecoder(object):
src_tokens = net_input["src_tokens"] src_tokens = net_input["src_tokens"]
src_lengths = net_input["src_lengths"] src_lengths = net_input["src_lengths"]
bsz, src_len = src_tokens.size()[:2] bsz, src_len = src_tokens.size()[:2]
beam_size = self.beam_size
encoder_outs = self.model(src_tokens=src_tokens, encoder_outs = self.model(src_tokens=src_tokens,
src_lengths=src_lengths) src_lengths=src_lengths)
ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1) ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1)
logit_length = (~encoder_outs["encoder_padding_mask"][0]).long().sum(-1) inter_logits = encoder_outs.get("interleaved_ctc_logits", [])
beam_results, beam_scores, time_steps, out_lens = self.ctc_decoder.decode( inter_logits_num = len(inter_logits)
utils.softmax(ctc_logit, -1), logit_length
) if self.ctc_inter_logit != 0:
if inter_logits_num != 0:
assert self.ctc_inter_logit <= inter_logits_num
ctc_logit = inter_logits[-self.ctc_inter_logit].transpose(0, 1)
logit_length = (~encoder_outs["encoder_padding_mask"][0]).long().sum(-1)
finalized = [] finalized = []
for idx in range(bsz): if self.infer == "beam":
hypos = [] beam_results, beam_scores, time_steps, out_lens = self.ctc_decoder.decode(
for beam_idx in range(beam_size): utils.softmax(ctc_logit, -1), logit_length
)
for idx in range(bsz):
hypos = []
#for beam_idx in range(beam_size):
for beam_idx in range(1):
hypo = dict()
length = out_lens[idx][beam_idx]
scores = beam_scores[idx, beam_idx]
hypo["tokens"] = beam_results[idx, beam_idx, : length]
hypo["score"] = scores
hypo["attention"] = None
hypo["alignment"] = None
hypo["positional_scores"] = torch.Tensor([scores / length] * length)
hypos.append(hypo)
finalized.append(hypos)
# elif self.infer == "greedy":
else:
ctc_probs = utils.log_softmax(ctc_logit, -1)
if self.ctc_self_ensemble:
if inter_logits_num != 0:
for i in range(inter_logits_num):
inter_logits_prob = utils.log_softmax(inter_logits[i].transpose(0, 1), -1)
ctc_probs += inter_logits_prob
topk_prob, topk_index = ctc_probs.topk(1, dim=2)
topk_prob = topk_prob.squeeze(-1)
topk_index = topk_index.squeeze(-1)
real_indexs = topk_index.masked_fill(encoder_outs["encoder_padding_mask"][0], self.blank).cpu()
real_probs = topk_prob.masked_fill(topk_index == self.blank, self.blank)
scores = -real_probs.sum(-1, keepdim=True).cpu()
for idx in range(bsz):
hypos = []
hypo = dict() hypo = dict()
length = out_lens[idx][beam_idx]
scores = beam_scores[idx, beam_idx] hyp = real_indexs[idx].unique_consecutive()
hypo["tokens"] = beam_results[idx, beam_idx, : length] hyp = hyp[hyp != self.blank]
hypo["score"] = scores length = len(hyp)
hypo["tokens"] = hyp
hypo["score"] = scores[idx]
hypo["attention"] = None hypo["attention"] = None
hypo["alignment"] = None hypo["alignment"] = None
hypo["positional_scores"] = torch.Tensor([scores / length] * length) hypo["positional_scores"] = torch.Tensor([hypo["score"] / length] * length)
hypos.append(hypo) hypos.append(hypo)
finalized.append(hypos) finalized.append(hypos)
return finalized return finalized
......
import logging import logging
import math import math
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils from fairseq import checkpoint_utils, utils
from fairseq.models.transformer import Embedding from fairseq.models.transformer import Embedding
from fairseq.models import ( from fairseq.models import (
......
...@@ -39,6 +39,9 @@ class CTC(nn.Module): ...@@ -39,6 +39,9 @@ class CTC(nn.Module):
self.post_process = "sentencepiece" self.post_process = "sentencepiece"
self.blank_idx = 0 self.blank_idx = 0
self.path = None
self.save_stream = None
def set_infer(self, is_infer, text_post_process, dictionary, path): def set_infer(self, is_infer, text_post_process, dictionary, path):
self.infer_decoding = is_infer self.infer_decoding = is_infer
self.post_process = text_post_process self.post_process = text_post_process
......
...@@ -46,6 +46,8 @@ class WerScorer(BaseScorer): ...@@ -46,6 +46,8 @@ class WerScorer(BaseScorer):
self.ref_length = 0 self.ref_length = 0
def add_string(self, ref, pred): def add_string(self, ref, pred):
ref = ref.replace("<<unk>>", "@")
pred = pred.replace("<<unk>>", "@")
ref_items = self.tokenizer.tokenize(ref).split() ref_items = self.tokenizer.tokenize(ref).split()
pred_items = self.tokenizer.tokenize(pred).split() pred_items = self.tokenizer.tokenize(pred).split()
self.distance += self.ed.eval(ref_items, pred_items) self.distance += self.ed.eval(ref_items, pred_items)
......
...@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for model in models: for model in models:
if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"): if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"):
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece", model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece",
src_dict, tgt_dict, translation_path) # os.path.dirname(translation_path)) src_dict, tgt_dict, translation_path)
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论