......@@ -12,6 +12,8 @@ from typing import Optional
import numpy as np
import logging
import editdistance
import os
import sys
import torch
import torch.nn.functional as F
......@@ -62,14 +64,28 @@ class CtcCriterionConfig(FairseqDataclass):
metadata={"help": "weight of interleaved CTC loss for target sentence"},
cal_all_ctc: bool = field(
metadata={"help": "calculate all ctc results"},
ctc_self_distill_weight: float = field(
metadata={"help": "weight of the self distillation CTC loss"},
target_ctc_self_distill_weight: float = field(
metadata={"help": "weight of the self distillation CTC loss for target sentence"},
ctc_self_distill_prob: float = field(
metadata={"help": "probability to use distillation loss"},
ctc_self_distill_temperature: float = field(
metadata={"help": "temperature for ctc self distillation"},
wer_kenlm_model: Optional[str] = field(
......@@ -100,7 +116,7 @@ class CtcCriterionConfig(FairseqDataclass):
@register_criterion("ctc", dataclass=CtcCriterionConfig)
class CtcCriterion(FairseqCriterion):
def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask, ctc_weight=1.0):
def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask, ctc_weight=1.0, save_dir=None):
if cfg.wer_args is not None:
......@@ -136,29 +152,45 @@ class CtcCriterion(FairseqCriterion):
self.eos_idx = task.target_dictionary.eos()
self.post_process = cfg.post_process
self.sentence_avg = cfg.sentence_avg
self.save_dir = save_dir
self.cal_all_ctc = cfg.cal_all_ctc
self.ctc_weight = ctc_weight
self.interleaved_ctc_weight = cfg.interleaved_ctc_weight
self.target_ctc_weight = cfg.target_ctc_weight
self.target_interleaved_ctc_weight = cfg.target_interleaved_ctc_weight
self.ctc_self_distill_weight = cfg.ctc_self_distill_weight
self.ctc_self_distill_prob = cfg.ctc_self_distill_prob
self.target_ctc_self_distill_weight = float(cfg.target_ctc_self_distill_weight)
self.ctc_self_distill_prob = float(cfg.ctc_self_distill_prob)
self.ctc_self_distill_temperature = float(cfg.ctc_self_distill_temperature)
self.ctc_entropy = cfg.ctc_entropy
self.ctc_entropy_cutoff = cfg.ctc_entropy_cutoff
self.all_ctc_weight = self.ctc_weight + self.interleaved_ctc_weight + \
self.target_ctc_weight + self.target_interleaved_ctc_weight + \
self.ctc_self_distill_weight + self.ctc_entropy
self.ctc_self_distill_weight + self.target_ctc_self_distill_weight + \
if self.all_ctc_weight > 0:
# assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
self.ctc_names = []
self.use_ctc = (self.ctc_weight + self.interleaved_ctc_weight > 0)
self.use_target_ctc = (self.target_ctc_weight + self.target_interleaved_ctc_weight > 0)
self.use_source_distill = self.use_target_distill = False
if self.ctc_self_distill_prob > 0:
if self.ctc_self_distill_weight:
self.use_source_distill = True
if self.target_ctc_self_distill_weight > 0:
self.use_target_distill = True
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
ntokens = sample["ntokens"]
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {
"ntokens": ntokens,
"nsentences": sample["id"].numel(),
......@@ -168,15 +200,119 @@ class CtcCriterion(FairseqCriterion):
loss, logging_output = self.compute_ctc_loss(model, sample, net_output, logging_output)
return loss, sample_size, logging_output
def get_loss(self, lprobs, targets_flat, input_lengths, transcript_lengths):
def get_ground_truth_alignment(self, model, sample):
ctc_alignment_oracle = dict()
from fairseq.torch_imputer import best_alignment, imputer_loss
except ImportError:
logger.warning("Imputer is not available.")
src_tokens, src_lengths, prev_output_tokens = sample["net_input"].values()
with torch.no_grad():
encoder_out = model.encoder(src_tokens, src_lengths)
ctc_logit = None
if "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) != 0:
ctc_logit = encoder_out["ctc_logit"][0]
elif "interleaved_ctc_logits" in encoder_out and len(encoder_out["interleaved_ctc_logits"]) != 0:
ctc_logit = encoder_out["interleaved_ctc_logits"][-1]
ctc_alignment_oracle["source"] = None
if ctc_logit is not None:
if "transcript" in sample:
tokens = sample["transcript"]["tokens"]
tokens = sample["target"]
pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
target_lengths = pad_mask.sum(-1)
if "ctc_padding_mask" in encoder_out:
non_padding_mask = ~encoder_out["ctc_padding_mask"][0]
non_padding_mask = ~encoder_out["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
best_aligns = best_alignment(ctc_logit.float(), tokens, input_lengths, target_lengths,
self.pad_idx, zero_infinity=True)
best_aligns_pad = torch.tensor([a + [0] * (ctc_logit.size(0) - len(a)) for a in best_aligns],
device=ctc_logit.device, dtype=tokens.dtype)
oracle_pos = torch.div(best_aligns_pad, 2, rounding_mode='floor').clip(max=tokens.shape[1] - 1)
oracle = tokens.gather(-1, oracle_pos)
source_oracle = oracle.masked_fill(best_aligns_pad % 2 == 0, self.blank_idx)
ctc_alignment_oracle["source"] = [source_oracle, best_aligns_pad]
ctc_alignment_oracle["target"] = None
target_ctc_logit = None
if "target_ctc_logit" in encoder_out and len(encoder_out["target_ctc_logit"]) != 0:
target_ctc_logit = encoder_out["ctc_logit"][0]
elif "target_interleaved_ctc_logits" in encoder_out and len(
encoder_out["target_interleaved_ctc_logits"]) != 0:
target_ctc_logit = encoder_out["target_interleaved_ctc_logits"][-1]
if target_ctc_logit is not None:
target_tokens = sample["target"]
target_pad_mask = (target_tokens != self.pad_idx) & (target_tokens != self.eos_idx)
target_lengths = target_pad_mask.sum(-1)
if "ctc_padding_mask" in encoder_out:
non_padding_mask = ~encoder_out["ctc_padding_mask"][0]
non_padding_mask = ~encoder_out["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
best_aligns = best_alignment(target_ctc_logit.float(), target_tokens, input_lengths, target_lengths,
self.pad_idx, zero_infinity=True)
best_aligns_pad = torch.tensor([a + [0] * (ctc_logit.size(0) - len(a)) for a in best_aligns],
device=target_ctc_logit.device, dtype=target_tokens.dtype)
oracle_pos = (best_aligns_pad // 2).clip(max=tokens.shape[1] - 1)
oracle = tokens.gather(-1, oracle_pos)
target_oracle = oracle.masked_fill(best_aligns_pad % 2 == 0, self.blank_idx)
ctc_alignment_oracle["target"] = [target_oracle, best_aligns_pad]
return ctc_alignment_oracle
def get_ctc_loss(self, model, ctc_logit, targets, input_lengths, target_lengths, loss_coef):
lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
lprobs.batch_first = False
loss = 0
with torch.backends.cudnn.flags(enabled=False):
ctc_loss = self.ctc_loss(
for item_targets, item_target_lengths, item_coef in zip(targets, target_lengths, loss_coef):
loss += self.ctc_loss(
) * item_coef
return loss, lprobs
def get_ctc_self_distill_loss(distill_num, teacher_logit, student_logits, non_padding_mask):
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
for i in range(distill_num):
logit = student_logits[i]
if type(logit) == list:
student_logit = logit[0]
non_padding_mask = ~logit[1]
student_logit = logit
if student_logit.size() != teacher_logit.size():
loss = F.kl_div(
F.log_softmax(student_logit, dim=-1, dtype=torch.float32),
F.log_softmax(teacher_logit.detach(), dim=-1, dtype=torch.float32),
return ctc_loss
ctc_self_distill_loss += loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0).sum()
ctc_self_distill_num += 1
return ctc_self_distill_num, ctc_self_distill_loss
def compute_ctc_loss(self, model, sample, net_output, logging_output):
if "transcript" in sample:
......@@ -187,62 +323,36 @@ class CtcCriterion(FairseqCriterion):
non_padding_mask = ~net_output["ctc_padding_mask"][0]
non_padding_mask = ~net_output["encoder_padding_mask"][0]
# non_padding_mask = ~net_output["encoder_padding_mask"][0]
mixup = False
input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
if "mixup" in net_output and net_output["mixup"] is not None:
mixup = True
mixup_coef = net_output["mixup"]["coef"]
mixup_idx1 = net_output["mixup"]["index1"]
mixup_idx2 = net_output["mixup"]["index2"]
input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (tokens != self.pad_idx) & (
tokens != self.eos_idx
if mixup:
mask1 = pad_mask[mixup_idx1]
mask2 = pad_mask[mixup_idx2]
transcript_flat1 = tokens[[mixup_idx1]].masked_select(mask1)
transcript_flat2 = tokens[mixup_idx2].masked_select(mask2)
transcripts1 = tokens[[mixup_idx1]].masked_select(mask1)
transcripts2 = tokens[mixup_idx2].masked_select(mask2)
transcript_lengths1 = mask1.sum(-1)
transcript_lengths2 = mask2.sum(-1)
transcript_flat = [transcript_flat1, transcript_flat2]
transcripts = [transcripts1, transcripts2]
transcript_lengths = [transcript_lengths1, transcript_lengths2]
loss_coef = [mixup_coef, 1 - mixup_coef]
transcript_flat = [tokens.masked_select(pad_mask)]
mixup = False
transcripts = [tokens.masked_select(pad_mask)]
transcript_lengths = [pad_mask.sum(-1)]
loss_coef = [1]
ctc_loss = 0
ctc_entropy = 0
all_ctc_logits = dict()
self.ctc_names = []
lprobs = None
if self.ctc_weight > 0 and "ctc_logit" in net_output and len(net_output["ctc_logit"]) > 0:
ctc_logit = net_output["ctc_logit"][0]
lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
lprobs.batch_first = False
for flat, lengths, coef in zip(transcript_flat, transcript_lengths, loss_coef):
ctc_loss += self.get_loss(lprobs, flat, input_lengths, lengths) * coef
if self.ctc_entropy > 0:
if self.ctc_entropy_cutoff != 0:
# ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100]
# ctc_logit = ctc_logit / ctc_logit.sum(dim=-1, keepdim=True)
cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:self.ctc_entropy_cutoff]
cut_ctc_logit = cut_ctc_logit / cut_ctc_logit.sum(dim=-1, keepdim=True)
ctc_entropy = Categorical(logits=cut_ctc_logit).entropy().sum()
ctc_entropy = Categorical(logits=ctc_logit).entropy().sum()
logging_output["ctc_entropy"] = utils.item(
logging_output["ctc_loss"] = utils.item(
target_lprobs = None
interleaved_ctc_num = 0
interleaved_ctc_loss = 0
......@@ -251,128 +361,170 @@ class CtcCriterion(FairseqCriterion):
# calculate the interleaved CTC loss
if self.interleaved_ctc_weight > 0 and interleaved_ctc_num > 0:
logits = net_output["interleaved_ctc_logits"]
for i in range(interleaved_ctc_num):
out = net_output["interleaved_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
inter_input_lengths = padding.long().sum(-1)
logit = logits[i]
if type(logit) == list:
inter_ctc_logit = logit[0]
inter_input_lengths = (~logit[1]).long().sum(-1)
inter_ctc_logit = out
inter_ctc_logit = logit
inter_input_lengths = input_lengths
inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False
for flat, lengths, coef in zip(transcript_flat, transcript_lengths, loss_coef):
interleaved_ctc_loss += self.get_loss(inter_lprobs, flat, inter_input_lengths, lengths) * coef
all_ctc_logits["interleaved_ctc_logit%d" % i] = [inter_ctc_logit, inter_input_lengths]
inter_loss, inter_lprobs = self.get_ctc_loss(
model, inter_ctc_logit, transcripts, inter_input_lengths, transcript_lengths, loss_coef)
interleaved_ctc_loss += inter_loss
lprobs = inter_lprobs
interleaved_ctc_loss /= interleaved_ctc_num
logging_output["interleaved_ctc_loss"] = utils.item(
if lprobs is None:
lprobs = inter_lprobs
ctc_loss = 0
ctc_entropy = 0
if self.ctc_weight > 0 and "ctc_logit" in net_output and len(net_output["ctc_logit"]) > 0:
ctc_logit = net_output["ctc_logit"][0]
all_ctc_logits["ctc_logit"] = [ctc_logit, input_lengths]
target_ctc_loss = 0
target_interleaved_ctc_loss = 0
ctc_loss, lprobs = self.get_ctc_loss(
model, ctc_logit, transcripts, input_lengths, transcript_lengths, loss_coef)
if self.ctc_entropy > 0:
if self.ctc_entropy_cutoff != 0:
cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:self.ctc_entropy_cutoff]
cut_ctc_logit = cut_ctc_logit / cut_ctc_logit.sum(dim=-1, keepdim=True)
ctc_entropy = Categorical(logits=cut_ctc_logit).entropy().sum()
ctc_entropy = Categorical(logits=ctc_logit).entropy().sum()
logging_output["ctc_entropy"] = utils.item(
logging_output["ctc_loss"] = utils.item(
# calculate the target CTC loss
if self.target_ctc_weight > 0 or self.target_interleaved_ctc_weight > 0:
target = sample["target"]
pad_mask = (target != self.pad_idx) & (target != self.eos_idx)
target_ctc_loss = 0
target_interleaved_ctc_loss = 0
target_interleaved_ctc_num = 0
if self.use_target_ctc:
target_tokens = sample["target"]
target_pad_mask = (target_tokens != self.pad_idx) & (target_tokens != self.eos_idx)
target_no_padding_mask = ~target_pad_mask
if mixup:
mask1 = pad_mask[mixup_idx1]
mask2 = pad_mask[mixup_idx2]
target_flat1 = target.masked_select(mask1)
target_flat2 = target.masked_select(mask2)
transcript_lengths1 = mask1.sum(-1)
transcript_lengths2 = mask2.sum(-1)
target_flat = [target_flat1, target_flat2]
target_length = [transcript_lengths1, transcript_lengths2]
mask1 = target_pad_mask[mixup_idx1]
mask2 = target_pad_mask[mixup_idx2]
target_tokens1 = target_tokens.masked_select(mask1)
target_tokens2 = target_tokens.masked_select(mask2)
target_lengths1 = mask1.sum(-1)
target_lengths2 = mask2.sum(-1)
target_tokens = [target_tokens1, target_tokens2]
target_lengths = [target_lengths1, target_lengths2]
loss_coef = [mixup_coef, 1 - mixup_coef]
target_flat = [target.masked_select(pad_mask)]
target_length = [pad_mask.sum(-1)]
target_tokens = [target_tokens.masked_select(target_pad_mask)]
target_lengths = [target_pad_mask.sum(-1)]
loss_coef = [1]
if self.target_ctc_weight > 0:
assert "target_ctc_logit" in net_output
target_ctc_logit = net_output["target_ctc_logit"]
tgt_lprobs = model.get_normalized_probs(
[target_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
tgt_lprobs.batch_first = False
for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_ctc_loss += self.get_loss(tgt_lprobs, flat, input_lengths, lengths) * coef
target_interleaved_ctc_num = 0
if "target_interleaved_ctc_logits" in net_output:
target_interleaved_ctc_num = len(net_output["target_interleaved_ctc_logits"])
if target_interleaved_ctc_num != 0 and self.target_interleaved_ctc_weight > 0:
logits = net_output["target_interleaved_ctc_logits"]
for i in range(target_interleaved_ctc_num):
out = net_output["target_interleaved_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
tgt_input_lengths = padding.long().sum(-1)
logit = logits[i]
if type(logit) == list:
target_inter_ctc_logit = logit[0]
padding = ~logit[1]
inter_input_lengths = padding.long().sum(-1)
inter_ctc_logit = out
tgt_input_lengths = input_lengths
tgt_inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
tgt_inter_lprobs.batch_first = False
target_inter_ctc_logit = logit
inter_input_lengths = input_lengths
for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_interleaved_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths,
lengths) * coef
all_ctc_logits["target_interleaved_ctc_logit%d" % i] = [target_inter_ctc_logit, inter_input_lengths]
inter_loss, target_inter_lprobs = self.get_ctc_loss(
model, target_inter_ctc_logit, target_tokens, inter_input_lengths, target_lengths, loss_coef)
target_interleaved_ctc_loss += inter_loss
target_lprobs = target_inter_lprobs
target_interleaved_ctc_loss /= target_interleaved_ctc_num
logging_output["target_interleaved_ctc_loss"] = utils.item(
# calculate the self distillation CTC loss
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0 and \
torch.rand(1).uniform_() < self.ctc_self_distill_prob:
for i in range(interleaved_ctc_num):
out = net_output["interleaved_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
non_padding_mask = ~out[1]
inter_ctc_logit = out
if self.target_ctc_weight > 0:
assert "target_ctc_logit" in net_output
target_ctc_logit = net_output["target_ctc_logit"][0]
all_ctc_logits["target_ctc_logit"] = [target_ctc_logit, input_lengths]
if inter_ctc_logit.size() != ctc_logit.size():
target_ctc_loss, target_lprobs = self.get_ctc_loss(
model, target_ctc_logit, target_tokens, input_lengths, target_lengths, loss_coef)
logging_output["target_ctc_loss"] = utils.item(
loss = F.kl_div(
F.log_softmax(inter_ctc_logit, dim=-1, dtype=torch.float32),
F.log_softmax(ctc_logit, dim=-1, dtype=torch.float32).detach(),
loss = loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0)
loss = loss.sum()
ctc_self_distill_loss += loss
ctc_self_distill_num += 1
# calculate the self distillation CTC loss
ctc_self_distill_loss = 0
if self.use_source_distill or self.use_target_distill:
ctc_self_distill_choice = torch.rand(1).uniform_()
if ctc_self_distill_num != 0:
ctc_self_distill_loss /= ctc_self_distill_num
logging_output["ctc_self_distill_loss"] = utils.item(
cal_source_distill = cal_target_distill = False
if not
cal_source_distill = True if self.use_source_distill else False
cal_target_distill = True if self.use_target_distill else False
if ctc_self_distill_choice <= self.ctc_self_distill_prob:
if self.use_source_distill and self.use_target_distill:
cal_source_distill = True if ctc_self_distill_choice > self.ctc_self_distill_prob / 2 else False
cal_target_distill = not cal_source_distill
cal_source_distill = self.use_source_distill
cal_target_distill = self.use_target_distill
# source self distillation
if cal_source_distill:
ctc_self_distill_num = 0
non_padding = non_padding_mask
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0:
teacher_logit = ctc_logit
student_logits = net_output["interleaved_ctc_logits"]
ctc_self_distill_num = interleaved_ctc_num
elif self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 1:
teacher_logit = net_output["interleaved_ctc_logits"][-1]
student_logits = net_output["interleaved_ctc_logits"][:-1]
ctc_self_distill_num = interleaved_ctc_num - 1
if ctc_self_distill_num != 0:
ctc_self_distill_num, source_ctc_self_distill_loss = \
ctc_self_distill_num, teacher_logit, student_logits, non_padding)
source_ctc_self_distill_loss /= ctc_self_distill_num
logging_output["ctc_self_distill_loss"] = utils.item(
ctc_self_distill_loss += source_ctc_self_distill_loss * self.ctc_self_distill_weight
# target self distillation
if cal_target_distill:
ctc_self_distill_num = 0
non_padding = non_padding_mask
if self.target_ctc_weight > 0 and self.target_ctc_self_distill_weight > 0 and target_interleaved_ctc_num > 0:
teacher_logit = target_ctc_logit
student_logits = net_output["target_interleaved_ctc_logits"]
ctc_self_distill_num = target_interleaved_ctc_num
elif self.target_ctc_self_distill_weight > 0 and target_interleaved_ctc_num > 1:
teacher_logit = net_output["target_interleaved_ctc_logits"][-1]
student_logits = net_output["target_interleaved_ctc_logits"][:-1]
ctc_self_distill_num = target_interleaved_ctc_num - 1
if ctc_self_distill_num != 0:
ctc_self_distill_num, target_ctc_self_distill_loss = \
ctc_self_distill_num, teacher_logit, student_logits, non_padding)
target_ctc_self_distill_loss /= ctc_self_distill_num
logging_output["target_ctc_self_distill_loss"] = utils.item(
ctc_self_distill_loss += target_ctc_self_distill_loss * self.target_ctc_self_distill_weight
loss = \
self.ctc_weight * ctc_loss + \
self.interleaved_ctc_weight * interleaved_ctc_loss + \
self.target_ctc_weight * target_ctc_loss + \
self.target_interleaved_ctc_weight * target_interleaved_ctc_loss + \
self.ctc_self_distill_weight * ctc_self_distill_loss + \
ctc_self_distill_loss + \
self.ctc_entropy * ctc_entropy
logging_output["all_ctc_loss"] = utils.item(
......@@ -386,74 +538,58 @@ class CtcCriterion(FairseqCriterion):
if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss)
if not and self.ctc_weight + self.interleaved_ctc_weight > 0:
with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
target = tokens
if mixup:
idx = mixup_idx1
if mixup_coef < 0.5:
idx = mixup_idx2
target = target[idx]
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
wv_errs = 0
for lp, t, inp_l in zip(
lp = lp[:inp_l].unsqueeze(0)
decoded = None
if self.w2l_decoder is not None:
decoded = self.w2l_decoder.decode(lp)
if len(decoded) < 1:
decoded = None
if not
if hasattr(model.encoder, "ctc_valid"):
if lprobs is not None:
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
if mixup:
idx = mixup_idx1 if mixup_coef > 0.5 else mixup_idx2
tokens = tokens[idx]
c_err, c_len, w_errs, w_len, wv_errs = model.encoder.ctc_valid(
lprobs_t, tokens, input_lengths, self.task.source_dictionary, lang="source")
logging_output["wv_errors"] = wv_errs
logging_output["w_errors"] = w_errs
logging_output["w_total"] = w_len
logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len
if target_lprobs is not None:
target_lprobs_t = target_lprobs.transpose(0, 1).float().contiguous().cpu()
target_tokens = sample["target"]
if mixup:
idx = mixup_idx1 if mixup_coef > 0.5 else mixup_idx2
target_tokens = target_tokens[idx]
c_err, c_len, w_errs, w_len, wv_errs = model.encoder.ctc_valid(
target_lprobs_t, target_tokens, input_lengths, self.task.target_dictionary, lang="target")
logging_output["target_wv_errors"] = wv_errs
logging_output["target_w_errors"] = w_errs
logging_output["target_w_total"] = w_len
logging_output["target_c_errors"] = c_err
logging_output["target_c_total"] = c_len
if self.cal_all_ctc:
logging_output["save_dir"] = self.save_dir
for name, items in all_ctc_logits.items():
logit, lengths = items
if "target" in name:
dictionary = self.task.target_dictionary
ctc_tokens = target_tokens
lang = "target"
decoded = decoded[0]
if len(decoded) < 1:
decoded = None
decoded = decoded[0]
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
dictionary = self.task.source_dictionary
ctc_tokens = tokens
lang = "source"
c_err, c_len, w_errs, w_len, wv_errs = model.encoder.ctc_valid(
logit, ctc_tokens, lengths, dictionary, lang)
cer = c_err * 100 / c_len
wer = w_errs * 100 / w_len
if decoded is not None and "words" in decoded:
pred_words = decoded["words"]
w_errs += editdistance.eval(pred_words, targ_words)
wv_errs += editdistance.eval(pred_words_raw, targ_words)
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
w_len += len(targ_words)
logging_output["wv_errors"] = wv_errs
logging_output["w_errors"] = w_errs
logging_output["w_total"] = w_len
logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len
logging_output["dump_%s_cer" % name] = cer
logging_output["dump_%s_wer" % name] = wer
return loss, logging_output
......@@ -479,6 +615,9 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_loss_sum = utils.item(
sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs)
target_ctc_self_distill_loss_sum = utils.item(
sum(log.get("target_ctc_self_distill_loss", 0) for log in logging_outputs)
all_ctc_loss_sum = utils.item(
sum(log.get("all_ctc_loss", 0) for log in logging_outputs)
......@@ -552,6 +691,13 @@ class CtcCriterion(FairseqCriterion):
if target_ctc_self_distill_loss_sum > 0:
target_ctc_self_distill_loss_sum / sample_size / math.log(2),
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
......@@ -589,14 +735,48 @@ class CtcCriterion(FairseqCriterion):
if meters["_w_total"].sum > 0
else float("nan"),
# metrics.log_derived(
# "raw_wer",
# lambda meters: safe_round(
# meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
# )
# if meters["_w_total"].sum > 0
# else float("nan"),
# )
target_c_errors = sum(log.get("target_c_errors", 0) for log in logging_outputs)
metrics.log_scalar("_target_c_errors", target_c_errors)
target_c_total = sum(log.get("target_c_total", 0) for log in logging_outputs)
metrics.log_scalar("_target_c_total", target_c_total)
target_w_errors = sum(log.get("target_w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_target_w_errors", target_w_errors)
target_w_total = sum(log.get("target_w_total", 0) for log in logging_outputs)
metrics.log_scalar("_target_w_total", target_w_total)
if target_c_total > 0:
lambda meters: safe_round(
meters["_target_c_errors"].sum * 100.0 / meters["_target_c_total"].sum, 3
if meters["_target_c_total"].sum > 0
else float("nan"),
if target_w_total > 0:
lambda meters: safe_round(
meters["_target_w_errors"].sum * 100.0 / meters["_target_w_total"].sum, 3
if meters["_target_w_total"].sum > 0
else float("nan"),
# save_dir = logging_outputs.get("save_dir", None)
# if save_dir is not None and os.path.exists(save_dir):
# out = open(os.path.join(save_dir, "ctc_results"), "a")
# else:
# out = sys.stdout
# for key in logging_outputs:
# if key.startswith("dump"):
# print("%s: %.2f" % (key, logging_outputs[key]), end="\t", file=out)
# print("", file=out)
# out.close()
# out = sys.stdout
def logging_outputs_can_be_summed() -> bool:
......@@ -23,12 +23,14 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
def __init__(self, task, label_smoothing,
cfg: CtcCriterionConfig,
super().__init__(task, sentence_avg, label_smoothing)
self.report_accuracy = True
self.ctc_weight = ctc_weight
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight)
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight, save_dir)
self.save_dir = save_dir
def add_args(parser):
......@@ -62,7 +64,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
encoder_out = model.encoder(src_tokens, src_lengths,
text_src_tokens, text_src_lengths)
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
if and getattr(model.encoder, "sae_ground_truth_ratio", 0) != 0:
ctc_alignment_oracle = self.ctc_criterion.get_ground_truth_alignment(model, sample)
encoder_out = model.encoder(src_tokens, src_lengths,
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
use_mixup = False
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
......@@ -2,7 +2,6 @@ import logging
from typing import Dict, Optional
import torch
import torch.nn as nn
from fairseq import checkpoint_utils, utils
from fairseq.models import (
......@@ -138,7 +137,6 @@ class CTCDecoder(object):
# the max beam size is the dictionary size - 1, since we never select pad
self.beam_size = min(self.beam_size, self.vocab_size - 1)
# from fairseq.sequence_generator import EnsembleModel
from fairseq.sequence_generator import EnsembleModel
if isinstance(models, EnsembleModel):
self.model = models
......@@ -240,8 +238,12 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
......@@ -276,7 +278,9 @@ def base_architecture(args):
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False)
args.sae_embed_norm = getattr(args, "sae_embed_norm", False)
args.sae_out_norm = getattr(args, "sae_out_norm", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
......@@ -310,8 +314,6 @@ def base_architecture(args):
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
@register_model_architecture("s2t_ctc", "s2t_ctc_s")
def s2t_ctc_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
......@@ -125,7 +125,7 @@ class S2TSATEModel(S2TTransformerModel):
# target CTC
help="ctc layer for target sentence",
......@@ -233,15 +233,15 @@ class S2TSATEModel(S2TTransformerModel):
return cls(encoder, decoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens):
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out, **kwargs
return decoder_out
......@@ -286,7 +286,9 @@ class TextualEncoder(FairseqEncoder):
self.use_ctc = getattr(args, "target_ctc_weight", 0) > 0
if self.use_ctc:
self.ctc_layer = getattr(args, "target_ctc_layer", layer_num)
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.ctc_layer == 0:
self.ctc_layer = layer_num
self.inter_ctc = True if self.ctc_layer != layer_num else False
if self.inter_ctc:"Target CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(embed_dim,
......@@ -294,13 +296,16 @@ class TextualEncoder(FairseqEncoder):
need_layernorm=True if self.inter_ctc else False)
if embed_tokens is not None:
if embed_tokens is not None and args.share_target_ctc_and_embed and \
self.ctc.ctc_projection.weight.size() == embed_tokens.weight.size():
self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.interleaved_ctc_layers = []
self.target_interleaved_ctc_layers = getattr(args, "target_interleaved_ctc_layers", None)
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
if self.target_interleaved_ctc_layers is not None:
target_interleaved_ctc_layers = self.target_interleaved_ctc_layers.split(",")
for layer_idx in target_interleaved_ctc_layers:
......@@ -337,7 +342,7 @@ class TextualEncoder(FairseqEncoder):
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
def forward(self, x, encoder_padding_mask=None, history=None):
def forward(self, x, encoder_padding_mask=None, history=None, **kwargs):
if self.encoder_embed_norm:
x = self.embed_ln(x)
......@@ -356,7 +361,7 @@ class TextualEncoder(FairseqEncoder):
layer_idx += 1
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
target_ctc_logit = self.ctc(x.clone())
target_ctc_logit = self.ctc(x.clone(), encoder_padding_mask, "Target Layer %d" % layer_idx)
if layer_idx != self.layer_num and layer_idx in self.interleaved_ctc_layers:
if self.interleaved_ctc_drop_prob > 0:
......@@ -365,11 +370,23 @@ class TextualEncoder(FairseqEncoder):
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
logit = self.ctc(norm_x, encoder_padding_mask, "Target Layer %d" % layer_idx)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask)
# CTC alignment
oracle = None
oracle_mask = None
force_emit = None
if self.sae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ctc_alignment_oracle is not None and ctc_alignment_oracle["target"] is not None:
oracle, best_aligns_pad = ctc_alignment_oracle["target"]
oracle_mask = (torch.rand(oracle.size(),
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask, oracle, oracle_mask)
if history is not None:
......@@ -381,7 +398,7 @@ class TextualEncoder(FairseqEncoder):
x = self.layer_norm(x)
if self.use_ctc and target_ctc_logit is None:
target_ctc_logit = self.ctc(x)
target_ctc_logit = self.ctc(x, encoder_padding_mask, "Target output")
return x, target_ctc_logit, target_interleaved_ctc_logits
......@@ -435,6 +452,7 @@ class S2TSATEEncoder(FairseqEncoder):
self.freeze_acoustic_encoder = getattr(args, "freeze_acoustic_encoder", False)
self.freeze_textual_encoder = getattr(args, "freeze_textual_encoder", False)
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
if getattr(args, "use_enc_dlcl", False):
layer_num = args.encoder_layers + args.text_encoder_layers + 2
......@@ -443,11 +461,24 @@ class S2TSATEEncoder(FairseqEncoder):
self.history = None
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
if hasattr(self, "ctc"):
if hasattr(self.acoustic_encoder, "ctc"):
assert src_dict is not None
self.acoustic_encoder.ctc.set_infer(ctc_infer, post_process, src_dict)
if hasattr(self.textual_encoder, "ctc"):
assert tgt_dict is not None
self.textual_encoder.ctc.set_infer(ctc_infer, post_process, tgt_dict)
def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"):
if lang == "source":
if hasattr(self.acoustic_encoder, "ctc"):
return self.acoustic_encoder.ctc.valid(lprobs, targets, input_lengths, dictionary)
logger.error("No ctc module in textual encoder")
if hasattr(self.textual_encoder, "ctc"):
self.textual_encoder.ctc.set_infer(ctc_infer, post_process, tgt_dict)
return self.textual_encoder.ctc.valid(lprobs, targets, input_lengths, dictionary)
logger.error("No ctc module in textual encoder")
def forward(self, src_tokens, src_lengths=None, **kwargs):
if self.history is not None:
......@@ -455,9 +486,9 @@ class S2TSATEEncoder(FairseqEncoder):
if self.freeze_acoustic_encoder:
with torch.no_grad():
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths)
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths, **kwargs)
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths)
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths, **kwargs)
encoder_out = acoustic_encoder_out["encoder_out"][0]
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
......@@ -490,16 +521,16 @@ class S2TSATEEncoder(FairseqEncoder):
if self.freeze_textual_encoder:
with torch.no_grad():
x, target_ctc_logit, target_interleaved_ctc_logits = self.textual_encoder(x, encoder_padding_mask,
self.history, **kwargs)
x, target_ctc_logit, target_interleaved_ctc_logits = self.textual_encoder(x, encoder_padding_mask,
self.history, **kwargs)
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C
"interleaved_ctc_logits": acoustic_encoder_out.get("interleaved_ctc_logits", []), # B x T x C
"target_ctc_logit": target_ctc_logit, # B x T x C
"target_ctc_logit": [target_ctc_logit], # B x T x C
"target_interleaved_ctc_logits": target_interleaved_ctc_logits, # B x T x C
"ctc_padding_mask": [ctc_padding_mask], # B x T
"encoder_padding_mask": [encoder_padding_mask], # B x T
......@@ -447,6 +447,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="cutoff of the distribution in sae",
help="the ratio for ground truth in sae",
help="share the weight of ctc and sae",
......@@ -570,13 +576,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
lprobs.batch_first = True
return lprobs
def forward(self, src_tokens, src_lengths, prev_output_tokens):
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
......@@ -655,6 +661,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
self.interleaved_ctc_layers = []
if args.interleaved_ctc_layers is not None:
interleaved_ctc_layers = args.interleaved_ctc_layers.split(",")
......@@ -681,6 +688,7 @@ class S2TTransformerEncoder(FairseqEncoder):
"out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"gt_ratio": self.sae_ground_truth_ratio,
"drop_prob": getattr(args, "sae_drop_prob", 0),
......@@ -717,6 +725,14 @@ class S2TTransformerEncoder(FairseqEncoder):
assert src_dict is not None
self.ctc.set_infer(ctc_infer, post_process, src_dict)
def ctc_valid(self, lprobs, targets, input_lengths,
dictionary, lang="source"):
if hasattr(self, "ctc"):
return self.ctc.valid(lprobs, targets, input_lengths,
logger.error("No ctc module in textual encoder")
def set_debug_var(self, debug_var_flag):
self.debug_var = debug_var_flag
......@@ -879,7 +895,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone(), encoder_padding_mask)
ctc_logit = self.ctc(x.clone(), encoder_padding_mask, "Source Layer %d" % layer_idx)
# interleaved CTC
if layer_idx in self.interleaved_ctc_layers:
......@@ -889,15 +905,27 @@ class S2TTransformerEncoder(FairseqEncoder):
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, encoder_padding_mask)
logit = self.ctc(norm_x, encoder_padding_mask, "Source Layer %d" % layer_idx)
logit = logit.clamp(min=-1e8 if logit.dtype == torch.float32 else -1e4,
max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask)
# CTC alignment
oracle = None
oracle_mask = None
force_emit = None
if self.sae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ctc_alignment_oracle is not None and ctc_alignment_oracle["source"] is not None:
oracle, best_aligns_pad = ctc_alignment_oracle["source"]
oracle_mask = (torch.rand(oracle.size(),
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask, oracle, oracle_mask)
self.show_debug(x, "x after sae")
# gather cosine similarity
......@@ -917,7 +945,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.show_debug(x, "x after encoding layer norm")
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, encoder_padding_mask)
ctc_logit = self.ctc(x, encoder_padding_mask, "Source output")
self.show_debug(x, "x after ctc")
return {
......@@ -925,6 +953,7 @@ class S2TTransformerEncoder(FairseqEncoder):
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"interleaved_ctc_logits": interleaved_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
# "oracle": [oracle, oracle_mask, force_emit],
"mixup": mixup,
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
......@@ -315,7 +315,7 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
help="upsampling ratio of the representation for CTC calculation",
......@@ -355,6 +355,24 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
help="share the weight of ctc and sae",
help="use the layer norm for embed output",
help="use the layer norm for final output",
help="the ratio for ground truth in sae",
# fmt: on
......@@ -625,6 +643,11 @@ class TransformerCTCEncoder(FairseqEncoder):"Interleaved CTC loss in layer %d" % layer_idx)
self.un_sample = torch.nn.Upsample(scale_factor=self.interleaved_ctc_upsampling_ratio, mode="linear",
self.down_sample = torch.nn.Upsample(scale_factor=1 / self.interleaved_ctc_upsampling_ratio, mode="linear",
if not self.use_ctc:
self.ctc = CTC(embed_dim,
......@@ -633,10 +656,14 @@ class TransformerCTCEncoder(FairseqEncoder):
self.ctc.ctc_projection.weight = decoder_embed_tokens.weight
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
strategy = {
"embed_norm": getattr(args, "sae_embed_norm", False),
"out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
"gt_ratio": self.sae_ground_truth_ratio,
self.sae = Adapter(embed_dim, args.sae_adapter,
......@@ -645,9 +672,9 @@ class TransformerCTCEncoder(FairseqEncoder):
if args.share_ctc_and_sae and hasattr(self.sae, "embed_adapter"):
self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
if hasattr(self, "ctc"):
self.pool = nn.MaxPool1d(kernel_size=self.interleaved_ctc_upsampling_ratio,
# if hasattr(self, "ctc"):
# self.pool = nn.MaxPool1d(kernel_size=self.interleaved_ctc_upsampling_ratio,
# stride=self.interleaved_ctc_upsampling_ratio)
def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args)
......@@ -679,6 +706,7 @@ class TransformerCTCEncoder(FairseqEncoder):
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
......@@ -706,7 +734,8 @@ class TransformerCTCEncoder(FairseqEncoder):
return self.forward_scriptable(src_tokens,
def upsampling(self, x, padding):
ratio = self.interleaved_ctc_upsampling_ratio
......@@ -714,12 +743,17 @@ class TransformerCTCEncoder(FairseqEncoder):
return x
if len(x.size()) == 3:
bsz, seq_len, dim = x.size()
up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
# bsz, seq_len, dim = x.size()
# up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
seq_len, bsz, dim = x.size()
x = x.permute(1, 2, 0)
up_x = self.un_sample(x)
up_x = up_x.permute(2, 0, 1)
bsz, seq_len = x.size()
up_x = x.unsqueeze(2).expand(-1, -1, ratio).reshape(bsz, -1)
up_padding = padding.unsqueeze(-1).expand(-1, -1, ratio).reshape(bsz, -1)
up_padding = padding.unsqueeze(-1).expand(-1, -1, int(ratio)).reshape(bsz, -1)
# output_length = int(seq_len * ratio * 2/3)
# select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
......@@ -742,6 +776,14 @@ class TransformerCTCEncoder(FairseqEncoder):
assert tgt_dict is not None
self.ctc.set_infer(ctc_infer, post_process, tgt_dict)
def ctc_valid(self, lprobs, targets, input_lengths,
dictionary, lang="source"):
if hasattr(self, "ctc"):
return self.ctc.valid(lprobs, targets, input_lengths,
logger.error("No ctc module in textual encoder")
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
......@@ -752,6 +794,7 @@ class TransformerCTCEncoder(FairseqEncoder):
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
......@@ -783,10 +826,10 @@ class TransformerCTCEncoder(FairseqEncoder):
if self.history is not None:
ctc_padding_mask = encoder_padding_mask
if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
src_tokens, encoder_padding_mask = self.upsampling(src_tokens, encoder_padding_mask)
ctc_padding_mask = encoder_padding_mask
# ctc_padding_mask = encoder_padding_mask
# if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
# src_tokens, encoder_padding_mask = self.upsampling(src_tokens, encoder_padding_mask)
# ctc_padding_mask = encoder_padding_mask
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
......@@ -796,6 +839,8 @@ class TransformerCTCEncoder(FairseqEncoder):
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# x, encoder_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_padding_mask = encoder_padding_mask
encoder_states = []
......@@ -824,6 +869,7 @@ class TransformerCTCEncoder(FairseqEncoder):
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_logit = self.ctc(x.clone(), ctc_padding_mask)
# Interleaved CTC
......@@ -833,18 +879,31 @@ class TransformerCTCEncoder(FairseqEncoder):
if p < self.interleaved_ctc_drop_prob:
x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, ctc_padding_mask)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, _ = self.sae([norm_x, prob])
# x = x.permute(1, 2, 0)
# x = self.pool(x)
# x = x.permute(2, 0, 1)
# encoder_padding_mask = org_encoder_padding_mask
# CTC alignment
oracle = None
oracle_mask = None
force_emit = None
if self.sae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ctc_alignment_oracle is not None and ctc_alignment_oracle["source"] is not None:
oracle, best_aligns_pad = ctc_alignment_oracle["source"]
oracle_mask = (torch.rand(oracle.size(),
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, _ = self.sae([norm_x, prob], ctc_padding_mask, oracle, oracle_mask)
x = x.permute(1, 2, 0)
# x = nn.functional.interpolate(x, scale_factor=1/self.interleaved_ctc_upsampling_ratio, mode="linear")
x = self.down_sample(x)
x = x.permute(2, 0, 1)
if self.history is not None:
......@@ -98,7 +98,7 @@ class Adapter(nn.Module):
self.ctc_compress = getattr(CTCCompressStrategy, ctc_compress_strategy)"CTC Compress Strategy: %s" % ctc_compress_strategy)
if "league" in self.adapter_type:
if self.cal_context:
self.distribution_cutoff = strategy.get("distribution_cutoff", None)
if self.distribution_cutoff is not None:
self.distribution_cutoff = int(self.distribution_cutoff)
......@@ -107,17 +107,21 @@ class Adapter(nn.Module):
self.drop_prob = strategy.get("drop_prob", 0)
if self.drop_prob != 0:"Adapter drop probability: %f" % self.drop_prob)
self.ground_truth_ratio = strategy.get("gt_ratio", 0)
self.out_norm = strategy.get("out_norm", False)
if self.out_norm:
self.out_ln = LayerNorm(dim)
def forward(self, x, padding=None):
def forward(self, x, padding=None, oracle=None, oracle_mask=None):
representation, distribution = x
distribution = distribution.type_as(representation)
seq_len, bsz, dim = representation.size()
org_distribution = distribution
distribution = distribution.contiguous().view(-1, distribution.size(-1))
vocab_size = distribution.size(-1)
distribution = distribution.contiguous().view(-1, vocab_size)
linear_out = None
soft_out = None
......@@ -125,18 +129,32 @@ class Adapter(nn.Module):
linear_out = self.linear_adapter(representation)
if self.cal_context:
if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), org_distribution.size(-1) - 1)
cutoff = min(int(self.distribution_cutoff), vocab_size - 1)
# threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
# distribution = torch.where(
# org_distribution > threshold, org_distribution, torch.zeros_like(org_distribution)
# )
threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True)
distribution = torch.where(
threshold > 0.9, org_distribution, torch.zeros_like(org_distribution)
distribution = distribution.view(-1, distribution.size(-1))
soft_out =, self.embed_adapter.weight).view(seq_len, bsz, -1)
# threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True)
# distribution = torch.where(
# threshold > 0.9, org_distribution, torch.zeros_like(org_distribution)
# )
# distribution = distribution.view(-1, vocab_size)
distribution[:, 0] = 0
distribution = distribution / distribution.sum(-1, keepdim=True)
if self.ground_truth_ratio > 0 and oracle is not None:
oracle = oracle.unsqueeze(-1)
oracle_one_hot = (oracle == torch.arange(vocab_size, device=oracle.device).unsqueeze(0)).\
to(distribution.dtype).transpose(0, 1)
oracle_mask = oracle_mask.transpose(0, 1).unsqueeze(-1).repeat(1, 1, vocab_size)
modify_dist = oracle_mask * oracle_one_hot + ~oracle_mask * org_distribution
soft_out =, vocab_size), self.embed_adapter.weight).view(seq_len, bsz, -1)
soft_out =, self.embed_adapter.weight).view(seq_len, bsz, -1)
if self.embed_norm:
soft_out = self.embed_ln(soft_out)
......@@ -37,13 +37,14 @@ class CTC(nn.Module):
self.dictionary = dictionary
self.infer_decoding = False
self.post_process = "sentencepiece"
self.blank_idx = 0
def set_infer(self, is_infer, text_post_process, dictionary):
self.infer_decoding = is_infer
self.post_process = text_post_process
self.dictionary = dictionary
def forward(self, x, padding=None):
def forward(self, x, padding=None, tag=None):
if self.need_layernorm:
x = self.LayerNorm(x)
......@@ -52,7 +53,7 @@ class CTC(nn.Module):
if not and self.infer_decoding:
assert self.dictionary is not None
input_lengths = (~padding).sum(-1)
self.infer(x.transpose(0, 1).float().contiguous().cpu(), input_lengths)
self.infer(x.transpose(0, 1).float().contiguous().cpu(), input_lengths, tag)
return x
......@@ -65,7 +66,7 @@ class CTC(nn.Module):
def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1)
def infer(self, logits_or_probs, lengths):
def infer(self, logits_or_probs, lengths, tag=None):
for lp, inp_l in zip(
......@@ -78,41 +79,47 @@ class CTC(nn.Module):
pred_units = self.dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()"\nCTC prediction: %s" % " ".join(pred_words_raw))
if tag is not None:"%s CTC prediction: %s" % (tag, " ".join(pred_words_raw)))
else:"CTC prediction: %s" % (" ".join(pred_words_raw)))
def valid(self, logits_or_probs, target, lengths):
def valid(self, logits_or_probs, targets, input_lengths, dictionary):
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
wv_errs = 0
for lp, t, inp_l in zip(
lp = lp[:inp_l].unsqueeze(0)
with torch.no_grad():
for lp, t, inp_l in zip(
lp = lp[:inp_l].unsqueeze(0)
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
p = (t != dictionary.pad()) & (t != dictionary.eos())
targ = t[p]
targ_units = dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split()
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
pred_units = dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
w_len += len(targ_words)
w_len += len(targ_words)
\ No newline at end of file
return c_err, c_len, w_errs, w_len, wv_errs
from .imputer import imputer_loss, ImputerLoss, best_alignment, ctc_decode
// Copyright (c) 2018 MathInf GmbH, Thomas Viehmann
// Licensed under the BSD-3-Clause license
// This is the GPU implementation of the Connectionist Temporal Loss.
// We mostly follow Graves.
// 1. Graves et al:
// We use the equations from above link, but note that [1] has 1-based indexing
// and we (of course) use 0-based. Graves et al call the probabilities y, we use
// log_probs (also calling them inputs) A few optimizations (simmilar to those
// here, but also some I didn't take) are described in
// 2. Minmin Sun:
#include <ATen/TensorUtils.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <numeric>
#include <type_traits>
using namespace at;
// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
// so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
// l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
// - note that no bound-checking is done
// - it is important to only call it witth idx == 0 if the target length is 0
// - __restrict__ impact to be measured, see
template <typename target_t>
__device__ static inline int64_t
get_target_prime(const target_t *__restrict__ target, int64_t offset,
int64_t stride, int64_t idx, int64_t BLANK) {
if (idx % 2 == 0) {
return BLANK;
} else {
return target[offset + stride * (idx / 2)];
// this kernel is a relatively straightforward implementation of the alpha
// calculation in the forward backward algorithm (section 4.1). A (minor) twist
// is that we are using log-calculations to enhance numerical stability
// (log_probs and log_alpha). In total it would be more efficient to compute the
// beta in the same kernel (e.g. cudnn does this). While the beta are not needed
// for the loss itself (just the grad), we can return log_alpha+log_beta (so
// same space as currently) and the overhead is small and the use-case for loss
// without grad is relatively limited. We parallelize by batch and target
// sequence. Empirically, it is faster to loop over the input (log probs)
// sequence and do target in parallel, even if it means more frequent
// __syncthreads. In contrast to the cuDNN implementation, we allow large target
// lengths. For this we need that all previous `s` have been computed when we
// start a new block_s. This is why we have our own for loop here.
template <typename scalar_t, typename target_t>
__global__ void
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
scalar_t *__restrict__ log_alpha_data, int64_t *__restrict__ paths_data,
const scalar_t *log_probs_data,
const int64_t *__restrict__ input_lengths, int64_t max_input_length,
const target_t *__restrict__ targets_data,
const int64_t *__restrict__ target_lengths, int64_t max_target_length,
scalar_t *__restrict__ neg_log_likelihood_data, int64_t lp_input_stride,
int64_t lp_batch_stride, int64_t lp_char_stride,
int64_t la_batch_stride, int64_t la_input_stride,
int64_t la_target_stride, const int64_t *__restrict__ tg_batch_offsets,
int64_t tg_target_stride, int64_t batch_size, int64_t BLANK) {
constexpr scalar_t neginf = -INFINITY;
// bookkeeping
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t lp_batch_offset = b * lp_batch_stride;
int64_t la_batch_offset = b * la_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
if (b >= batch_size)
// first row (t=0), the three equations for alpha_1 above eq (6)
for (int64_t block_s = 0; block_s < 2 * max_target_length + 1;
block_s += blockDim.x) {
int64_t s = threadIdx.x + block_s;
scalar_t la;
switch (s) {
case 0:
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
case 1:
la = target_length == 0
? neginf
: log_probs_data[lp_batch_offset +
lp_char_stride *
targets_data, tg_batch_offset,
tg_target_stride, 1, BLANK)];
la = neginf;
if (s < 2 * max_target_length + 1)
log_alpha_data[la_batch_offset +
/* la_input_stride * 0 */ +la_target_stride * s] = la;
for (int64_t block_s = 0; block_s < 2 * max_target_length + 1;
block_s += blockDim.x) {
int64_t s = threadIdx.x + block_s;
// These two only depend on s, so we can cache them.
int64_t current_char; // l_s in eq (6)
bool have_three; // flag which of the two cases in eq (6) we have
if (s < 2 * target_length + 1 && target_length > 0) {
current_char = get_target_prime(targets_data, tg_batch_offset,
tg_target_stride, s, BLANK);
have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset,
tg_target_stride, s - 2,
BLANK) != current_char));
} else {
current_char = BLANK;
have_three = false;
for (int64_t t = 1; t < max_input_length; t++) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only
// the threads within the same batch
if ((t < input_length) && (s < 2 * target_length + 1)) {
// only for valid t, s. This is equation (6) and (7), la1, la2, la3 are
// the three summands, lamax is the maximum for the logsumexp trick.
scalar_t la1 =
log_alpha_data[la_batch_offset + la_input_stride * (t - 1) +
la_target_stride * s];
scalar_t lamax = la1;
int64_t max_path = s;
scalar_t la2, la3;
if (s > 0) {
la2 = log_alpha_data[la_batch_offset + la_input_stride * (t - 1) +
la_target_stride * (s - 1)];
if (la2 > lamax) {
lamax = la2;
max_path = s - 1;
} else {
la2 = neginf;
if (have_three) {
la3 = log_alpha_data[la_batch_offset + la_input_stride * (t - 1) +
la_target_stride * (s - 2)];
if (la3 > lamax) {
lamax = la3;
max_path = s - 2;
} else {
la3 = neginf;
/*if (lamax == neginf) // when all are neginf. (then the whole thing is
// neginf, but we can pretend)
lamax = 0;*/
int64_t log_alpha_i =
la_batch_offset + la_input_stride * t + la_target_stride * s;
int64_t log_prob_i = lp_batch_offset + t * lp_input_stride +
lp_char_stride * current_char;
log_alpha_data[log_alpha_i] = lamax + log_probs_data[log_prob_i];
paths_data[log_alpha_i] = max_path;
} else {
// otherwise we just set to neginf
if (s < 2 * max_target_length + 1)
log_alpha_data[la_batch_offset + la_input_stride * t +
la_target_stride * s] = neginf;
__syncthreads(); // on cuda 9 we might use partial synchronization of only the
// threads within the same batch
// compute the loss (eq (8))
if (threadIdx.x == 0) {
scalar_t l1 =
log_alpha_data[la_batch_offset + la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2)];
scalar_t l2 =
target_length > 0
? log_alpha_data[la_batch_offset +
la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2 - 1)]
: neginf;
scalar_t m = ((l1 > l2) ? l1 : l2);
m = ((m == neginf) ? 0 : m);
scalar_t log_likelihood = std::log(std::exp(l1 - m) + std::exp(l2 - m)) + m;
neg_log_likelihood_data[b] = -log_likelihood;
// The forward computation. Lot's of admin and a call to the alpha kernel.
// Note: we do not check that the labels are in the valid range. As we use
// them for indexing in the kernels, you'll see memory errors when you
// pass corrupt labels.
// We support both a 2-dimensional tensor as targets (one set of targets in each
// row) and a 1-dimensional tensor where all targets are concatenated (and we
// use target_lengths to figure out where they begin). We return log_alpha
// (currently, might change to (log_alpha+log_beta) to be passed to the
// backward. The dispatch function will only return the loss.
template <typename scalar_t, ScalarType target_scalar_type>
std::tuple<Tensor, Tensor, Tensor>
best_alignment_gpu_template(const Tensor &log_probs, const Tensor &targets,
IntArrayRef input_lengths,
IntArrayRef target_lengths, int64_t BLANK) {
// log_probs: input_len x batch_size x num_labels
// targets [int64]: batch_size x target_length OR sum(target_lengths)
CheckedFrom c = "ctc_alignment_gpu";
using target_t =
typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
auto log_probs_arg = TensorArg(log_probs, "log_probs", 1);
auto targets_arg = TensorArg(targets, "targets", 2);
checkAllSameGPU(c, {log_probs_arg, targets_arg});
checkScalarType(c, targets_arg, target_scalar_type);
checkDim(c, log_probs_arg, 3);
checkDimRange(c, targets_arg, 1, 3);
int64_t batch_size = log_probs.size(1);
int64_t num_labels = log_probs.size(2);
TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels),
"blank must be in label range");
TORCH_CHECK(input_lengths.size() == batch_size,
"input_lengths must be of size batch_size");
TORCH_CHECK(target_lengths.size() == batch_size,
"target_lengths must be of size batch_size");
int64_t lp_input_stride = log_probs.stride(0);
int64_t lp_char_stride = log_probs.stride(2);
int64_t tg_target_stride;
int64_t max_target_length = 0;
auto tg_batch_offsets =
at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong));
auto tg_batch_offsets_data = tg_batch_offsets.data_ptr<int64_t>();
if (targets.dim() == 1) { // concatenated targets
int64_t pos = 0;
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = pos;
pos += target_lengths[i];
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
tg_target_stride = targets.stride(0);
checkSize(c, targets_arg, 0, pos);
} else { // batch x max_target_length
// dim is 2
int64_t tg_batch_stride = targets.stride(0);
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = i * tg_batch_stride;
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
tg_target_stride = targets.stride(1);
checkSize(c, targets_arg, 0, batch_size);
TORCH_CHECK(targets.size(1) >= max_target_length,
"Expected tensor to have size at least ", max_target_length,
" at dimension 1, but got size ", targets.size(1), " for ",
targets_arg, " (while checking arguments for ", c, ")");
int64_t max_input_length = log_probs.size(0);
for (int64_t b = 0; b < batch_size; b++) {
TORCH_CHECK(input_lengths[b] <= max_input_length,
"Expected input_lengths to have value at most ",
max_input_length, ", but got value ", input_lengths[b],
" (while checking arguments for ", c, ")");
auto target_lengths_t =
at::tensor(target_lengths, targets.options().dtype(kLong));
auto input_lengths_t =
at::tensor(input_lengths, targets.options().dtype(kLong));
tg_batch_offsets = tg_batch_offsets.cuda();
Tensor log_alpha =
at::empty({batch_size, log_probs.size(0), 2 * max_target_length + 1},
Tensor paths = at::full_like(log_alpha, -1, log_alpha.options().dtype(kLong));
Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());
// Very likely, we could be more clever here, e.g. learning (or genralizing
// and reusing) from
constexpr int max_threads =
std::is_same<scalar_t, float>::value
? 512
: 896; // we need 72 or so 32 bit registers for double
int threads_target = max_threads;
while (threads_target / 2 >= 2 * max_target_length + 1) {
threads_target /= 2;
int threads_batch = std::min(max_threads / threads_target, (int)batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid((2 * max_target_length + 1 + threads_target - 1) / threads_target,
(batch_size + threads_batch - 1) / threads_batch);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
ctc_alignment_log_alpha_gpu_kernel<scalar_t, target_t>
<<<grid, block, 0, stream>>>(
log_alpha.data_ptr<scalar_t>(), paths.data_ptr<int64_t>(),
log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(),
log_probs.size(0), targets.data_ptr<target_t>(),
target_lengths_t.data_ptr<int64_t>(), max_target_length,
neg_log_likelihood.data_ptr<scalar_t>(), log_probs.stride(0),
log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0),
log_alpha.stride(1), log_alpha.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, batch_size,
AT_CUDA_CHECK(cudaGetLastError()); // catch launch errors
return std::make_tuple(neg_log_likelihood, log_alpha, paths);
std::tuple<Tensor, Tensor, Tensor>
best_alignment_op(const Tensor &log_probs, const Tensor &targets,
IntArrayRef input_lengths, IntArrayRef target_lengths,
int64_t BLANK, bool zero_infinity) {
(void)zero_infinity; // only used for backward
log_probs.scalar_type(), "ctc_alignment_cuda", [&] {
if (targets.scalar_type() == kLong) {
return best_alignment_gpu_template<scalar_t, kLong>(
log_probs, targets, input_lengths, target_lengths, BLANK);
} else {
return best_alignment_gpu_template<scalar_t, kInt>(
log_probs, targets, input_lengths, target_lengths, BLANK);
#include <torch/extension.h>
#include <tuple>
std::tuple<torch::Tensor, torch::Tensor>
imputer_loss_op(const torch::Tensor &log_probs, const torch::Tensor &targets,
const torch::Tensor &force_emits, at::IntArrayRef input_lengths,
at::IntArrayRef target_lengths, int64_t BLANK,
bool zero_infinity);
torch::Tensor imputer_loss_backward_op(
const torch::Tensor &grad, const torch::Tensor &log_probs,
const torch::Tensor &targets, const torch::Tensor &force_emits,
at::IntArrayRef input_lengths, at::IntArrayRef target_lengths,
const torch::Tensor &neg_log_likelihood, const torch::Tensor &log_alpha,
int64_t BLANK, bool zero_infinity);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
best_alignment_op(const torch::Tensor &log_probs, const torch::Tensor &targets,
at::IntArrayRef input_lengths, at::IntArrayRef target_lengths,
int64_t BLANK, bool zero_infinity);
std::tuple<torch::Tensor, torch::Tensor> imputer_loss(
const torch::Tensor &log_probs, const torch::Tensor &targets,
const torch::Tensor &force_emits, const torch::Tensor &input_lengths,
const torch::Tensor &target_lengths, int64_t BLANK, bool zero_infinity) {
torch::Tensor ilc =, at::kLong).contiguous();
torch::Tensor tlc =, at::kLong).contiguous();
at::IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
at::IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
auto res =
imputer_loss_op(log_probs,, at::kLong),, at::kLong), il, tl,
BLANK, zero_infinity);
return res;
torch::Tensor imputer_loss_backward(
const torch::Tensor &grad, const torch::Tensor &log_probs,
const torch::Tensor &targets, const torch::Tensor &force_emits,
const torch::Tensor &input_lengths, const torch::Tensor &target_lengths,
const torch::Tensor &neg_log_likelihood, const torch::Tensor &log_alpha,
int64_t BLANK, bool zero_infinity) {
torch::Tensor ilc =, at::kLong).contiguous();
torch::Tensor tlc =, at::kLong).contiguous();
at::IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
at::IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
torch::Tensor res;
res = imputer_loss_backward_op(
grad, log_probs,, at::kLong),, at::kLong), il, tl, neg_log_likelihood,
log_alpha, BLANK, zero_infinity);
return res;
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
best_alignment(const torch::Tensor &log_probs, const torch::Tensor &targets,
const torch::Tensor &input_lengths,
const torch::Tensor &target_lengths, int64_t BLANK,
bool zero_infinity) {
torch::Tensor ilc =, at::kLong).contiguous();
torch::Tensor tlc =, at::kLong).contiguous();
at::IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
at::IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
auto res =
best_alignment_op(log_probs,, at::kLong),
il, tl, BLANK, zero_infinity);
return res;
m.def("imputer_loss", &imputer_loss, "calculate imputer loss");
m.def("imputer_loss_backward", &imputer_loss_backward,
"calculate imputer loss gradient");
m.def("best_alignment", &best_alignment, "get best alignments for ctc");
\ No newline at end of file
// Copyright (c) 2018 MathInf GmbH, Thomas Viehmann
// Licensed under the BSD-3-Clause license
// This is the GPU implementation of the Connectionist Temporal Loss.
// We mostly follow Graves.
// 1. Graves et al:
// We use the equations from above link, but note that [1] has 1-based indexing
// and we (of course) use 0-based. Graves et al call the probabilities y, we use
// log_probs (also calling them inputs) A few optimizations (simmilar to those
// here, but also some I didn't take) are described in
// 2. Minmin Sun:
#include <ATen/TensorUtils.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <numeric>
#include <type_traits>
using namespace at;
// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
// so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
// l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
// - note that no bound-checking is done
// - it is important to only call it witth idx == 0 if the target length is 0
// - __restrict__ impact to be measured, see
template <typename target_t>
__device__ static inline int64_t
get_target_prime(const target_t *__restrict__ target, int64_t offset,
int64_t stride, int64_t idx, int64_t BLANK) {
if (idx % 2 == 0) {
return BLANK;
} else {
return target[offset + stride * (idx / 2)];
// this kernel is a relatively straightforward implementation of the alpha
// calculation in the forward backward algorithm (section 4.1). A (minor) twist
// is that we are using log-calculations to enhance numerical stability
// (log_probs and log_alpha). In total it would be more efficient to compute the
// beta in the same kernel (e.g. cudnn does this). While the beta are not needed
// for the loss itself (just the grad), we can return log_alpha+log_beta (so
// same space as currently) and the overhead is small and the use-case for loss
// without grad is relatively limited. We parallelize by batch and target
// sequence. Empirically, it is faster to loop over the input (log probs)
// sequence and do target in parallel, even if it means more frequent
// __syncthreads. In contrast to the cuDNN implementation, we allow large target
// lengths. For this we need that all previous `s` have been computed when we
// start a new block_s. This is why we have our own for loop here.
template <typename scalar_t, typename target_t>
__global__ void
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
scalar_t *__restrict__ log_alpha_data, const scalar_t *log_probs_data,
const int64_t *__restrict__ input_lengths, int64_t max_input_length,
const target_t *__restrict__ targets_data,
const int64_t *__restrict__ target_lengths, int64_t max_target_length,
const target_t *__restrict__ force_emits_data,
scalar_t *__restrict__ neg_log_likelihood_data, int64_t lp_input_stride,
int64_t lp_batch_stride, int64_t lp_char_stride,
int64_t fe_batch_stride, int64_t la_batch_stride,
int64_t la_input_stride, int64_t la_target_stride,
const int64_t *__restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, int64_t BLANK) {
constexpr scalar_t neginf = -INFINITY;
// bookkeeping
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t lp_batch_offset = b * lp_batch_stride;
int64_t la_batch_offset = b * la_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
int64_t fe_batch_offset = b * fe_batch_stride;
if (b >= batch_size)
target_t force_emit;
// first row (t=0), the three equations for alpha_1 above eq (6)
for (int64_t block_s = 0; block_s < 2 * max_target_length + 1;
block_s += blockDim.x) {
int64_t s = threadIdx.x + block_s;
scalar_t la;
switch (s) {
case 0:
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
case 1:
la = target_length == 0
? neginf
: log_probs_data[lp_batch_offset +
lp_char_stride *
targets_data, tg_batch_offset,
tg_target_stride, 1, BLANK)];
la = neginf;
force_emit = force_emits_data[fe_batch_offset];
if (force_emit > -1 && force_emit != s) {
la = neginf;
if (s < 2 * max_target_length + 1)
log_alpha_data[la_batch_offset +
/* la_input_stride * 0 */ +la_target_stride * s] = la;
for (int64_t block_s = 0; block_s < 2 * max_target_length + 1;
block_s += blockDim.x) {
int64_t s = threadIdx.x + block_s;
// These two only depend on s, so we can cache them.
int64_t current_char; // l_s in eq (6)
bool have_three; // flag which of the two cases in eq (6) we have
if (s < 2 * target_length + 1 && target_length > 0) {
current_char = get_target_prime(targets_data, tg_batch_offset,
tg_target_stride, s, BLANK);
have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset,
tg_target_stride, s - 2,
BLANK) != current_char));
} else {
current_char = BLANK;
have_three = false;
for (int64_t t = 1; t < max_input_length; t++) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only
// the threads within the same batch
if ((t < input_length) && (s < 2 * target_length + 1)) {
force_emit = force_emits_data[fe_batch_offset + t];
if (force_emit > -1 && force_emit != s) {
log_alpha_data[la_batch_offset + la_input_stride * t +
la_target_stride * s] = neginf;
// only for valid t, s. This is equation (6) and (7), la1, la2, la3 are
// the three summands, lamax is the maximum for the logsumexp trick.
scalar_t la1 =
log_alpha_data[la_batch_offset + la_input_stride * (t - 1) +
la_target_stride * s];
scalar_t lamax = la1;
scalar_t la2, la3;
if (s > 0) {
la2 = log_alpha_data[la_batch_offset + la_input_stride * (t - 1) +
la_target_stride * (s - 1)];
if (la2 > lamax)
lamax = la2;
} else {
la2 = neginf;
if (have_three) {
la3 = log_alpha_data[la_batch_offset + la_input_stride * (t - 1) +
la_target_stride * (s - 2)];
if (la3 > lamax)
lamax = la3;
} else {
la3 = neginf;
if (lamax == neginf) // when all are neginf. (then the whole thing is
// neginf, but we can pretend)
lamax = 0;
log_alpha_data[la_batch_offset + la_input_stride * t +
la_target_stride * s] =
std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) +
std::exp(la3 - lamax)) +
lamax +
log_probs_data[lp_batch_offset + t * lp_input_stride +
lp_char_stride * current_char];
} else {
// otherwise we just set to neginf
if (s < 2 * max_target_length + 1)
log_alpha_data[la_batch_offset + la_input_stride * t +
la_target_stride * s] = neginf;
__syncthreads(); // on cuda 9 we might use partial synchronization of only the
// threads within the same batch
// compute the loss (eq (8))
if (threadIdx.x == 0) {
scalar_t l1 =
log_alpha_data[la_batch_offset + la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2)];
scalar_t l2 =
target_length > 0
? log_alpha_data[la_batch_offset +
la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2 - 1)]
: neginf;
scalar_t m = ((l1 > l2) ? l1 : l2);
m = ((m == neginf) ? 0 : m);
scalar_t log_likelihood = std::log(std::exp(l1 - m) + std::exp(l2 - m)) + m;
neg_log_likelihood_data[b] = -log_likelihood;
// The forward computation. Lot's of admin and a call to the alpha kernel.
// Note: we do not check that the labels are in the valid range. As we use
// them for indexing in the kernels, you'll see memory errors when you
// pass corrupt labels.
// We support both a 2-dimensional tensor as targets (one set of targets in each
// row) and a 1-dimensional tensor where all targets are concatenated (and we
// use target_lengths to figure out where they begin). We return log_alpha
// (currently, might change to (log_alpha+log_beta) to be passed to the
// backward. The dispatch function will only return the loss.
template <typename scalar_t, ScalarType target_scalar_type>
std::tuple<Tensor, Tensor>
imputer_loss_gpu_template(const Tensor &log_probs, const Tensor &targets,
const Tensor &force_emits, IntArrayRef input_lengths,
IntArrayRef target_lengths, int64_t BLANK) {
// log_probs: input_len x batch_size x num_labels
// targets [int64]: batch_size x target_length OR sum(target_lengths)
CheckedFrom c = "imputer_loss_gpu";
using target_t =
typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
auto log_probs_arg = TensorArg(log_probs, "log_probs", 1);
auto targets_arg = TensorArg(targets, "targets", 2);
checkAllSameGPU(c, {log_probs_arg, targets_arg});
checkScalarType(c, targets_arg, target_scalar_type);
checkDim(c, log_probs_arg, 3);
checkDimRange(c, targets_arg, 1, 3);
int64_t batch_size = log_probs.size(1);
int64_t num_labels = log_probs.size(2);
TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels),
"blank must be in label range");
TORCH_CHECK(input_lengths.size() == batch_size,
"input_lengths must be of size batch_size");
TORCH_CHECK(target_lengths.size() == batch_size,
"target_lengths must be of size batch_size");
int64_t lp_input_stride = log_probs.stride(0);
int64_t lp_char_stride = log_probs.stride(2);
int64_t tg_target_stride;
int64_t max_target_length = 0;
auto tg_batch_offsets =
at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong));
auto tg_batch_offsets_data = tg_batch_offsets.data_ptr<int64_t>();
if (targets.dim() == 1) { // concatenated targets
int64_t pos = 0;
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = pos;
pos += target_lengths[i];
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
tg_target_stride = targets.stride(0);
checkSize(c, targets_arg, 0, pos);
} else { // batch x max_target_length
// dim is 2
int64_t tg_batch_stride = targets.stride(0);
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = i * tg_batch_stride;
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
tg_target_stride = targets.stride(1);
checkSize(c, targets_arg, 0, batch_size);
TORCH_CHECK(targets.size(1) >= max_target_length,
"Expected tensor to have size at least ", max_target_length,
" at dimension 1, but got size ", targets.size(1), " for ",
targets_arg, " (while checking arguments for ", c, ")");
int64_t max_input_length = log_probs.size(0);
for (int64_t b = 0; b < batch_size; b++) {
TORCH_CHECK(input_lengths[b] <= max_input_length,
"Expected input_lengths to have value at most ",
max_input_length, ", but got value ", input_lengths[b],
" (while checking arguments for ", c, ")");
auto target_lengths_t =
at::tensor(target_lengths, targets.options().dtype(kLong));
auto input_lengths_t =
at::tensor(input_lengths, targets.options().dtype(kLong));
tg_batch_offsets = tg_batch_offsets.cuda();
Tensor log_alpha =
at::empty({batch_size, log_probs.size(0), 2 * max_target_length + 1},
Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());
// Very likely, we could be more clever here, e.g. learning (or genralizing
// and reusing) from
constexpr int max_threads =
std::is_same<scalar_t, float>::value
? 1024
: 896; // we need 72 or so 32 bit registers for double
int threads_target = max_threads;
while (threads_target / 2 >= 2 * max_target_length + 1) {
threads_target /= 2;
int threads_batch = std::min(max_threads / threads_target, (int)batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid((2 * max_target_length + 1 + threads_target - 1) / threads_target,
(batch_size + threads_batch - 1) / threads_batch);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
imputer_loss_log_alpha_gpu_kernel<scalar_t, target_t>
<<<grid, block, 0, stream>>>(
log_alpha.data_ptr<scalar_t>(), log_probs.data_ptr<scalar_t>(),
input_lengths_t.data_ptr<int64_t>(), log_probs.size(0),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(),
max_target_length, force_emits.data_ptr<target_t>(),
neg_log_likelihood.data_ptr<scalar_t>(), log_probs.stride(0),
log_probs.stride(1), log_probs.stride(2), force_emits.stride(0),
log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, batch_size,
C10_CUDA_CHECK(cudaGetLastError()); // catch launch errors
return std::make_tuple(neg_log_likelihood, log_alpha);
// The second (backward) half of the forward backward algorithm, (10) and (11).
// This is parallel to the alpha kernel above. (As mentioned above, it might
// make sense do the calculation in the alpha kernel.)
template <typename scalar_t, typename target_t>
__global__ void
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
scalar_t *__restrict__ log_beta_data, const scalar_t *log_probs_data,
const int64_t *__restrict__ input_lengths, int64_t max_input_length,
const target_t *__restrict__ targets_data,
const int64_t *__restrict__ target_lengths, int64_t max_target_length,
const target_t *__restrict__ force_emits_data, int64_t lp_input_stride,
int64_t lp_batch_stride, int64_t lp_char_stride,
int64_t fe_batch_stride, int64_t lb_batch_stride,
int64_t lb_input_stride, int64_t lb_target_stride,
const int64_t *__restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, int64_t BLANK) {
constexpr scalar_t neginf = -INFINITY;
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t lp_batch_offset = b * lp_batch_stride;
int64_t lb_batch_offset = b * lb_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
int64_t fe_batch_offset = b * fe_batch_stride;
if (b >= batch_size)
target_t force_emit;
// "first" row, the beta initiaization before eq (10) (t=target_length -
// differes per batch)
for (int64_t block_s =
2 * max_target_length - (2 * max_target_length % blockDim.x);
block_s >= 0; block_s -= blockDim.x) {
int64_t s = threadIdx.x + block_s;
scalar_t lb;
if (s == 2 * target_length) {
lb = log_probs_data[lp_batch_offset +
(input_length - 1) * lp_input_stride +
lp_char_stride * BLANK];
} else if (s == 2 * target_length - 1) { // false for target_length == 0
int64_t current_target_prime = get_target_prime(
targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
lb = log_probs_data[lp_batch_offset +
(input_length - 1) * lp_input_stride +
lp_char_stride * current_target_prime];
} else {
lb = neginf;
force_emit = force_emits_data[fe_batch_offset + (input_length - 1)];
if (force_emit > -1 && force_emit != s) {
lb = neginf;
if (s < 2 * max_target_length + 1) {
log_beta_data[lb_batch_offset + (input_length - 1) * lb_input_stride +
lb_target_stride * s] = lb;
// go backward in s
for (int64_t block_s =
2 * max_target_length - (2 * max_target_length % blockDim.x);
block_s >= 0; block_s -= blockDim.x) {
int64_t s = threadIdx.x + block_s;
int64_t current_target_prime;
bool have_three;
if (s < 2 * target_length + 1 && target_length > 0) {
current_target_prime = get_target_prime(targets_data, tg_batch_offset,
tg_target_stride, s, BLANK);
have_three =
((s < 2 * target_length - 1) &&
(get_target_prime(targets_data, tg_batch_offset, tg_target_stride,
s + 2, BLANK) != current_target_prime));
} else {
current_target_prime = BLANK;
have_three = false;
// now go backward in t. Note that we need to skip the last timestep that we
// did above.
for (int64_t t = max_input_length - 2; t >= 0; t--) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only
// the threads within the same batch item
if ((t < input_length - 1) && (s < 2 * target_length + 1)) {
force_emit = force_emits_data[fe_batch_offset + t];
if ((force_emit > -1) && (force_emit != s)) {
log_beta_data[lb_batch_offset + lb_input_stride * t +
lb_target_stride * s] = neginf;
scalar_t lb1 =
log_beta_data[lb_batch_offset + lb_input_stride * (t + 1) +
lb_target_stride * s];
scalar_t lbmax = lb1;
scalar_t lb2, lb3;
if (s < 2 * target_length) {
lb2 = log_beta_data[lb_batch_offset + lb_input_stride * (t + 1) +
lb_target_stride * (s + 1)];
if (lb2 > lbmax)
lbmax = lb2;
} else {
lb2 = neginf;
if (have_three) {
lb3 = log_beta_data[lb_batch_offset + lb_input_stride * (t + 1) +
lb_target_stride * (s + 2)];
if (lb3 > lbmax)
lbmax = lb3;
} else {
lb3 = neginf;
if (lbmax == neginf)
lbmax = 0;
scalar_t lb = std::log(std::exp(lb1 - lbmax) + std::exp(lb2 - lbmax) +
std::exp(lb3 - lbmax)) +
lbmax +
log_probs_data[lp_batch_offset + t * lp_input_stride +
lp_char_stride * current_target_prime];
log_beta_data[lb_batch_offset + lb_input_stride * t +
lb_target_stride * s] = lb;
} else if ((s < 2 * max_target_length + 1) &&
(((target_length == 0) && (s > 0)) ||
(s >= 2 * target_length + 1) || (t >= input_length))) {
log_beta_data[lb_batch_offset + lb_input_stride * t +
lb_target_stride * s] = neginf;
// This implements the subtrahend of equation (16) for all *nonblank*
// characters. It assumes you have probs in gradient_data when called and it
// modifies gradient_data to be, the gradient. In order to facilitate this
// inplace update, We don't actually do this in logspace. (The other variant
// implemented uses log_space and the differences seem to be
// not so problematic at least with unit normal distributed test activations.)
// Internally this uses atomicAdd because different threads may write to the
// same gradient position. This is parallelised over b and s again. Note that
// for us, the Z of eqn (16) is actually constant for all t and it is the
// likelihood - this is why we use the negative log likelihood below.
// We also multiply by the input gradient to keep with standard autograd style.
// I took this trick from [2], for moderate alphabet sizes a log-space
// calculation (with an atomic log add) is similarly in performance, but for
// large alphabets the inplace nature is a considerable advantage.
template <typename scalar_t, typename target_t>
__global__ void
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
scalar_t *__restrict__ gradient_data,
const scalar_t *__restrict__ grad_out_data,
int64_t grad_out_batch_stride,
const scalar_t *__restrict__ log_alpha_data,
const scalar_t *__restrict__ log_beta_data,
const scalar_t *log_probs_data,
const int64_t *__restrict__ input_lengths, int64_t max_input_length,
const target_t *__restrict__ targets_data,
const int64_t *__restrict__ target_lengths, int64_t max_target_length,
const scalar_t *__restrict__ neg_log_likelihood_data,
int64_t gr_input_stride, int64_t gr_batch_stride,
int64_t gr_char_stride, int64_t lp_input_stride,
int64_t lp_batch_stride, int64_t lp_char_stride,
int64_t la_batch_stride, int64_t la_input_stride,
int64_t la_target_stride, int64_t lb_batch_stride,
int64_t lb_input_stride, int64_t lb_target_stride,
const int64_t *__restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, int64_t num_labels, int64_t BLANK,
bool zero_infinity) {
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t s =
threadIdx.x + blockIdx.x * blockDim.x; // note, this directly indexes into
// targets, not targets prime!
if (b >= batch_size)
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t gr_batch_offset = b * gr_batch_stride;
int64_t lp_batch_offset = b * lp_batch_stride;
int64_t la_batch_offset = b * la_batch_stride;
int64_t lb_batch_offset = b * lb_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
if (s >= target_length)
int64_t target = targets_data[tg_batch_offset + s * tg_target_stride];
scalar_t nll = neg_log_likelihood_data[b];
scalar_t gr = grad_out_data[b * grad_out_batch_stride];
if (zero_infinity && nll == INFINITY)
for (int64_t t = 0; t < input_length; t++) {
scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride +
lp_char_stride * target];
atomicAdd(&gradient_data[gr_batch_offset + t * gr_input_stride +
gr_char_stride * target],
-std::exp(log_alpha_data[la_batch_offset + la_input_stride * t +
la_target_stride * (s * 2 + 1)] +
log_beta_data[lb_batch_offset + lb_input_stride * t +
lb_target_stride * (s * 2 + 1)] +
nll - lp) *
// This is the naive implementation of equation (16). It is parallelised in
// batch and input timestep. It appears to be faster than the above method for
// small batch sizes.
template <typename scalar_t, typename target_t>
__global__ void
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
scalar_t *__restrict__ gradient_data,
const scalar_t *__restrict__ grad_out_data,
int64_t grad_out_batch_stride,
const scalar_t *__restrict__ log_alpha_data,
const scalar_t *__restrict__ log_beta_data,
const scalar_t *log_probs_data,
const int64_t *__restrict__ input_lengths, int64_t max_input_length,
const target_t *__restrict__ targets_data,
const int64_t *__restrict__ target_lengths, int64_t max_target_length,
const scalar_t *__restrict__ neg_log_likelihood_data,
int64_t gr_input_stride, int64_t gr_batch_stride,
int64_t gr_char_stride, int64_t lp_input_stride,
int64_t lp_batch_stride, int64_t lp_char_stride,
int64_t la_batch_stride, int64_t la_input_stride,
int64_t la_target_stride, int64_t lb_batch_stride,
int64_t lb_input_stride, int64_t lb_target_stride,
const int64_t *__restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, int64_t num_labels, int64_t BLANK,
bool zero_infinity) {
constexpr scalar_t neginf = -INFINITY;
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t t = threadIdx.x + blockIdx.x * blockDim.x;
if ((t >= max_input_length) || (b >= batch_size))
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t gr_batch_offset = b * gr_batch_stride;
int64_t lp_batch_offset = b * lp_batch_stride;
int64_t la_batch_offset = b * la_batch_stride;
int64_t lb_batch_offset = b * lb_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
// collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
for (int s = 0; s < 2 * max_target_length + 1; s++) {
if (s < 2 * target_length + 1) { // if target_length == 0, s == 0
int64_t current_target_prime = get_target_prime(
targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
scalar_t log_alpha_beta =
(log_alpha_data[la_batch_offset + la_input_stride * t +
la_target_stride * s] +
log_beta_data[lb_batch_offset + lb_input_stride * t +
lb_target_stride * s]);
scalar_t &lcab = gradient_data[gr_batch_offset + t * gr_input_stride +
gr_char_stride * current_target_prime];
if (lcab == neginf) {
lcab = log_alpha_beta;
} else {
scalar_t max = ((lcab > log_alpha_beta) ? lcab : log_alpha_beta);
lcab = std::log(std::exp(lcab - max) + std::exp(log_alpha_beta - max)) +
scalar_t nll = neg_log_likelihood_data[b];
scalar_t gr = grad_out_data[b * grad_out_batch_stride];
for (int64_t c = 0; c < num_labels; c++) {
scalar_t &res = gradient_data[gr_batch_offset + t * gr_input_stride +
gr_char_stride * c];
if (t < input_length && (!zero_infinity || nll != INFINITY)) {
scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride +
lp_char_stride * c];
res = (std::exp(lp) - std::exp(res + nll - lp)) * gr;
} else {
res = 0.;
// This is to zero gradients which corresponding to the out-of-sequence position
// Those gradients should not be used in any model update since the input
// elements are padded
template <typename scalar_t>
__global__ void
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
scalar_t *__restrict__ gradient_data, /* (T, B, D) layout */
const int64_t *__restrict__ input_lengths, /* (B, ) layout */
int64_t gr_timestep_stride, int64_t gr_batch_stride,
int64_t gr_label_stride, int64_t max_input_length, /* T */
int64_t batch_size, /* B */
int64_t num_labels /* D */) {
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (b >= batch_size || t >= max_input_length) {
scalar_t input_length = input_lengths[b];
if (t >= input_length) {
for (int l = 0; l < num_labels; l++)
gradient_data[t * gr_timestep_stride + b * gr_batch_stride +
l * gr_label_stride] = 0.0f;
// The backward. It essentially computes eq 16 by using the above kernels.
// We don't do a lot of checking as we envision this to be called only when
// backpropagating through a (well-checked) forward.
template <typename scalar_t, ScalarType target_scalar_type>
Tensor imputer_loss_backward_gpu_template(
const Tensor &grad_out, const Tensor &log_probs, const Tensor &targets,
const Tensor &force_emits, IntArrayRef input_lengths,
IntArrayRef target_lengths, const Tensor &neg_log_likelihood,
const Tensor &log_alpha, int64_t BLANK, bool zero_infinity) {
constexpr scalar_t neginf = -INFINITY;
using target_t =
typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
int64_t batch_size = log_probs.size(1);
int64_t num_labels = log_probs.size(2);
int64_t lp_input_stride = log_probs.stride(0);
int64_t lp_char_stride = log_probs.stride(2);
int64_t tg_target_stride;
int64_t max_target_length;
auto tg_batch_offsets =
at::empty({batch_size}, TensorOptions(at::CPU(kLong)));
auto tg_batch_offsets_data = tg_batch_offsets.data_ptr<int64_t>();
if (targets.dim() == 1) { // concatenated targets
int64_t pos = 0;
max_target_length = 0;
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = pos;
pos += target_lengths[i];
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
tg_target_stride = targets.stride(0);
} else { // batch x max_target_length
// dim is 2
int64_t tg_batch_stride = targets.stride(0);
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = i * tg_batch_stride;
tg_target_stride = targets.stride(1);
max_target_length =
log_alpha.size(2) / 2; // targets.size(1) might be larger
auto target_lengths_t =
at::tensor(target_lengths, targets.options().dtype(kLong));
auto input_lengths_t =
at::tensor(input_lengths, targets.options().dtype(kLong));
tg_batch_offsets = tg_batch_offsets.cuda();
Tensor log_beta = at::empty_like(log_alpha, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor grad =
at::full_like(log_probs, neginf,
// log(sum (alpha beta))
// As above, there may be better configurations to use.
constexpr int max_threads =
std::is_same<scalar_t, float>::value
? 1024
: 896; // we need 72 or so 32 bit registers for double
int threads_target = max_threads;
while (threads_target / 2 >= 2 * max_target_length + 1) {
threads_target /= 2;
int threads_batch = std::min(max_threads / threads_target, (int)batch_size);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block(threads_target, threads_batch);
dim3 grid((2 * max_target_length + 1 + threads_target - 1) / threads_target,
(batch_size + threads_batch - 1) / threads_batch);
imputer_loss_backward_log_beta_gpu_kernel<scalar_t, target_t>
<<<grid, block, 0, stream>>>(
log_beta.data_ptr<scalar_t>(), log_probs.data_ptr<scalar_t>(),
input_lengths_t.data_ptr<int64_t>(), log_probs.size(0),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(),
max_target_length, force_emits.data_ptr<target_t>(),
log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
force_emits.stride(0), log_beta.stride(0), log_beta.stride(1),
log_beta.stride(2), tg_batch_offsets.data_ptr<int64_t>(),
tg_target_stride, batch_size, BLANK);
C10_CUDA_CHECK(cudaGetLastError()); // catch launch errors
// Very crude heuristic for what is a small problem., based on linearly
// regressing problem dimensions on the (capped) difference of timings. Note
// that for OK problems target length <= input length, so we only consider
// input length.
bool is_large = (2 * log_probs.size(0) + (24 * batch_size) / 10 +
(2 * num_labels) / 10) > 450;
if (is_large) { // large alphabet, large batch
// this computes the probs, minuend in (16)
exp_out(grad, log_probs);
// now we compute the subtrahend for the blanks. It is a straightforward
// reduction because we know that blanks are in every other position. maybe
// we should kernelize this, too.
auto grad_blank = grad.narrow(2, BLANK, 1);
grad_blank -=
{batch_size, log_alpha.size(1), max_target_length + 1},
{log_alpha.stride(0), log_alpha.stride(1),
log_alpha.stride(2) * 2}) +
{batch_size, log_beta.size(1), max_target_length + 1},
{log_beta.stride(0), log_beta.stride(1),
log_beta.stride(2) * 2}),
2, true)
.permute({1, 0, 2})
.add_(neg_log_likelihood.view({1, batch_size, 1}))
.sub_(log_probs.narrow(2, BLANK, 1))
// scale by output gradient (blanks and first summand of non-blanks)
grad *= grad_out.view({1, batch_size, 1});
if (zero_infinity) {
grad = at::where(neg_log_likelihood.view({1, batch_size, 1}) ==
at::zeros({}, grad.options()), grad);
// For the non-blank characters, we use a kernel to compute the subtrahend.
// Again we might configure block and grid in a better way.
int threads_target = max_threads;
while (threads_target / 2 >= max_target_length && threads_target > 1) {
threads_target /= 2;
int threads_batch = std::min(max_threads / threads_target, (int)batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid(std::max<int>(
(max_target_length + threads_target - 1) / threads_target, 1),
(batch_size + threads_batch - 1) / threads_batch, 1);
imputer_loss_backward_collect_nonblank_gpu_kernel<scalar_t, target_t>
<<<grid, block, 0, stream>>>(
grad.data_ptr<scalar_t>(), grad_out.data_ptr<scalar_t>(),
grad_out.stride(0), log_alpha.data_ptr<scalar_t>(),
log_beta.data_ptr<scalar_t>(), log_probs.data_ptr<scalar_t>(),
input_lengths_t.data_ptr<int64_t>(), log_probs.size(0),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(),
max_target_length, neg_log_likelihood.data_ptr<scalar_t>(),
grad.stride(0), grad.stride(1), grad.stride(2), log_probs.stride(0),
log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0),
log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0),
log_beta.stride(1), log_beta.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, batch_size,
num_labels, BLANK, zero_infinity);
C10_CUDA_CHECK(cudaGetLastError()); // catch launch errors
} else { // small problem, use naive algorithm
// Still no block/grid configuration guru...
int threads_input = max_threads;
while (threads_input / 2 >= log_probs.size(0) && threads_input > 1) {
threads_input /= 2;
threads_batch = std::min(max_threads / threads_input, (int)batch_size);
dim3 block(threads_input, threads_batch);
dim3 grid((log_probs.size(0) + threads_input - 1) / threads_input,
(batch_size + threads_batch - 1) / threads_batch);
imputer_loss_backward_collect_gpu_kernel<scalar_t, target_t>
<<<grid, block, 0, stream>>>(
grad.data_ptr<scalar_t>(), grad_out.data_ptr<scalar_t>(),
grad_out.stride(0), log_alpha.data_ptr<scalar_t>(),
log_beta.data_ptr<scalar_t>(), log_probs.data_ptr<scalar_t>(),
input_lengths_t.data_ptr<int64_t>(), log_probs.size(0),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(),
max_target_length, neg_log_likelihood.data_ptr<scalar_t>(),
grad.stride(0), grad.stride(1), grad.stride(2), log_probs.stride(0),
log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0),
log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0),
log_beta.stride(1), log_beta.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, batch_size,
num_labels, BLANK, zero_infinity);
C10_CUDA_CHECK(cudaGetLastError()); // catch launch errors
// zero those invalid graident elements due to padding
int threads_input = max_threads;
while (threads_input / 2 >= log_probs.size(0)) {
threads_input /= 2;
threads_batch = std::min(max_threads / threads_input, (int)batch_size);
dim3 block(threads_input, threads_batch);
dim3 grid((log_probs.size(0) + threads_input - 1) / threads_input,
(batch_size + threads_batch - 1) / threads_batch);
imputer_loss_zero_padded_gradients<scalar_t><<<grid, block, 0, stream>>>(
grad.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(),
grad.stride(0), grad.stride(1), grad.stride(2), grad.size(0),
grad.size(1), grad.size(2));
return grad;
std::tuple<Tensor, Tensor>
imputer_loss_op(const Tensor &log_probs, const Tensor &targets,
const Tensor &force_emits, IntArrayRef input_lengths,
IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) {
(void)zero_infinity; // only used for backward
log_probs.scalar_type(), "imputer_loss_cuda", [&] {
if (targets.scalar_type() == kLong) {
return imputer_loss_gpu_template<scalar_t, kLong>(
log_probs, targets, force_emits, input_lengths, target_lengths,
} else {
return imputer_loss_gpu_template<scalar_t, kInt>(
log_probs, targets, force_emits, input_lengths, target_lengths,
Tensor imputer_loss_backward_op(
const Tensor &grad, const Tensor &log_probs, const Tensor &targets,
const Tensor &force_emits, IntArrayRef input_lengths,
IntArrayRef target_lengths, const Tensor &neg_log_likelihood,
const Tensor &log_alpha, int64_t BLANK, bool zero_infinity) {
log_probs.scalar_type(), "imputer_loss_backward_cuda", [&] {
if (targets.scalar_type() == kLong) {
return imputer_loss_backward_gpu_template<scalar_t, kLong>(
grad, log_probs, targets, force_emits, input_lengths,
target_lengths, neg_log_likelihood, log_alpha, BLANK,
} else {
return imputer_loss_backward_gpu_template<scalar_t, kInt>(
grad, log_probs, targets, force_emits, input_lengths,
target_lengths, neg_log_likelihood, log_alpha, BLANK,
import os
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
imputer = load(
os.path.join(module_path, "imputer.cpp"),
os.path.join(module_path, ""),
os.path.join(module_path, ""),
class ImputerLossFunction(Function):
def forward(
input_lengths ="cpu", dtype=torch.int64)
target_lengths ="cpu", dtype=torch.int64)
loss, log_alpha = imputer.imputer_loss(
ctx.blank = blank
ctx.zero_infinity = zero_infinity
return loss
def backward(ctx, grad_output):
log_prob, targets, force_emits, input_lengths, target_lengths, loss, log_alpha = (
blank = ctx.blank
zero_infinity = ctx.zero_infinity
grad_input = imputer.imputer_loss_backward(
return grad_input, None, None, None, None, None, None
imputer_loss_fn = ImputerLossFunction.apply
def imputer_loss(
"""The Imputer loss
log_prob (T, N, C): C = number of characters in alphabet including blank
T = input length
N = batch size
log probability of the outputs (e.g. torch.log_softmax of logits)
targets (N, S): S = maximum number of characters in target sequences
force_emits (N, T): sequence of ctc states that should be occur given times
that is, if force_emits is state s at time t, only ctc paths
that pass state s at time t will be enabled, and will be zero out the rest
this will be same as using cross entropy loss at time t
value should be in range [-1, 2 * S + 1), valid ctc states
-1 will means that it could be any states at time t (normal ctc paths)
input_lengths (N): lengths of log_prob
target_lengths (N): lengths of targets
blank (int): index of blank tokens (default 0)
reduction (str): reduction methods applied to the output. 'none' | 'mean' | 'sum'
zero_infinity (bool): if true imputer loss will zero out infinities.
infinities mostly occur when it is impossible to generate
target sequences using input sequences
(e.g. input sequences are shorter than target sequences)
loss = imputer_loss_fn(
input_lengths ="cpu", dtype=torch.int64)
target_lengths ="cpu", dtype=torch.int64)
if zero_infinity:
inf = float("inf")
loss = torch.where(loss == inf, loss.new_zeros(1), loss)
if reduction == "mean":
target_length =
return (loss / target_length).mean()
elif reduction == "sum":
return loss.sum()
elif reduction == "none":
return loss
raise ValueError(
f"Supported reduction modes are: mean, sum, none; got {reduction}"
class ImputerLoss(nn.Module):
def __init__(self, blank=0, reduction="mean", zero_infinity=False):
"""The Imputer loss
blank (int): index of blank tokens (default 0)
reduction (str): reduction methods applied to the output. 'none' | 'mean' | 'sum'
zero_infinity (bool): if true imputer loss will zero out infinities.
infinities mostly occur when it is impossible to generate
target sequences using input sequences
(e.g. input sequences are shorter than target sequences)
log_prob (T, N, C): C = number of characters in alphabet including blank
T = input length
N = batch size
log probability of the outputs (e.g. torch.log_softmax of logits)
targets (N, S): S = maximum number of characters in target sequences
force_emits (N, T): sequence of ctc states that should be occur given times
that is, if force_emits is state s at time t, only ctc paths
that pass state s at time t will be enabled, and will be zero out the rest
this will be same as using cross entropy loss at time t
value should be in range [-1, 2 * S + 1), valid ctc states
-1 will means that it could be any states at time t (normal ctc paths)
input_lengths (N): lengths of log_prob
target_lengths (N): lengths of targets"""
self.blank = blank
self.reduction = reduction
self.zero_infinity = zero_infinity
def forward(self, log_prob, targets, force_emits, input_lengths, target_lengths):
return imputer_loss(
"""class ImputerLoss(nn.Module):
def __init__(self, blank=0, reduction="mean", zero_infinity=False, mask_eps=1e-8):
self.blank = blank
self.reduction = reduction
self.zero_infinity = zero_infinity
self.mask_eps = math.log(mask_eps)
def forward(
self, logit, targets_ctc, targets_ce, mask, input_lengths, targets_ctc_lengths
n_target = logit.shape[-1]
mask_e = mask.unsqueeze(-1)
mask_exp = mask_e.repeat(1, 1, n_target)
log_p_mask = logit.masked_fill(mask_exp == 1, self.mask_eps)
mask_exp[:, :, self.blank] = 0
log_p_mask = log_p_mask.masked_fill((mask_e == 1) & (mask_exp == 0), 0)
log_p_mask = torch.log_softmax(log_p_mask, 2)
ctc_loss = F.ctc_loss(
log_p_mask.transpose(0, 1),
ce_loss = F.cross_entropy(
logit.view(-1, n_target), targets_ce.view(-1), reduction="none"
ce_loss = mask.view(-1) * ce_loss
if self.reduction == "mean":
ce_loss = ce_loss.mean()
elif self.reduction == "sum":
ce_loss = ce_loss.sum()
return ctc_loss + ce_loss"""
def get_alignment_path(log_alpha, path):
if log_alpha.shape[0] == 1:
current_state = 0
current_state = log_alpha[-2:, -1].argmax() + (log_alpha.shape[0] - 2)
path_decode = [current_state]
for t in range(path.shape[1] - 1, 0, -1):
prev_state = path[current_state, t]
current_state = prev_state
return path_decode[::-1]
def ctc_decode(seq, blank=0):
result = []
prev = -1
for s in seq:
if s == blank:
prev = s
if prev == -1:
if s != prev:
prev = s
return result
def best_alignment(
log_prob, targets, input_lengths, target_lengths, blank=0, zero_infinity=False
"""Get best alignment (maximum probability sequence of ctc states)
conditioned on log probabilities and target sequences
log_prob (T, N, C): C = number of characters in alphabet including blank
T = input length
N = batch size
log probability of the outputs (e.g. torch.log_softmax of logits)
targets (N, S): S = maximum number of characters in target sequences
input_lengths (N): lengths of log_prob
target_lengths (N): lengths of targets
blank (int): index of blank tokens (default 0)
zero_infinity (bool): if true imputer loss will zero out infinities.
infinities mostly occur when it is impossible to generate
target sequences using input sequences
(e.g. input sequences are shorter than target sequences)
best_aligns (List[List[int]]): sequence of ctc states that have maximum probabilties
given log probabilties, and compatible with target sequences"""
nll, log_alpha, alignment = imputer.best_alignment(
log_prob, targets, input_lengths, target_lengths, blank, zero_infinity
log_alpha = log_alpha.transpose(1, 2).detach().cpu().numpy()
alignment = alignment.transpose(1, 2).detach().cpu().numpy()
best_aligns = []
for log_a, align, input_len, target_len in zip(
log_alpha, alignment, input_lengths, target_lengths
state_len = target_len * 2 + 1
log_a = log_a[:state_len, :input_len]
align = align[:state_len, :input_len]
best_aligns.append(get_alignment_path(log_a, align))
return best_aligns
......@@ -107,7 +107,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for model in models:
if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"):
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, cfg.common_eval.post_process,
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece",
src_dict, tgt_dict)
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
