Commit f5fb7d7d by xuchen

add the joint decoding by ctc and attn

parent 26dfcb2f
......@@ -808,6 +808,10 @@ class GenerationConfig(FairseqDataclass):
default=0.0,
metadata={"help": "weight for lm probs for lm fusion"},
)
infer_ctc_weight: float = field(
default=0.0,
metadata={"help": "weight for ctc probs for lm fusion"},
)
# arguments for iterative refinement generator
iter_decode_eos_penalty: float = field(
......
......@@ -13,7 +13,10 @@ from fairseq.data import data_utils
from fairseq.models import FairseqIncrementalDecoder
from torch import Tensor
from fairseq.ngram_repeat_block import NGramRepeatBlock
from espnet.nets.ctc_prefix_score import CTCPrefixScore
import numpy
CTC_SCORING_RATIO = 1.5
class SequenceGenerator(nn.Module):
def __init__(
......@@ -35,6 +38,7 @@ class SequenceGenerator(nn.Module):
symbols_to_strip_from_output=None,
lm_model=None,
lm_weight=1.0,
ctc_weight=0.0,
):
"""Generates translations of a given source sentence.
......@@ -67,6 +71,8 @@ class SequenceGenerator(nn.Module):
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos() if eos is None else eos
self.blank = tgt_dict.index(self.blank_symbol) if hasattr(self, 'blank_symbol') else 0
self.symbols_to_strip_from_output = (
symbols_to_strip_from_output.union({self.eos})
if symbols_to_strip_from_output is not None
......@@ -107,6 +113,7 @@ class SequenceGenerator(nn.Module):
self.lm_model = lm_model
self.lm_weight = lm_weight
self.ctc_weight = ctc_weight
if self.lm_model is not None:
self.lm_model.eval()
......@@ -245,6 +252,22 @@ class SequenceGenerator(nn.Module):
# compute the encoder output for each beam
encoder_outs = self.model.forward_encoder(net_input)
# Get CTC lprobs and prep ctc_scorer
if self.ctc_weight > 0:
if encoder_outs[0].get("xctc_logit", None) and len(encoder_outs[0]["xctc_logit"]) > 0:
ctc_logit = encoder_outs[0]["xctc_logit"][0]
else:
ctc_logit = encoder_outs[0]["ctc_logit"][0]
input_length = ctc_logit.size(0)
ctc_lprobs = utils.log_softmax(ctc_logit, dim=-1).contiguous().transpose(0, 1)
hyp = {}
ctc_prefix_score = CTCPrefixScore(ctc_lprobs[0].detach().cpu().numpy(), self.blank, self.eos, numpy)
hyp["ctc_state_prev"] = ctc_prefix_score.initial_state()
hyp["ctc_score_prev"] = 0.0
ctc_beam = min(ctc_lprobs.shape[-1], int(beam_size * CTC_SCORING_RATIO))
ctc_hyps = {str(self.eos): hyp}
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long()
......@@ -329,6 +352,41 @@ class SequenceGenerator(nn.Module):
self.temperature,
)
if self.ctc_weight > 0 and step <= input_length:
ctc_lprobs = lprobs.clone()
ctc_lprobs[:, self.blank] = -math.inf # never select blank
local_best_scores, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
for b in range(tokens.size(0)):
hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
# Add the missing beam
if hyp_key not in ctc_hyps:
prev_hyp_key = " ".join(str(x) for x in tokens[b, : step].tolist())
local_best_ids[b] = tokens[b, step]
ctc_scores, ctc_states = ctc_prefix_score(
tokens[b, : step].cpu(), local_best_ids[b].cpu(), ctc_hyps[prev_hyp_key]["ctc_state_prev"]
)
for j in range(len(local_best_ids[b])):
new_key = prev_hyp_key + " " + str(local_best_ids[b][j].item())
ctc_hyps[new_key] = {}
ctc_hyps[new_key]["ctc_score_prev"] = ctc_scores[j]
ctc_hyps[new_key]["ctc_state_prev"] = ctc_states[j]
ctc_scores, ctc_states = ctc_prefix_score(
tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
)
lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
).to(device="cuda")
for j in range(len(local_best_ids[b])):
if step == 0 and b != 0:
continue
else:
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
if self.lm_model is not None:
lm_out = self.lm_model(tokens[:, : step + 1])
probs = self.lm_model.get_normalized_probs(
......@@ -340,6 +398,7 @@ class SequenceGenerator(nn.Module):
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.blank] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
# handle max length constraint
......
......@@ -512,7 +512,10 @@ class SpeechToTextTask(LegacyFairseqTask):
for s, i in self.tgt_dict.indices.items()
if SpeechToTextDataset.is_lang_tag(s)
}
extra_gen_cls_kwargs = {"symbols_to_strip_from_output": lang_token_ids}
if extra_gen_cls_kwargs is None:
extra_gen_cls_kwargs = {"symbols_to_strip_from_output": lang_token_ids}
else:
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids
return super().build_generator(
models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
......
......@@ -194,7 +194,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
# Initialize generator
gen_timer = StopwatchMeter()
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight}
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight, "ctc_weight": cfg.generation.infer_ctc_weight}
generator = task.build_generator(
models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论