Commit e40eac14 by xuchen

optimize the information dump

parent d946bc3b
......@@ -12,6 +12,8 @@ from typing import Optional
import numpy as np
import logging
import editdistance
import os
import sys
import torch
import torch.nn.functional as F
......@@ -62,14 +64,28 @@ class CtcCriterionConfig(FairseqDataclass):
default=0.0,
metadata={"help": "weight of interleaved CTC loss for target sentence"},
)
cal_all_ctc: bool = field(
default=False,
metadata={"help": "calculate all ctc results"},
)
ctc_self_distill_weight: float = field(
default=0.0,
metadata={"help": "weight of the self distillation CTC loss"},
)
target_ctc_self_distill_weight: float = field(
default=0.0,
metadata={"help": "weight of the self distillation CTC loss for target sentence"},
)
ctc_self_distill_prob: float = field(
default=0.1,
metadata={"help": "probability to use distillation loss"},
)
ctc_self_distill_temperature: float = field(
default=1,
metadata={"help": "temperature for ctc self distillation"},
)
wer_kenlm_model: Optional[str] = field(
default=None,
......@@ -100,7 +116,7 @@ class CtcCriterionConfig(FairseqDataclass):
@register_criterion("ctc", dataclass=CtcCriterionConfig)
class CtcCriterion(FairseqCriterion):
def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask, ctc_weight=1.0):
def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask, ctc_weight=1.0, save_dir=None):
super().__init__(task)
if cfg.wer_args is not None:
......@@ -136,29 +152,45 @@ class CtcCriterion(FairseqCriterion):
self.eos_idx = task.target_dictionary.eos()
self.post_process = cfg.post_process
self.sentence_avg = cfg.sentence_avg
self.save_dir = save_dir
self.cal_all_ctc = cfg.cal_all_ctc
self.ctc_weight = ctc_weight
self.interleaved_ctc_weight = cfg.interleaved_ctc_weight
self.target_ctc_weight = cfg.target_ctc_weight
self.target_interleaved_ctc_weight = cfg.target_interleaved_ctc_weight
self.ctc_self_distill_weight = cfg.ctc_self_distill_weight
self.ctc_self_distill_prob = cfg.ctc_self_distill_prob
self.target_ctc_self_distill_weight = float(cfg.target_ctc_self_distill_weight)
self.ctc_self_distill_prob = float(cfg.ctc_self_distill_prob)
self.ctc_self_distill_temperature = float(cfg.ctc_self_distill_temperature)
self.ctc_entropy = cfg.ctc_entropy
self.ctc_entropy_cutoff = cfg.ctc_entropy_cutoff
self.all_ctc_weight = self.ctc_weight + self.interleaved_ctc_weight + \
self.target_ctc_weight + self.target_interleaved_ctc_weight + \
self.ctc_self_distill_weight + self.ctc_entropy
self.ctc_self_distill_weight + self.target_ctc_self_distill_weight + \
self.ctc_entropy
if self.all_ctc_weight > 0:
# assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
self.ctc_names = []
self.use_ctc = (self.ctc_weight + self.interleaved_ctc_weight > 0)
self.use_target_ctc = (self.target_ctc_weight + self.target_interleaved_ctc_weight > 0)
self.use_source_distill = self.use_target_distill = False
if self.ctc_self_distill_prob > 0:
if self.ctc_self_distill_weight:
self.use_source_distill = True
if self.target_ctc_self_distill_weight > 0:
self.use_target_distill = True
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
ntokens = sample["ntokens"]
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {
"ntokens": ntokens,
"nsentences": sample["id"].numel(),
......@@ -168,15 +200,119 @@ class CtcCriterion(FairseqCriterion):
loss, logging_output = self.compute_ctc_loss(model, sample, net_output, logging_output)
return loss, sample_size, logging_output
def get_loss(self, lprobs, targets_flat, input_lengths, transcript_lengths):
def get_ground_truth_alignment(self, model, sample):
ctc_alignment_oracle = dict()
try:
from fairseq.torch_imputer import best_alignment, imputer_loss
except ImportError:
logger.warning("Imputer is not available.")
src_tokens, src_lengths, prev_output_tokens = sample["net_input"].values()
with torch.no_grad():
encoder_out = model.encoder(src_tokens, src_lengths)
ctc_logit = None
if "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) != 0:
ctc_logit = encoder_out["ctc_logit"][0]
elif "interleaved_ctc_logits" in encoder_out and len(encoder_out["interleaved_ctc_logits"]) != 0:
ctc_logit = encoder_out["interleaved_ctc_logits"][-1]
ctc_alignment_oracle["source"] = None
if ctc_logit is not None:
if "transcript" in sample:
tokens = sample["transcript"]["tokens"]
else:
tokens = sample["target"]
pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
target_lengths = pad_mask.sum(-1)
if "ctc_padding_mask" in encoder_out:
non_padding_mask = ~encoder_out["ctc_padding_mask"][0]
else:
non_padding_mask = ~encoder_out["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
best_aligns = best_alignment(ctc_logit.float(), tokens, input_lengths, target_lengths,
self.pad_idx, zero_infinity=True)
best_aligns_pad = torch.tensor([a + [0] * (ctc_logit.size(0) - len(a)) for a in best_aligns],
device=ctc_logit.device, dtype=tokens.dtype)
oracle_pos = torch.div(best_aligns_pad, 2, rounding_mode='floor').clip(max=tokens.shape[1] - 1)
oracle = tokens.gather(-1, oracle_pos)
source_oracle = oracle.masked_fill(best_aligns_pad % 2 == 0, self.blank_idx)
ctc_alignment_oracle["source"] = [source_oracle, best_aligns_pad]
ctc_alignment_oracle["target"] = None
target_ctc_logit = None
if "target_ctc_logit" in encoder_out and len(encoder_out["target_ctc_logit"]) != 0:
target_ctc_logit = encoder_out["ctc_logit"][0]
elif "target_interleaved_ctc_logits" in encoder_out and len(
encoder_out["target_interleaved_ctc_logits"]) != 0:
target_ctc_logit = encoder_out["target_interleaved_ctc_logits"][-1]
if target_ctc_logit is not None:
target_tokens = sample["target"]
target_pad_mask = (target_tokens != self.pad_idx) & (target_tokens != self.eos_idx)
target_lengths = target_pad_mask.sum(-1)
if "ctc_padding_mask" in encoder_out:
non_padding_mask = ~encoder_out["ctc_padding_mask"][0]
else:
non_padding_mask = ~encoder_out["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
best_aligns = best_alignment(target_ctc_logit.float(), target_tokens, input_lengths, target_lengths,
self.pad_idx, zero_infinity=True)
best_aligns_pad = torch.tensor([a + [0] * (ctc_logit.size(0) - len(a)) for a in best_aligns],
device=target_ctc_logit.device, dtype=target_tokens.dtype)
oracle_pos = (best_aligns_pad // 2).clip(max=tokens.shape[1] - 1)
oracle = tokens.gather(-1, oracle_pos)
target_oracle = oracle.masked_fill(best_aligns_pad % 2 == 0, self.blank_idx)
ctc_alignment_oracle["target"] = [target_oracle, best_aligns_pad]
return ctc_alignment_oracle
def get_ctc_loss(self, model, ctc_logit, targets, input_lengths, target_lengths, loss_coef):
lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
lprobs.batch_first = False
loss = 0
with torch.backends.cudnn.flags(enabled=False):
ctc_loss = self.ctc_loss(
lprobs,
targets_flat,
input_lengths,
transcript_lengths,
for item_targets, item_target_lengths, item_coef in zip(targets, target_lengths, loss_coef):
loss += self.ctc_loss(
lprobs,
item_targets,
input_lengths,
item_target_lengths,
) * item_coef
return loss, lprobs
@staticmethod
def get_ctc_self_distill_loss(distill_num, teacher_logit, student_logits, non_padding_mask):
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
for i in range(distill_num):
logit = student_logits[i]
if type(logit) == list:
student_logit = logit[0]
non_padding_mask = ~logit[1]
else:
student_logit = logit
if student_logit.size() != teacher_logit.size():
continue
loss = F.kl_div(
F.log_softmax(student_logit, dim=-1, dtype=torch.float32),
F.log_softmax(teacher_logit.detach(), dim=-1, dtype=torch.float32),
log_target=True,
reduction="none",
)
return ctc_loss
ctc_self_distill_loss += loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0).sum()
ctc_self_distill_num += 1
return ctc_self_distill_num, ctc_self_distill_loss
def compute_ctc_loss(self, model, sample, net_output, logging_output):
if "transcript" in sample:
......@@ -187,62 +323,36 @@ class CtcCriterion(FairseqCriterion):
non_padding_mask = ~net_output["ctc_padding_mask"][0]
else:
non_padding_mask = ~net_output["encoder_padding_mask"][0]
# non_padding_mask = ~net_output["encoder_padding_mask"][0]
mixup = False
input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
if "mixup" in net_output and net_output["mixup"] is not None:
mixup = True
mixup_coef = net_output["mixup"]["coef"]
mixup_idx1 = net_output["mixup"]["index1"]
mixup_idx2 = net_output["mixup"]["index2"]
input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (tokens != self.pad_idx) & (
tokens != self.eos_idx
)
if mixup:
mask1 = pad_mask[mixup_idx1]
mask2 = pad_mask[mixup_idx2]
transcript_flat1 = tokens[[mixup_idx1]].masked_select(mask1)
transcript_flat2 = tokens[mixup_idx2].masked_select(mask2)
transcripts1 = tokens[[mixup_idx1]].masked_select(mask1)
transcripts2 = tokens[mixup_idx2].masked_select(mask2)
transcript_lengths1 = mask1.sum(-1)
transcript_lengths2 = mask2.sum(-1)
transcript_flat = [transcript_flat1, transcript_flat2]
transcripts = [transcripts1, transcripts2]
transcript_lengths = [transcript_lengths1, transcript_lengths2]
loss_coef = [mixup_coef, 1 - mixup_coef]
else:
transcript_flat = [tokens.masked_select(pad_mask)]
mixup = False
transcripts = [tokens.masked_select(pad_mask)]
transcript_lengths = [pad_mask.sum(-1)]
loss_coef = [1]
ctc_loss = 0
ctc_entropy = 0
all_ctc_logits = dict()
self.ctc_names = []
lprobs = None
if self.ctc_weight > 0 and "ctc_logit" in net_output and len(net_output["ctc_logit"]) > 0:
ctc_logit = net_output["ctc_logit"][0]
lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
lprobs.batch_first = False
for flat, lengths, coef in zip(transcript_flat, transcript_lengths, loss_coef):
ctc_loss += self.get_loss(lprobs, flat, input_lengths, lengths) * coef
if self.ctc_entropy > 0:
if self.ctc_entropy_cutoff != 0:
# ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100]
# ctc_logit = ctc_logit / ctc_logit.sum(dim=-1, keepdim=True)
cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:self.ctc_entropy_cutoff]
cut_ctc_logit = cut_ctc_logit / cut_ctc_logit.sum(dim=-1, keepdim=True)
ctc_entropy = Categorical(logits=cut_ctc_logit).entropy().sum()
else:
ctc_entropy = Categorical(logits=ctc_logit).entropy().sum()
logging_output["ctc_entropy"] = utils.item(ctc_entropy.data)
logging_output["ctc_loss"] = utils.item(ctc_loss.data)
target_lprobs = None
interleaved_ctc_num = 0
interleaved_ctc_loss = 0
......@@ -251,128 +361,170 @@ class CtcCriterion(FairseqCriterion):
# calculate the interleaved CTC loss
if self.interleaved_ctc_weight > 0 and interleaved_ctc_num > 0:
logits = net_output["interleaved_ctc_logits"]
for i in range(interleaved_ctc_num):
out = net_output["interleaved_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
inter_input_lengths = padding.long().sum(-1)
logit = logits[i]
if type(logit) == list:
inter_ctc_logit = logit[0]
inter_input_lengths = (~logit[1]).long().sum(-1)
else:
inter_ctc_logit = out
inter_ctc_logit = logit
inter_input_lengths = input_lengths
inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False
for flat, lengths, coef in zip(transcript_flat, transcript_lengths, loss_coef):
interleaved_ctc_loss += self.get_loss(inter_lprobs, flat, inter_input_lengths, lengths) * coef
all_ctc_logits["interleaved_ctc_logit%d" % i] = [inter_ctc_logit, inter_input_lengths]
inter_loss, inter_lprobs = self.get_ctc_loss(
model, inter_ctc_logit, transcripts, inter_input_lengths, transcript_lengths, loss_coef)
interleaved_ctc_loss += inter_loss
lprobs = inter_lprobs
interleaved_ctc_loss /= interleaved_ctc_num
logging_output["interleaved_ctc_loss"] = utils.item(interleaved_ctc_loss.data)
if lprobs is None:
lprobs = inter_lprobs
ctc_loss = 0
ctc_entropy = 0
if self.ctc_weight > 0 and "ctc_logit" in net_output and len(net_output["ctc_logit"]) > 0:
ctc_logit = net_output["ctc_logit"][0]
all_ctc_logits["ctc_logit"] = [ctc_logit, input_lengths]
target_ctc_loss = 0
target_interleaved_ctc_loss = 0
ctc_loss, lprobs = self.get_ctc_loss(
model, ctc_logit, transcripts, input_lengths, transcript_lengths, loss_coef)
if self.ctc_entropy > 0:
if self.ctc_entropy_cutoff != 0:
cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:self.ctc_entropy_cutoff]
cut_ctc_logit = cut_ctc_logit / cut_ctc_logit.sum(dim=-1, keepdim=True)
ctc_entropy = Categorical(logits=cut_ctc_logit).entropy().sum()
else:
ctc_entropy = Categorical(logits=ctc_logit).entropy().sum()
logging_output["ctc_entropy"] = utils.item(ctc_entropy.data)
logging_output["ctc_loss"] = utils.item(ctc_loss.data)
# calculate the target CTC loss
if self.target_ctc_weight > 0 or self.target_interleaved_ctc_weight > 0:
target = sample["target"]
pad_mask = (target != self.pad_idx) & (target != self.eos_idx)
target_ctc_loss = 0
target_interleaved_ctc_loss = 0
target_interleaved_ctc_num = 0
if self.use_target_ctc:
target_tokens = sample["target"]
target_pad_mask = (target_tokens != self.pad_idx) & (target_tokens != self.eos_idx)
target_no_padding_mask = ~target_pad_mask
if mixup:
mask1 = pad_mask[mixup_idx1]
mask2 = pad_mask[mixup_idx2]
target_flat1 = target.masked_select(mask1)
target_flat2 = target.masked_select(mask2)
transcript_lengths1 = mask1.sum(-1)
transcript_lengths2 = mask2.sum(-1)
target_flat = [target_flat1, target_flat2]
target_length = [transcript_lengths1, transcript_lengths2]
mask1 = target_pad_mask[mixup_idx1]
mask2 = target_pad_mask[mixup_idx2]
target_tokens1 = target_tokens.masked_select(mask1)
target_tokens2 = target_tokens.masked_select(mask2)
target_lengths1 = mask1.sum(-1)
target_lengths2 = mask2.sum(-1)
target_tokens = [target_tokens1, target_tokens2]
target_lengths = [target_lengths1, target_lengths2]
loss_coef = [mixup_coef, 1 - mixup_coef]
else:
target_flat = [target.masked_select(pad_mask)]
target_length = [pad_mask.sum(-1)]
target_tokens = [target_tokens.masked_select(target_pad_mask)]
target_lengths = [target_pad_mask.sum(-1)]
loss_coef = [1]
if self.target_ctc_weight > 0:
assert "target_ctc_logit" in net_output
target_ctc_logit = net_output["target_ctc_logit"]
tgt_lprobs = model.get_normalized_probs(
[target_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
tgt_lprobs.batch_first = False
for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_ctc_loss += self.get_loss(tgt_lprobs, flat, input_lengths, lengths) * coef
target_interleaved_ctc_num = 0
if "target_interleaved_ctc_logits" in net_output:
target_interleaved_ctc_num = len(net_output["target_interleaved_ctc_logits"])
if target_interleaved_ctc_num != 0 and self.target_interleaved_ctc_weight > 0:
logits = net_output["target_interleaved_ctc_logits"]
for i in range(target_interleaved_ctc_num):
out = net_output["target_interleaved_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
tgt_input_lengths = padding.long().sum(-1)
logit = logits[i]
if type(logit) == list:
target_inter_ctc_logit = logit[0]
padding = ~logit[1]
inter_input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
tgt_input_lengths = input_lengths
tgt_inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
tgt_inter_lprobs.batch_first = False
target_inter_ctc_logit = logit
inter_input_lengths = input_lengths
for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_interleaved_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths,
lengths) * coef
all_ctc_logits["target_interleaved_ctc_logit%d" % i] = [target_inter_ctc_logit, inter_input_lengths]
inter_loss, target_inter_lprobs = self.get_ctc_loss(
model, target_inter_ctc_logit, target_tokens, inter_input_lengths, target_lengths, loss_coef)
target_interleaved_ctc_loss += inter_loss
target_lprobs = target_inter_lprobs
target_interleaved_ctc_loss /= target_interleaved_ctc_num
logging_output["target_interleaved_ctc_loss"] = utils.item(target_interleaved_ctc_loss.data)
# calculate the self distillation CTC loss
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0 and \
torch.rand(1).uniform_() < self.ctc_self_distill_prob:
for i in range(interleaved_ctc_num):
out = net_output["interleaved_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
non_padding_mask = ~out[1]
else:
inter_ctc_logit = out
if self.target_ctc_weight > 0:
assert "target_ctc_logit" in net_output
target_ctc_logit = net_output["target_ctc_logit"][0]
all_ctc_logits["target_ctc_logit"] = [target_ctc_logit, input_lengths]
if inter_ctc_logit.size() != ctc_logit.size():
continue
target_ctc_loss, target_lprobs = self.get_ctc_loss(
model, target_ctc_logit, target_tokens, input_lengths, target_lengths, loss_coef)
logging_output["target_ctc_loss"] = utils.item(target_ctc_loss.data)
loss = F.kl_div(
F.log_softmax(inter_ctc_logit, dim=-1, dtype=torch.float32),
F.log_softmax(ctc_logit, dim=-1, dtype=torch.float32).detach(),
log_target=True,
reduction="none",
)
loss = loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0)
loss = loss.sum()
ctc_self_distill_loss += loss
ctc_self_distill_num += 1
# calculate the self distillation CTC loss
ctc_self_distill_loss = 0
if self.use_source_distill or self.use_target_distill:
ctc_self_distill_choice = torch.rand(1).uniform_()
if ctc_self_distill_num != 0:
ctc_self_distill_loss /= ctc_self_distill_num
logging_output["ctc_self_distill_loss"] = utils.item(ctc_self_distill_loss.data)
cal_source_distill = cal_target_distill = False
if not self.training:
cal_source_distill = True if self.use_source_distill else False
cal_target_distill = True if self.use_target_distill else False
else:
if ctc_self_distill_choice <= self.ctc_self_distill_prob:
if self.use_source_distill and self.use_target_distill:
cal_source_distill = True if ctc_self_distill_choice > self.ctc_self_distill_prob / 2 else False
cal_target_distill = not cal_source_distill
else:
cal_source_distill = self.use_source_distill
cal_target_distill = self.use_target_distill
# source self distillation
if cal_source_distill:
ctc_self_distill_num = 0
non_padding = non_padding_mask
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0:
teacher_logit = ctc_logit
student_logits = net_output["interleaved_ctc_logits"]
ctc_self_distill_num = interleaved_ctc_num
elif self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 1:
teacher_logit = net_output["interleaved_ctc_logits"][-1]
student_logits = net_output["interleaved_ctc_logits"][:-1]
ctc_self_distill_num = interleaved_ctc_num - 1
if ctc_self_distill_num != 0:
ctc_self_distill_num, source_ctc_self_distill_loss = \
self.get_ctc_self_distill_loss(
ctc_self_distill_num, teacher_logit, student_logits, non_padding)
source_ctc_self_distill_loss /= ctc_self_distill_num
logging_output["ctc_self_distill_loss"] = utils.item(source_ctc_self_distill_loss.data)
ctc_self_distill_loss += source_ctc_self_distill_loss * self.ctc_self_distill_weight
# target self distillation
if cal_target_distill:
ctc_self_distill_num = 0
non_padding = non_padding_mask
if self.target_ctc_weight > 0 and self.target_ctc_self_distill_weight > 0 and target_interleaved_ctc_num > 0:
teacher_logit = target_ctc_logit
student_logits = net_output["target_interleaved_ctc_logits"]
ctc_self_distill_num = target_interleaved_ctc_num
elif self.target_ctc_self_distill_weight > 0 and target_interleaved_ctc_num > 1:
teacher_logit = net_output["target_interleaved_ctc_logits"][-1]
student_logits = net_output["target_interleaved_ctc_logits"][:-1]
ctc_self_distill_num = target_interleaved_ctc_num - 1
if ctc_self_distill_num != 0:
ctc_self_distill_num, target_ctc_self_distill_loss = \
self.get_ctc_self_distill_loss(
ctc_self_distill_num, teacher_logit, student_logits, non_padding)
target_ctc_self_distill_loss /= ctc_self_distill_num
logging_output["target_ctc_self_distill_loss"] = utils.item(target_ctc_self_distill_loss.data)
ctc_self_distill_loss += target_ctc_self_distill_loss * self.target_ctc_self_distill_weight
loss = \
self.ctc_weight * ctc_loss + \
self.interleaved_ctc_weight * interleaved_ctc_loss + \
self.target_ctc_weight * target_ctc_loss + \
self.target_interleaved_ctc_weight * target_interleaved_ctc_loss + \
self.ctc_self_distill_weight * ctc_self_distill_loss + \
ctc_self_distill_loss + \
self.ctc_entropy * ctc_entropy
logging_output["all_ctc_loss"] = utils.item(loss.data)
......@@ -386,74 +538,58 @@ class CtcCriterion(FairseqCriterion):
if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss)
if not model.training and self.ctc_weight + self.interleaved_ctc_weight > 0:
with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
target = tokens
if mixup:
idx = mixup_idx1
if mixup_coef < 0.5:
idx = mixup_idx2
target = target[idx]
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
wv_errs = 0
for lp, t, inp_l in zip(
lprobs_t,
target,
input_lengths,
):
lp = lp[:inp_l].unsqueeze(0)
decoded = None
if self.w2l_decoder is not None:
decoded = self.w2l_decoder.decode(lp)
if len(decoded) < 1:
decoded = None
if not model.training:
if hasattr(model.encoder, "ctc_valid"):
if lprobs is not None:
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
if mixup:
idx = mixup_idx1 if mixup_coef > 0.5 else mixup_idx2
tokens = tokens[idx]
c_err, c_len, w_errs, w_len, wv_errs = model.encoder.ctc_valid(
lprobs_t, tokens, input_lengths, self.task.source_dictionary, lang="source")
logging_output["wv_errors"] = wv_errs
logging_output["w_errors"] = w_errs
logging_output["w_total"] = w_len
logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len
if target_lprobs is not None:
target_lprobs_t = target_lprobs.transpose(0, 1).float().contiguous().cpu()
target_tokens = sample["target"]
if mixup:
idx = mixup_idx1 if mixup_coef > 0.5 else mixup_idx2
target_tokens = target_tokens[idx]
c_err, c_len, w_errs, w_len, wv_errs = model.encoder.ctc_valid(
target_lprobs_t, target_tokens, input_lengths, self.task.target_dictionary, lang="target")
logging_output["target_wv_errors"] = wv_errs
logging_output["target_w_errors"] = w_errs
logging_output["target_w_total"] = w_len
logging_output["target_c_errors"] = c_err
logging_output["target_c_total"] = c_len
if self.cal_all_ctc:
logging_output["save_dir"] = self.save_dir
for name, items in all_ctc_logits.items():
logit, lengths = items
if "target" in name:
dictionary = self.task.target_dictionary
ctc_tokens = target_tokens
lang = "target"
else:
decoded = decoded[0]
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
)
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
dictionary = self.task.source_dictionary
ctc_tokens = tokens
lang = "source"
c_err, c_len, w_errs, w_len, wv_errs = model.encoder.ctc_valid(
logit, ctc_tokens, lengths, dictionary, lang)
cer = c_err * 100 / c_len
wer = w_errs * 100 / w_len
if decoded is not None and "words" in decoded:
pred_words = decoded["words"]
w_errs += editdistance.eval(pred_words, targ_words)
wv_errs += editdistance.eval(pred_words_raw, targ_words)
else:
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
w_len += len(targ_words)
logging_output["wv_errors"] = wv_errs
logging_output["w_errors"] = w_errs
logging_output["w_total"] = w_len
logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len
logging_output["dump_%s_cer" % name] = cer
logging_output["dump_%s_wer" % name] = wer
return loss, logging_output
......@@ -479,6 +615,9 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_loss_sum = utils.item(
sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs)
)
target_ctc_self_distill_loss_sum = utils.item(
sum(log.get("target_ctc_self_distill_loss", 0) for log in logging_outputs)
)
all_ctc_loss_sum = utils.item(
sum(log.get("all_ctc_loss", 0) for log in logging_outputs)
)
......@@ -552,6 +691,13 @@ class CtcCriterion(FairseqCriterion):
sample_size,
round=3,
)
if target_ctc_self_distill_loss_sum > 0:
metrics.log_scalar(
"target_ctc_self_distill_loss_sum",
target_ctc_self_distill_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
......@@ -589,14 +735,48 @@ class CtcCriterion(FairseqCriterion):
if meters["_w_total"].sum > 0
else float("nan"),
)
# metrics.log_derived(
# "raw_wer",
# lambda meters: safe_round(
# meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
# )
# if meters["_w_total"].sum > 0
# else float("nan"),
# )
target_c_errors = sum(log.get("target_c_errors", 0) for log in logging_outputs)
metrics.log_scalar("_target_c_errors", target_c_errors)
target_c_total = sum(log.get("target_c_total", 0) for log in logging_outputs)
metrics.log_scalar("_target_c_total", target_c_total)
target_w_errors = sum(log.get("target_w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_target_w_errors", target_w_errors)
target_w_total = sum(log.get("target_w_total", 0) for log in logging_outputs)
metrics.log_scalar("_target_w_total", target_w_total)
if target_c_total > 0:
metrics.log_derived(
"target_cer",
lambda meters: safe_round(
meters["_target_c_errors"].sum * 100.0 / meters["_target_c_total"].sum, 3
)
if meters["_target_c_total"].sum > 0
else float("nan"),
)
if target_w_total > 0:
metrics.log_derived(
"target_wer",
lambda meters: safe_round(
meters["_target_w_errors"].sum * 100.0 / meters["_target_w_total"].sum, 3
)
if meters["_target_w_total"].sum > 0
else float("nan"),
)
# save_dir = logging_outputs.get("save_dir", None)
# if save_dir is not None and os.path.exists(save_dir):
# out = open(os.path.join(save_dir, "ctc_results"), "a")
# else:
# out = sys.stdout
#
# for key in logging_outputs:
# if key.startswith("dump"):
# print("%s: %.2f" % (key, logging_outputs[key]), end="\t", file=out)
# print("", file=out)
# out.close()
#
# out = sys.stdout
@staticmethod
def logging_outputs_can_be_summed() -> bool:
......
......@@ -23,12 +23,14 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
def __init__(self, task, label_smoothing,
sentence_avg,
cfg: CtcCriterionConfig,
ctc_weight=0.0):
ctc_weight=0.0,
save_dir=None):
super().__init__(task, sentence_avg, label_smoothing)
self.report_accuracy = True
self.ctc_weight = ctc_weight
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight)
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight, save_dir)
self.save_dir = save_dir
@staticmethod
def add_args(parser):
......@@ -62,7 +64,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
encoder_out = model.encoder(src_tokens, src_lengths,
text_src_tokens, text_src_lengths)
else:
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
if self.training and getattr(model.encoder, "sae_ground_truth_ratio", 0) != 0:
ctc_alignment_oracle = self.ctc_criterion.get_ground_truth_alignment(model, sample)
encoder_out = model.encoder(src_tokens, src_lengths,
ctc_alignment_oracle=ctc_alignment_oracle)
else:
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
use_mixup = False
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
......
......@@ -2,7 +2,6 @@ import logging
from typing import Dict, Optional
import torch
import torch.nn as nn
from fairseq import checkpoint_utils, utils
from fairseq.models import (
......@@ -138,7 +137,6 @@ class CTCDecoder(object):
# the max beam size is the dictionary size - 1, since we never select pad
self.beam_size = min(self.beam_size, self.vocab_size - 1)
# from fairseq.sequence_generator import EnsembleModel
from fairseq.sequence_generator import EnsembleModel
if isinstance(models, EnsembleModel):
self.model = models
......@@ -240,8 +238,12 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
......@@ -276,7 +278,9 @@ def base_architecture(args):
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False)
args.sae_embed_norm = getattr(args, "sae_embed_norm", False)
args.sae_out_norm = getattr(args, "sae_out_norm", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
......@@ -310,8 +314,6 @@ def base_architecture(args):
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
@register_model_architecture("s2t_ctc", "s2t_ctc_s")
def s2t_ctc_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
......
......@@ -125,7 +125,7 @@ class S2TSATEModel(S2TTransformerModel):
# target CTC
parser.add_argument(
"--target-ctc-layer",
default=None,
default=0,
type=int,
help="ctc layer for target sentence",
)
......@@ -233,15 +233,15 @@ class S2TSATEModel(S2TTransformerModel):
return cls(encoder, decoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens):
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out, **kwargs
)
return decoder_out
......@@ -286,7 +286,9 @@ class TextualEncoder(FairseqEncoder):
self.use_ctc = getattr(args, "target_ctc_weight", 0) > 0
if self.use_ctc:
self.ctc_layer = getattr(args, "target_ctc_layer", layer_num)
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.ctc_layer == 0:
self.ctc_layer = layer_num
self.inter_ctc = True if self.ctc_layer != layer_num else False
if self.inter_ctc:
logger.info("Target CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(embed_dim,
......@@ -294,13 +296,16 @@ class TextualEncoder(FairseqEncoder):
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
if embed_tokens is not None:
if embed_tokens is not None and args.share_target_ctc_and_embed and \
self.ctc.ctc_projection.weight.size() == embed_tokens.weight.size():
self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.interleaved_ctc_layers = []
self.target_interleaved_ctc_layers = getattr(args, "target_interleaved_ctc_layers", None)
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
if self.target_interleaved_ctc_layers is not None:
target_interleaved_ctc_layers = self.target_interleaved_ctc_layers.split(",")
for layer_idx in target_interleaved_ctc_layers:
......@@ -337,7 +342,7 @@ class TextualEncoder(FairseqEncoder):
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
def forward(self, x, encoder_padding_mask=None, history=None):
def forward(self, x, encoder_padding_mask=None, history=None, **kwargs):
if self.encoder_embed_norm:
x = self.embed_ln(x)
......@@ -356,7 +361,7 @@ class TextualEncoder(FairseqEncoder):
layer_idx += 1
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
target_ctc_logit = self.ctc(x.clone())
target_ctc_logit = self.ctc(x.clone(), encoder_padding_mask, "Target Layer %d" % layer_idx)
if layer_idx != self.layer_num and layer_idx in self.interleaved_ctc_layers:
if self.interleaved_ctc_drop_prob > 0:
......@@ -365,11 +370,23 @@ class TextualEncoder(FairseqEncoder):
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
logit = self.ctc(norm_x, encoder_padding_mask, "Target Layer %d" % layer_idx)
target_interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask)
# CTC alignment
oracle = None
oracle_mask = None
force_emit = None
if self.sae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ctc_alignment_oracle is not None and ctc_alignment_oracle["target"] is not None:
oracle, best_aligns_pad = ctc_alignment_oracle["target"]
oracle_mask = (torch.rand(oracle.size(),
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask, oracle, oracle_mask)
if history is not None:
history.push(x)
......@@ -381,7 +398,7 @@ class TextualEncoder(FairseqEncoder):
x = self.layer_norm(x)
if self.use_ctc and target_ctc_logit is None:
target_ctc_logit = self.ctc(x)
target_ctc_logit = self.ctc(x, encoder_padding_mask, "Target output")
return x, target_ctc_logit, target_interleaved_ctc_logits
......@@ -435,6 +452,7 @@ class S2TSATEEncoder(FairseqEncoder):
self.freeze_acoustic_encoder = getattr(args, "freeze_acoustic_encoder", False)
self.freeze_textual_encoder = getattr(args, "freeze_textual_encoder", False)
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
if getattr(args, "use_enc_dlcl", False):
layer_num = args.encoder_layers + args.text_encoder_layers + 2
......@@ -443,11 +461,24 @@ class S2TSATEEncoder(FairseqEncoder):
self.history = None
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
if hasattr(self, "ctc"):
if hasattr(self.acoustic_encoder, "ctc"):
assert src_dict is not None
self.acoustic_encoder.ctc.set_infer(ctc_infer, post_process, src_dict)
if hasattr(self.textual_encoder, "ctc"):
assert tgt_dict is not None
self.textual_encoder.ctc.set_infer(ctc_infer, post_process, tgt_dict)
def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"):
if lang == "source":
if hasattr(self.acoustic_encoder, "ctc"):
return self.acoustic_encoder.ctc.valid(lprobs, targets, input_lengths, dictionary)
else:
logger.error("No ctc module in textual encoder")
else:
if hasattr(self.textual_encoder, "ctc"):
self.textual_encoder.ctc.set_infer(ctc_infer, post_process, tgt_dict)
return self.textual_encoder.ctc.valid(lprobs, targets, input_lengths, dictionary)
else:
logger.error("No ctc module in textual encoder")
def forward(self, src_tokens, src_lengths=None, **kwargs):
if self.history is not None:
......@@ -455,9 +486,9 @@ class S2TSATEEncoder(FairseqEncoder):
if self.freeze_acoustic_encoder:
with torch.no_grad():
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths)
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths, **kwargs)
else:
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths)
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths, **kwargs)
encoder_out = acoustic_encoder_out["encoder_out"][0]
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
......@@ -490,16 +521,16 @@ class S2TSATEEncoder(FairseqEncoder):
if self.freeze_textual_encoder:
with torch.no_grad():
x, target_ctc_logit, target_interleaved_ctc_logits = self.textual_encoder(x, encoder_padding_mask,
self.history)
self.history, **kwargs)
else:
x, target_ctc_logit, target_interleaved_ctc_logits = self.textual_encoder(x, encoder_padding_mask,
self.history)
self.history, **kwargs)
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C
"interleaved_ctc_logits": acoustic_encoder_out.get("interleaved_ctc_logits", []), # B x T x C
"target_ctc_logit": target_ctc_logit, # B x T x C
"target_ctc_logit": [target_ctc_logit], # B x T x C
"target_interleaved_ctc_logits": target_interleaved_ctc_logits, # B x T x C
"ctc_padding_mask": [ctc_padding_mask], # B x T
"encoder_padding_mask": [encoder_padding_mask], # B x T
......
......@@ -447,6 +447,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="cutoff of the distribution in sae",
)
parser.add_argument(
"--sae-ground-truth-ratio",
default=0,
type=float,
help="the ratio for ground truth in sae",
)
parser.add_argument(
"--share-sae-and-ctc",
action="store_true",
help="share the weight of ctc and sae",
......@@ -570,13 +576,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
lprobs.batch_first = True
return lprobs
def forward(self, src_tokens, src_lengths, prev_output_tokens):
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
......@@ -655,6 +661,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
self.interleaved_ctc_layers = []
if args.interleaved_ctc_layers is not None:
interleaved_ctc_layers = args.interleaved_ctc_layers.split(",")
......@@ -681,6 +688,7 @@ class S2TTransformerEncoder(FairseqEncoder):
"out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"gt_ratio": self.sae_ground_truth_ratio,
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
......@@ -717,6 +725,14 @@ class S2TTransformerEncoder(FairseqEncoder):
assert src_dict is not None
self.ctc.set_infer(ctc_infer, post_process, src_dict)
def ctc_valid(self, lprobs, targets, input_lengths,
dictionary, lang="source"):
if hasattr(self, "ctc"):
return self.ctc.valid(lprobs, targets, input_lengths,
dictionary)
else:
logger.error("No ctc module in textual encoder")
def set_debug_var(self, debug_var_flag):
self.debug_var = debug_var_flag
......@@ -879,7 +895,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone(), encoder_padding_mask)
ctc_logit = self.ctc(x.clone(), encoder_padding_mask, "Source Layer %d" % layer_idx)
# interleaved CTC
if layer_idx in self.interleaved_ctc_layers:
......@@ -889,15 +905,27 @@ class S2TTransformerEncoder(FairseqEncoder):
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, encoder_padding_mask)
logit = self.ctc(norm_x, encoder_padding_mask, "Source Layer %d" % layer_idx)
interleaved_ctc_logits.append(logit)
logit = logit.clamp(min=-1e8 if logit.dtype == torch.float32 else -1e4,
max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask)
# CTC alignment
oracle = None
oracle_mask = None
force_emit = None
if self.sae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ctc_alignment_oracle is not None and ctc_alignment_oracle["source"] is not None:
oracle, best_aligns_pad = ctc_alignment_oracle["source"]
oracle_mask = (torch.rand(oracle.size(),
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask, oracle, oracle_mask)
self.show_debug(x, "x after sae")
# gather cosine similarity
......@@ -917,7 +945,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.show_debug(x, "x after encoding layer norm")
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, encoder_padding_mask)
ctc_logit = self.ctc(x, encoder_padding_mask, "Source output")
self.show_debug(x, "x after ctc")
return {
......@@ -925,6 +953,7 @@ class S2TTransformerEncoder(FairseqEncoder):
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"interleaved_ctc_logits": interleaved_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
# "oracle": [oracle, oracle_mask, force_emit],
"mixup": mixup,
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
......
......@@ -315,7 +315,7 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
parser.add_argument(
"--interleaved-ctc-upsampling-ratio",
default=2,
type=int,
type=float,
help="upsampling ratio of the representation for CTC calculation",
)
parser.add_argument(
......@@ -355,6 +355,24 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
action="store_true",
help="share the weight of ctc and sae",
)
parser.add_argument(
"--sae-embed-norm",
default=False,
action="store_true",
help="use the layer norm for embed output",
)
parser.add_argument(
"--sae-out-norm",
default=False,
action="store_true",
help="use the layer norm for final output",
)
parser.add_argument(
"--sae-ground-truth-ratio",
default=0,
type=float,
help="the ratio for ground truth in sae",
)
# fmt: on
@classmethod
......@@ -625,6 +643,11 @@ class TransformerCTCEncoder(FairseqEncoder):
logger.info("Interleaved CTC loss in layer %d" % layer_idx)
self.un_sample = torch.nn.Upsample(scale_factor=self.interleaved_ctc_upsampling_ratio, mode="linear",
align_corners=True)
self.down_sample = torch.nn.Upsample(scale_factor=1 / self.interleaved_ctc_upsampling_ratio, mode="linear",
align_corners=True)
if not self.use_ctc:
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
......@@ -633,10 +656,14 @@ class TransformerCTCEncoder(FairseqEncoder):
self.ctc.ctc_projection.weight = decoder_embed_tokens.weight
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
strategy = {
"embed_norm": getattr(args, "sae_embed_norm", False),
"out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
"gt_ratio": self.sae_ground_truth_ratio,
}
self.sae = Adapter(embed_dim, args.sae_adapter,
......@@ -645,9 +672,9 @@ class TransformerCTCEncoder(FairseqEncoder):
)
if args.share_ctc_and_sae and hasattr(self.sae, "embed_adapter"):
self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
if hasattr(self, "ctc"):
self.pool = nn.MaxPool1d(kernel_size=self.interleaved_ctc_upsampling_ratio,
stride=self.interleaved_ctc_upsampling_ratio)
# if hasattr(self, "ctc"):
# self.pool = nn.MaxPool1d(kernel_size=self.interleaved_ctc_upsampling_ratio,
# stride=self.interleaved_ctc_upsampling_ratio)
def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args)
......@@ -679,6 +706,7 @@ class TransformerCTCEncoder(FairseqEncoder):
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
**kwargs
):
"""
Args:
......@@ -706,7 +734,8 @@ class TransformerCTCEncoder(FairseqEncoder):
return self.forward_scriptable(src_tokens,
src_lengths,
return_all_hiddens,
token_embeddings)
token_embeddings,
**kwargs)
def upsampling(self, x, padding):
ratio = self.interleaved_ctc_upsampling_ratio
......@@ -714,12 +743,17 @@ class TransformerCTCEncoder(FairseqEncoder):
return x
if len(x.size()) == 3:
bsz, seq_len, dim = x.size()
up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
# bsz, seq_len, dim = x.size()
# up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
seq_len, bsz, dim = x.size()
x = x.permute(1, 2, 0)
up_x = self.un_sample(x)
up_x = up_x.permute(2, 0, 1)
else:
bsz, seq_len = x.size()
up_x = x.unsqueeze(2).expand(-1, -1, ratio).reshape(bsz, -1)
up_padding = padding.unsqueeze(-1).expand(-1, -1, ratio).reshape(bsz, -1)
up_padding = padding.unsqueeze(-1).expand(-1, -1, int(ratio)).reshape(bsz, -1)
# output_length = int(seq_len * ratio * 2/3)
# select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
......@@ -742,6 +776,14 @@ class TransformerCTCEncoder(FairseqEncoder):
assert tgt_dict is not None
self.ctc.set_infer(ctc_infer, post_process, tgt_dict)
def ctc_valid(self, lprobs, targets, input_lengths,
dictionary, lang="source"):
if hasattr(self, "ctc"):
return self.ctc.valid(lprobs, targets, input_lengths,
dictionary)
else:
logger.error("No ctc module in textual encoder")
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
......@@ -752,6 +794,7 @@ class TransformerCTCEncoder(FairseqEncoder):
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
**kwargs
):
"""
Args:
......@@ -783,10 +826,10 @@ class TransformerCTCEncoder(FairseqEncoder):
if self.history is not None:
self.history.clean()
ctc_padding_mask = encoder_padding_mask
if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
src_tokens, encoder_padding_mask = self.upsampling(src_tokens, encoder_padding_mask)
ctc_padding_mask = encoder_padding_mask
# ctc_padding_mask = encoder_padding_mask
# if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
# src_tokens, encoder_padding_mask = self.upsampling(src_tokens, encoder_padding_mask)
# ctc_padding_mask = encoder_padding_mask
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
......@@ -796,6 +839,8 @@ class TransformerCTCEncoder(FairseqEncoder):
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# x, encoder_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_padding_mask = encoder_padding_mask
encoder_states = []
......@@ -824,6 +869,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_logit = self.ctc(x.clone(), ctc_padding_mask)
# Interleaved CTC
......@@ -833,18 +879,31 @@ class TransformerCTCEncoder(FairseqEncoder):
if p < self.interleaved_ctc_drop_prob:
break
x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, ctc_padding_mask)
interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, _ = self.sae([norm_x, prob])
# x = x.permute(1, 2, 0)
# x = self.pool(x)
# x = x.permute(2, 0, 1)
# encoder_padding_mask = org_encoder_padding_mask
# CTC alignment
oracle = None
oracle_mask = None
force_emit = None
if self.sae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ctc_alignment_oracle is not None and ctc_alignment_oracle["source"] is not None:
oracle, best_aligns_pad = ctc_alignment_oracle["source"]
oracle_mask = (torch.rand(oracle.size(),
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, _ = self.sae([norm_x, prob], ctc_padding_mask, oracle, oracle_mask)
x = x.permute(1, 2, 0)
# x = nn.functional.interpolate(x, scale_factor=1/self.interleaved_ctc_upsampling_ratio, mode="linear")
x = self.down_sample(x)
x = x.permute(2, 0, 1)
if self.history is not None:
self.history.push(x)
......
......@@ -98,7 +98,7 @@ class Adapter(nn.Module):
self.ctc_compress = getattr(CTCCompressStrategy, ctc_compress_strategy)
logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy)
if "league" in self.adapter_type:
if self.cal_context:
self.distribution_cutoff = strategy.get("distribution_cutoff", None)
if self.distribution_cutoff is not None:
self.distribution_cutoff = int(self.distribution_cutoff)
......@@ -107,17 +107,21 @@ class Adapter(nn.Module):
self.drop_prob = strategy.get("drop_prob", 0)
if self.drop_prob != 0:
logger.info("Adapter drop probability: %f" % self.drop_prob)
self.ground_truth_ratio = strategy.get("gt_ratio", 0)
self.out_norm = strategy.get("out_norm", False)
if self.out_norm:
self.out_ln = LayerNorm(dim)
def forward(self, x, padding=None):
def forward(self, x, padding=None, oracle=None, oracle_mask=None):
representation, distribution = x
distribution = distribution.type_as(representation)
seq_len, bsz, dim = representation.size()
org_distribution = distribution
distribution = distribution.contiguous().view(-1, distribution.size(-1))
vocab_size = distribution.size(-1)
distribution = distribution.contiguous().view(-1, vocab_size)
linear_out = None
soft_out = None
......@@ -125,18 +129,32 @@ class Adapter(nn.Module):
linear_out = self.linear_adapter(representation)
if self.cal_context:
if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), org_distribution.size(-1) - 1)
cutoff = min(int(self.distribution_cutoff), vocab_size - 1)
# threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
# distribution = torch.where(
# org_distribution > threshold, org_distribution, torch.zeros_like(org_distribution)
# )
threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True)
distribution = torch.where(
threshold > 0.9, org_distribution, torch.zeros_like(org_distribution)
)
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
# threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True)
# distribution = torch.where(
# threshold > 0.9, org_distribution, torch.zeros_like(org_distribution)
# )
# distribution = distribution.view(-1, vocab_size)
distribution[:, 0] = 0
distribution = distribution / distribution.sum(-1, keepdim=True)
if self.ground_truth_ratio > 0 and oracle is not None:
oracle = oracle.unsqueeze(-1)
oracle_one_hot = (oracle == torch.arange(vocab_size, device=oracle.device).unsqueeze(0)).\
to(distribution.dtype).transpose(0, 1)
oracle_mask = oracle_mask.transpose(0, 1).unsqueeze(-1).repeat(1, 1, vocab_size)
modify_dist = oracle_mask * oracle_one_hot + ~oracle_mask * org_distribution
soft_out = torch.mm(modify_dist.view(-1, vocab_size), self.embed_adapter.weight).view(seq_len, bsz, -1)
else:
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
if self.embed_norm:
soft_out = self.embed_ln(soft_out)
......
......@@ -37,13 +37,14 @@ class CTC(nn.Module):
self.dictionary = dictionary
self.infer_decoding = False
self.post_process = "sentencepiece"
self.blank_idx = 0
def set_infer(self, is_infer, text_post_process, dictionary):
self.infer_decoding = is_infer
self.post_process = text_post_process
self.dictionary = dictionary
def forward(self, x, padding=None):
def forward(self, x, padding=None, tag=None):
if self.need_layernorm:
x = self.LayerNorm(x)
......@@ -52,7 +53,7 @@ class CTC(nn.Module):
if not self.training and self.infer_decoding:
assert self.dictionary is not None
input_lengths = (~padding).sum(-1)
self.infer(x.transpose(0, 1).float().contiguous().cpu(), input_lengths)
self.infer(x.transpose(0, 1).float().contiguous().cpu(), input_lengths, tag)
return x
......@@ -65,7 +66,7 @@ class CTC(nn.Module):
def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1)
def infer(self, logits_or_probs, lengths):
def infer(self, logits_or_probs, lengths, tag=None):
for lp, inp_l in zip(
logits_or_probs,
lengths,
......@@ -78,41 +79,47 @@ class CTC(nn.Module):
pred_units = self.dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
logger.info("\nCTC prediction: %s" % " ".join(pred_words_raw))
if tag is not None:
logger.info("%s CTC prediction: %s" % (tag, " ".join(pred_words_raw)))
else:
logger.info("CTC prediction: %s" % (" ".join(pred_words_raw)))
def valid(self, logits_or_probs, target, lengths):
def valid(self, logits_or_probs, targets, input_lengths, dictionary):
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
wv_errs = 0
for lp, t, inp_l in zip(
logits_or_probs,
target,
lengths,
):
lp = lp[:inp_l].unsqueeze(0)
with torch.no_grad():
for lp, t, inp_l in zip(
logits_or_probs,
targets,
input_lengths,
):
lp = lp[:inp_l].unsqueeze(0)
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
)
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
p = (t != dictionary.pad()) & (t != dictionary.eos())
targ = t[p]
targ_units = dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split()
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
pred_units = dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
w_len += len(targ_words)
w_len += len(targ_words)
\ No newline at end of file
return c_err, c_len, w_errs, w_len, wv_errs
from .imputer import imputer_loss, ImputerLoss, best_alignment, ctc_decode
// Copyright (c) 2018 MathInf GmbH, Thomas Viehmann
// Licensed under the BSD-3-Clause license
// This is the GPU implementation of the Connectionist Temporal Loss.
// We mostly follow Graves.
// 1. Graves et al: http://www.cs.toronto.edu/~graves/icml_2006.pdf
// We use the equations from above link, but note that [1] has 1-based indexing
// and we (of course) use 0-based. Graves et al call the probabilities y, we use
// log_probs (also calling them inputs) A few optimizations (simmilar to those
// here, but also some I didn't take) are described in
// 2. Minmin Sun:
// http://on-demand.gputechconf.com/gtc/2016/presentation/s6383-minmin-sun-speech-recognition.pdf
#include <ATen/TensorUtils.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <numeric>
#include <type_traits>
using namespace at;
// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
// so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
// l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
// - note that no bound-checking is done
// - it is important to only call it witth idx == 0 if the target length is 0
// - __restrict__ impact to be measured, see
// https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
template <typename target_t>
__device__ static inline int64_t
get_target_prime(const target_t *__restrict__ target, int64_t offset,
int64_t stride, int64_t idx, int64_t BLANK) {
if (idx % 2 == 0) {
return BLANK;
} else {
return target[offset + stride * (idx / 2)];
}
}
// this kernel is a relatively straightforward implementation of the alpha
// calculation in the forward backward algorithm (section 4.1). A (minor) twist
// is that we are using log-calculations to enhance numerical stability
// (log_probs and log_alpha). In total it would be more efficient to compute the
// beta in the same kernel (e.g. cudnn does this). While the beta are not needed
// for the loss itself (just the grad), we can return log_alpha+log_beta (so
// same space as currently) and the overhead is small and the use-case for loss
// without grad is relatively limited. We parallelize by batch and target
// sequence. Empirically, it is faster to loop over the input (log probs)
// sequence and do target in parallel, even if it means more frequent
// __syncthreads. In contrast to the cuDNN implementation, we allow large target
// lengths. For this we need that all previous `s` have been computed when we
// start a new block_s. This is why we have our own for loop here.
template <typename scalar_t, typename target_t>
__global__ void
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_alignment_log_alpha_gpu_kernel(
scalar_t *__restrict__ log_alpha_data, int64_t *__restrict__ paths_data,
const scalar_t *log_probs_data,
const int64_t *__restrict__ input_lengths, int64_t max_input_length,
const target_t *__restrict__ targets_data,
const int64_t *__restrict__ target_lengths, int64_t max_target_length,
scalar_t *__restrict__ neg_log_likelihood_data, int64_t lp_input_stride,
int64_t lp_batch_stride, int64_t lp_char_stride,
int64_t la_batch_stride, int64_t la_input_stride,
int64_t la_target_stride, const int64_t *__restrict__ tg_batch_offsets,
int64_t tg_target_stride, int64_t batch_size, int64_t BLANK) {
constexpr scalar_t neginf = -INFINITY;
// bookkeeping
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t lp_batch_offset = b * lp_batch_stride;
int64_t la_batch_offset = b * la_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
if (b >= batch_size)
return;
// first row (t=0), the three equations for alpha_1 above eq (6)
for (int64_t block_s = 0; block_s < 2 * max_target_length + 1;
block_s += blockDim.x) {
int64_t s = threadIdx.x + block_s;
scalar_t la;
switch (s) {
case 0:
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
break;
case 1:
la = target_length == 0
? neginf
: log_probs_data[lp_batch_offset +
lp_char_stride *
get_target_prime(
targets_data, tg_batch_offset,
tg_target_stride, 1, BLANK)];
break;
default:
la = neginf;
}
if (s < 2 * max_target_length + 1)
log_alpha_data[la_batch_offset +
/* la_input_stride * 0 */ +la_target_stride * s] = la;
}
for (int64_t block_s = 0; block_s < 2 * max_target_length + 1;
block_s += blockDim.x) {
int64_t s = threadIdx.x + block_s;
// These two only depend on s, so we can cache them.
int64_t current_char; // l_s in eq (6)
bool have_three; // flag which of the two cases in eq (6) we have
if (s < 2 * target_length + 1 && target_length > 0) {
current_char = get_target_prime(targets_data, tg_batch_offset,
tg_target_stride, s, BLANK);
have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset,
tg_target_stride, s - 2,
BLANK) != current_char));
} else {
current_char = BLANK;
have_three = false;
}
for (int64_t t = 1; t < max_input_length; t++) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only
// the threads within the same batch
if ((t < input_length) && (s < 2 * target_length + 1)) {
// only for valid t, s. This is equation (6) and (7), la1, la2, la3 are
// the three summands, lamax is the maximum for the logsumexp trick.
scalar_t la1 =
log_alpha_data[la_batch_offset + la_input_stride * (t - 1) +
la_target_stride * s];
scalar_t lamax = la1;
int64_t max_path = s;
scalar_t la2, la3;
if (s > 0) {
la2 = log_alpha_data[la_batch_offset + la_input_stride * (t - 1) +
la_target_stride * (s - 1)];
if (la2 > lamax) {
lamax = la2;
max_path = s - 1;
}
} else {
la2 = neginf;
}
if (have_three) {
la3 = log_alpha_data[la_batch_offset + la_input_stride * (t - 1) +
la_target_stride * (s - 2)];
if (la3 > lamax) {
lamax = la3;
max_path = s - 2;
}
} else {
la3 = neginf;
}
/*if (lamax == neginf) // when all are neginf. (then the whole thing is
// neginf, but we can pretend)
lamax = 0;*/
int64_t log_alpha_i =
la_batch_offset + la_input_stride * t + la_target_stride * s;
int64_t log_prob_i = lp_batch_offset + t * lp_input_stride +
lp_char_stride * current_char;
log_alpha_data[log_alpha_i] = lamax + log_probs_data[log_prob_i];
paths_data[log_alpha_i] = max_path;
} else {
// otherwise we just set to neginf
if (s < 2 * max_target_length + 1)
log_alpha_data[la_batch_offset + la_input_stride * t +
la_target_stride * s] = neginf;
}
}
}
__syncthreads(); // on cuda 9 we might use partial synchronization of only the
// threads within the same batch
// compute the loss (eq (8))
if (threadIdx.x == 0) {
scalar_t l1 =
log_alpha_data[la_batch_offset + la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2)];
scalar_t l2 =
target_length > 0
? log_alpha_data[la_batch_offset +
la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2 - 1)]
: neginf;
scalar_t m = ((l1 > l2) ? l1 : l2);
m = ((m == neginf) ? 0 : m);
scalar_t log_likelihood = std::log(std::exp(l1 - m) + std::exp(l2 - m)) + m;
neg_log_likelihood_data[b] = -log_likelihood;
}
}
// The forward computation. Lot's of admin and a call to the alpha kernel.
// Note: we do not check that the labels are in the valid range. As we use
// them for indexing in the kernels, you'll see memory errors when you
// pass corrupt labels.
// We support both a 2-dimensional tensor as targets (one set of targets in each
// row) and a 1-dimensional tensor where all targets are concatenated (and we
// use target_lengths to figure out where they begin). We return log_alpha
// (currently, might change to (log_alpha+log_beta) to be passed to the
// backward. The dispatch function will only return the loss.
template <typename scalar_t, ScalarType target_scalar_type>
std::tuple<Tensor, Tensor, Tensor>
best_alignment_gpu_template(const Tensor &log_probs, const Tensor &targets,
IntArrayRef input_lengths,
IntArrayRef target_lengths, int64_t BLANK) {
// log_probs: input_len x batch_size x num_labels
// targets [int64]: batch_size x target_length OR sum(target_lengths)
CheckedFrom c = "ctc_alignment_gpu";
using target_t =
typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
auto log_probs_arg = TensorArg(log_probs, "log_probs", 1);
auto targets_arg = TensorArg(targets, "targets", 2);
checkAllSameGPU(c, {log_probs_arg, targets_arg});
checkScalarType(c, targets_arg, target_scalar_type);
checkDim(c, log_probs_arg, 3);
checkDimRange(c, targets_arg, 1, 3);
int64_t batch_size = log_probs.size(1);
int64_t num_labels = log_probs.size(2);
TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels),
"blank must be in label range");
TORCH_CHECK(input_lengths.size() == batch_size,
"input_lengths must be of size batch_size");
TORCH_CHECK(target_lengths.size() == batch_size,
"target_lengths must be of size batch_size");
int64_t lp_input_stride = log_probs.stride(0);
int64_t lp_char_stride = log_probs.stride(2);
int64_t tg_target_stride;
int64_t max_target_length = 0;
auto tg_batch_offsets =
at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong));
auto tg_batch_offsets_data = tg_batch_offsets.data_ptr<int64_t>();
if (targets.dim() == 1) { // concatenated targets
int64_t pos = 0;
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = pos;
pos += target_lengths[i];
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
}
tg_target_stride = targets.stride(0);
checkSize(c, targets_arg, 0, pos);
} else { // batch x max_target_length
// dim is 2
int64_t tg_batch_stride = targets.stride(0);
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = i * tg_batch_stride;
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
}
tg_target_stride = targets.stride(1);
checkSize(c, targets_arg, 0, batch_size);
TORCH_CHECK(targets.size(1) >= max_target_length,
"Expected tensor to have size at least ", max_target_length,
" at dimension 1, but got size ", targets.size(1), " for ",
targets_arg, " (while checking arguments for ", c, ")");
}
int64_t max_input_length = log_probs.size(0);
for (int64_t b = 0; b < batch_size; b++) {
TORCH_CHECK(input_lengths[b] <= max_input_length,
"Expected input_lengths to have value at most ",
max_input_length, ", but got value ", input_lengths[b],
" (while checking arguments for ", c, ")");
}
auto target_lengths_t =
at::tensor(target_lengths, targets.options().dtype(kLong));
auto input_lengths_t =
at::tensor(input_lengths, targets.options().dtype(kLong));
tg_batch_offsets = tg_batch_offsets.cuda();
Tensor log_alpha =
at::empty({batch_size, log_probs.size(0), 2 * max_target_length + 1},
log_probs.options());
Tensor paths = at::full_like(log_alpha, -1, log_alpha.options().dtype(kLong));
Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());
// Very likely, we could be more clever here, e.g. learning (or genralizing
// and reusing) from SoftMax.cu...
constexpr int max_threads =
std::is_same<scalar_t, float>::value
? 512
: 896; // we need 72 or so 32 bit registers for double
int threads_target = max_threads;
while (threads_target / 2 >= 2 * max_target_length + 1) {
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int)batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid((2 * max_target_length + 1 + threads_target - 1) / threads_target,
(batch_size + threads_batch - 1) / threads_batch);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
ctc_alignment_log_alpha_gpu_kernel<scalar_t, target_t>
<<<grid, block, 0, stream>>>(
log_alpha.data_ptr<scalar_t>(), paths.data_ptr<int64_t>(),
log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(),
log_probs.size(0), targets.data_ptr<target_t>(),
target_lengths_t.data_ptr<int64_t>(), max_target_length,
neg_log_likelihood.data_ptr<scalar_t>(), log_probs.stride(0),
log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0),
log_alpha.stride(1), log_alpha.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, batch_size,
BLANK);
AT_CUDA_CHECK(cudaGetLastError()); // catch launch errors
return std::make_tuple(neg_log_likelihood, log_alpha, paths);
}
std::tuple<Tensor, Tensor, Tensor>
best_alignment_op(const Tensor &log_probs, const Tensor &targets,
IntArrayRef input_lengths, IntArrayRef target_lengths,
int64_t BLANK, bool zero_infinity) {
(void)zero_infinity; // only used for backward
return AT_DISPATCH_FLOATING_TYPES(
log_probs.scalar_type(), "ctc_alignment_cuda", [&] {
if (targets.scalar_type() == kLong) {
return best_alignment_gpu_template<scalar_t, kLong>(
log_probs, targets, input_lengths, target_lengths, BLANK);
} else {
return best_alignment_gpu_template<scalar_t, kInt>(
log_probs, targets, input_lengths, target_lengths, BLANK);
}
});
}
#include <torch/extension.h>
#include <tuple>
std::tuple<torch::Tensor, torch::Tensor>
imputer_loss_op(const torch::Tensor &log_probs, const torch::Tensor &targets,
const torch::Tensor &force_emits, at::IntArrayRef input_lengths,
at::IntArrayRef target_lengths, int64_t BLANK,
bool zero_infinity);
torch::Tensor imputer_loss_backward_op(
const torch::Tensor &grad, const torch::Tensor &log_probs,
const torch::Tensor &targets, const torch::Tensor &force_emits,
at::IntArrayRef input_lengths, at::IntArrayRef target_lengths,
const torch::Tensor &neg_log_likelihood, const torch::Tensor &log_alpha,
int64_t BLANK, bool zero_infinity);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
best_alignment_op(const torch::Tensor &log_probs, const torch::Tensor &targets,
at::IntArrayRef input_lengths, at::IntArrayRef target_lengths,
int64_t BLANK, bool zero_infinity);
std::tuple<torch::Tensor, torch::Tensor> imputer_loss(
const torch::Tensor &log_probs, const torch::Tensor &targets,
const torch::Tensor &force_emits, const torch::Tensor &input_lengths,
const torch::Tensor &target_lengths, int64_t BLANK, bool zero_infinity) {
torch::Tensor ilc =
input_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
torch::Tensor tlc =
target_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
at::IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
at::IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
auto res =
imputer_loss_op(log_probs, targets.to(log_probs.device(), at::kLong),
force_emits.to(log_probs.device(), at::kLong), il, tl,
BLANK, zero_infinity);
return res;
}
torch::Tensor imputer_loss_backward(
const torch::Tensor &grad, const torch::Tensor &log_probs,
const torch::Tensor &targets, const torch::Tensor &force_emits,
const torch::Tensor &input_lengths, const torch::Tensor &target_lengths,
const torch::Tensor &neg_log_likelihood, const torch::Tensor &log_alpha,
int64_t BLANK, bool zero_infinity) {
torch::Tensor ilc =
input_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
torch::Tensor tlc =
target_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
at::IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
at::IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
torch::Tensor res;
res = imputer_loss_backward_op(
grad, log_probs, targets.to(log_probs.device(), at::kLong),
force_emits.to(log_probs.device(), at::kLong), il, tl, neg_log_likelihood,
log_alpha, BLANK, zero_infinity);
return res;
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
best_alignment(const torch::Tensor &log_probs, const torch::Tensor &targets,
const torch::Tensor &input_lengths,
const torch::Tensor &target_lengths, int64_t BLANK,
bool zero_infinity) {
torch::Tensor ilc =
input_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
torch::Tensor tlc =
target_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
at::IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
at::IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
auto res =
best_alignment_op(log_probs, targets.to(log_probs.device(), at::kLong),
il, tl, BLANK, zero_infinity);
return res;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("imputer_loss", &imputer_loss, "calculate imputer loss");
m.def("imputer_loss_backward", &imputer_loss_backward,
"calculate imputer loss gradient");
m.def("best_alignment", &best_alignment, "get best alignments for ctc");
}
\ No newline at end of file
// Copyright (c) 2018 MathInf GmbH, Thomas Viehmann
// Licensed under the BSD-3-Clause license
// This is the GPU implementation of the Connectionist Temporal Loss.
// We mostly follow Graves.
// 1. Graves et al: http://www.cs.toronto.edu/~graves/icml_2006.pdf
// We use the equations from above link, but note that [1] has 1-based indexing
// and we (of course) use 0-based. Graves et al call the probabilities y, we use
// log_probs (also calling them inputs) A few optimizations (simmilar to those
// here, but also some I didn't take) are described in
// 2. Minmin Sun:
// http://on-demand.gputechconf.com/gtc/2016/presentation/s6383-minmin-sun-speech-recognition.pdf
#include <ATen/TensorUtils.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <numeric>
#include <type_traits>
using namespace at;
// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
// so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
// l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
// - note that no bound-checking is done
// - it is important to only call it witth idx == 0 if the target length is 0
// - __restrict__ impact to be measured, see
// https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
template <typename target_t>
__device__ static inline int64_t
get_target_prime(const target_t *__restrict__ target, int64_t offset,
int64_t stride, int64_t idx, int64_t BLANK) {
if (idx % 2 == 0) {
return BLANK;
} else {
return target[offset + stride * (idx / 2)];
}
}
// this kernel is a relatively straightforward implementation of the alpha
// calculation in the forward backward algorithm (section 4.1). A (minor) twist
// is that we are using log-calculations to enhance numerical stability
// (log_probs and log_alpha). In total it would be more efficient to compute the
// beta in the same kernel (e.g. cudnn does this). While the beta are not needed
// for the loss itself (just the grad), we can return log_alpha+log_beta (so
// same space as currently) and the overhead is small and the use-case for loss
// without grad is relatively limited. We parallelize by batch and target
// sequence. Empirically, it is faster to loop over the input (log probs)
// sequence and do target in parallel, even if it means more frequent
// __syncthreads. In contrast to the cuDNN implementation, we allow large target
// lengths. For this we need that all previous `s` have been computed when we
// start a new block_s. This is why we have our own for loop here.
template <typename scalar_t, typename target_t>
__global__ void
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
imputer_loss_log_alpha_gpu_kernel(
scalar_t *__restrict__ log_alpha_data, const scalar_t *log_probs_data,
const int64_t *__restrict__ input_lengths, int64_t max_input_length,
const target_t *__restrict__ targets_data,
const int64_t *__restrict__ target_lengths, int64_t max_target_length,
const target_t *__restrict__ force_emits_data,
scalar_t *__restrict__ neg_log_likelihood_data, int64_t lp_input_stride,
int64_t lp_batch_stride, int64_t lp_char_stride,
int64_t fe_batch_stride, int64_t la_batch_stride,
int64_t la_input_stride, int64_t la_target_stride,
const int64_t *__restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, int64_t BLANK) {
constexpr scalar_t neginf = -INFINITY;
// bookkeeping
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t lp_batch_offset = b * lp_batch_stride;
int64_t la_batch_offset = b * la_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
int64_t fe_batch_offset = b * fe_batch_stride;
if (b >= batch_size)
return;
target_t force_emit;
// first row (t=0), the three equations for alpha_1 above eq (6)
for (int64_t block_s = 0; block_s < 2 * max_target_length + 1;
block_s += blockDim.x) {
int64_t s = threadIdx.x + block_s;
scalar_t la;
switch (s) {
case 0:
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
break;
case 1:
la = target_length == 0
? neginf
: log_probs_data[lp_batch_offset +
lp_char_stride *
get_target_prime(
targets_data, tg_batch_offset,
tg_target_stride, 1, BLANK)];
break;
default:
la = neginf;
}
force_emit = force_emits_data[fe_batch_offset];
if (force_emit > -1 && force_emit != s) {
la = neginf;
}
if (s < 2 * max_target_length + 1)
log_alpha_data[la_batch_offset +
/* la_input_stride * 0 */ +la_target_stride * s] = la;
}
for (int64_t block_s = 0; block_s < 2 * max_target_length + 1;
block_s += blockDim.x) {
int64_t s = threadIdx.x + block_s;
// These two only depend on s, so we can cache them.
int64_t current_char; // l_s in eq (6)
bool have_three; // flag which of the two cases in eq (6) we have
if (s < 2 * target_length + 1 && target_length > 0) {
current_char = get_target_prime(targets_data, tg_batch_offset,
tg_target_stride, s, BLANK);
have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset,
tg_target_stride, s - 2,
BLANK) != current_char));
} else {
current_char = BLANK;
have_three = false;
}
for (int64_t t = 1; t < max_input_length; t++) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only
// the threads within the same batch
if ((t < input_length) && (s < 2 * target_length + 1)) {
force_emit = force_emits_data[fe_batch_offset + t];
if (force_emit > -1 && force_emit != s) {
log_alpha_data[la_batch_offset + la_input_stride * t +
la_target_stride * s] = neginf;
continue;
}
// only for valid t, s. This is equation (6) and (7), la1, la2, la3 are
// the three summands, lamax is the maximum for the logsumexp trick.
scalar_t la1 =
log_alpha_data[la_batch_offset + la_input_stride * (t - 1) +
la_target_stride * s];
scalar_t lamax = la1;
scalar_t la2, la3;
if (s > 0) {
la2 = log_alpha_data[la_batch_offset + la_input_stride * (t - 1) +
la_target_stride * (s - 1)];
if (la2 > lamax)
lamax = la2;
} else {
la2 = neginf;
}
if (have_three) {
la3 = log_alpha_data[la_batch_offset + la_input_stride * (t - 1) +
la_target_stride * (s - 2)];
if (la3 > lamax)
lamax = la3;
} else {
la3 = neginf;
}
if (lamax == neginf) // when all are neginf. (then the whole thing is
// neginf, but we can pretend)
lamax = 0;
log_alpha_data[la_batch_offset + la_input_stride * t +
la_target_stride * s] =
std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) +
std::exp(la3 - lamax)) +
lamax +
log_probs_data[lp_batch_offset + t * lp_input_stride +
lp_char_stride * current_char];
} else {
// otherwise we just set to neginf
if (s < 2 * max_target_length + 1)
log_alpha_data[la_batch_offset + la_input_stride * t +
la_target_stride * s] = neginf;
}
}
}
__syncthreads(); // on cuda 9 we might use partial synchronization of only the
// threads within the same batch
// compute the loss (eq (8))
if (threadIdx.x == 0) {
scalar_t l1 =
log_alpha_data[la_batch_offset + la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2)];
scalar_t l2 =
target_length > 0
? log_alpha_data[la_batch_offset +
la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2 - 1)]
: neginf;
scalar_t m = ((l1 > l2) ? l1 : l2);
m = ((m == neginf) ? 0 : m);
scalar_t log_likelihood = std::log(std::exp(l1 - m) + std::exp(l2 - m)) + m;
neg_log_likelihood_data[b] = -log_likelihood;
}
}
// The forward computation. Lot's of admin and a call to the alpha kernel.
// Note: we do not check that the labels are in the valid range. As we use
// them for indexing in the kernels, you'll see memory errors when you
// pass corrupt labels.
// We support both a 2-dimensional tensor as targets (one set of targets in each
// row) and a 1-dimensional tensor where all targets are concatenated (and we
// use target_lengths to figure out where they begin). We return log_alpha
// (currently, might change to (log_alpha+log_beta) to be passed to the
// backward. The dispatch function will only return the loss.
template <typename scalar_t, ScalarType target_scalar_type>
std::tuple<Tensor, Tensor>
imputer_loss_gpu_template(const Tensor &log_probs, const Tensor &targets,
const Tensor &force_emits, IntArrayRef input_lengths,
IntArrayRef target_lengths, int64_t BLANK) {
// log_probs: input_len x batch_size x num_labels
// targets [int64]: batch_size x target_length OR sum(target_lengths)
CheckedFrom c = "imputer_loss_gpu";
using target_t =
typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
auto log_probs_arg = TensorArg(log_probs, "log_probs", 1);
auto targets_arg = TensorArg(targets, "targets", 2);
checkAllSameGPU(c, {log_probs_arg, targets_arg});
checkScalarType(c, targets_arg, target_scalar_type);
checkDim(c, log_probs_arg, 3);
checkDimRange(c, targets_arg, 1, 3);
int64_t batch_size = log_probs.size(1);
int64_t num_labels = log_probs.size(2);
TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels),
"blank must be in label range");
TORCH_CHECK(input_lengths.size() == batch_size,
"input_lengths must be of size batch_size");
TORCH_CHECK(target_lengths.size() == batch_size,
"target_lengths must be of size batch_size");
int64_t lp_input_stride = log_probs.stride(0);
int64_t lp_char_stride = log_probs.stride(2);
int64_t tg_target_stride;
int64_t max_target_length = 0;
auto tg_batch_offsets =
at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong));
auto tg_batch_offsets_data = tg_batch_offsets.data_ptr<int64_t>();
if (targets.dim() == 1) { // concatenated targets
int64_t pos = 0;
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = pos;
pos += target_lengths[i];
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
}
tg_target_stride = targets.stride(0);
checkSize(c, targets_arg, 0, pos);
} else { // batch x max_target_length
// dim is 2
int64_t tg_batch_stride = targets.stride(0);
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = i * tg_batch_stride;
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
}
tg_target_stride = targets.stride(1);
checkSize(c, targets_arg, 0, batch_size);
TORCH_CHECK(targets.size(1) >= max_target_length,
"Expected tensor to have size at least ", max_target_length,
" at dimension 1, but got size ", targets.size(1), " for ",
targets_arg, " (while checking arguments for ", c, ")");
}
int64_t max_input_length = log_probs.size(0);
for (int64_t b = 0; b < batch_size; b++) {
TORCH_CHECK(input_lengths[b] <= max_input_length,
"Expected input_lengths to have value at most ",
max_input_length, ", but got value ", input_lengths[b],
" (while checking arguments for ", c, ")");
}
auto target_lengths_t =
at::tensor(target_lengths, targets.options().dtype(kLong));
auto input_lengths_t =
at::tensor(input_lengths, targets.options().dtype(kLong));
tg_batch_offsets = tg_batch_offsets.cuda();
Tensor log_alpha =
at::empty({batch_size, log_probs.size(0), 2 * max_target_length + 1},
log_probs.options());
Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());
// Very likely, we could be more clever here, e.g. learning (or genralizing
// and reusing) from SoftMax.cu...
constexpr int max_threads =
std::is_same<scalar_t, float>::value
? 1024
: 896; // we need 72 or so 32 bit registers for double
int threads_target = max_threads;
while (threads_target / 2 >= 2 * max_target_length + 1) {
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int)batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid((2 * max_target_length + 1 + threads_target - 1) / threads_target,
(batch_size + threads_batch - 1) / threads_batch);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
imputer_loss_log_alpha_gpu_kernel<scalar_t, target_t>
<<<grid, block, 0, stream>>>(
log_alpha.data_ptr<scalar_t>(), log_probs.data_ptr<scalar_t>(),
input_lengths_t.data_ptr<int64_t>(), log_probs.size(0),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(),
max_target_length, force_emits.data_ptr<target_t>(),
neg_log_likelihood.data_ptr<scalar_t>(), log_probs.stride(0),
log_probs.stride(1), log_probs.stride(2), force_emits.stride(0),
log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, batch_size,
BLANK);
C10_CUDA_CHECK(cudaGetLastError()); // catch launch errors
return std::make_tuple(neg_log_likelihood, log_alpha);
}
// The second (backward) half of the forward backward algorithm, (10) and (11).
// This is parallel to the alpha kernel above. (As mentioned above, it might
// make sense do the calculation in the alpha kernel.)
template <typename scalar_t, typename target_t>
__global__ void
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
imputer_loss_backward_log_beta_gpu_kernel(
scalar_t *__restrict__ log_beta_data, const scalar_t *log_probs_data,
const int64_t *__restrict__ input_lengths, int64_t max_input_length,
const target_t *__restrict__ targets_data,
const int64_t *__restrict__ target_lengths, int64_t max_target_length,
const target_t *__restrict__ force_emits_data, int64_t lp_input_stride,
int64_t lp_batch_stride, int64_t lp_char_stride,
int64_t fe_batch_stride, int64_t lb_batch_stride,
int64_t lb_input_stride, int64_t lb_target_stride,
const int64_t *__restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, int64_t BLANK) {
constexpr scalar_t neginf = -INFINITY;
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t lp_batch_offset = b * lp_batch_stride;
int64_t lb_batch_offset = b * lb_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
int64_t fe_batch_offset = b * fe_batch_stride;
if (b >= batch_size)
return;
target_t force_emit;
// "first" row, the beta initiaization before eq (10) (t=target_length -
// differes per batch)
for (int64_t block_s =
2 * max_target_length - (2 * max_target_length % blockDim.x);
block_s >= 0; block_s -= blockDim.x) {
int64_t s = threadIdx.x + block_s;
scalar_t lb;
if (s == 2 * target_length) {
lb = log_probs_data[lp_batch_offset +
(input_length - 1) * lp_input_stride +
lp_char_stride * BLANK];
} else if (s == 2 * target_length - 1) { // false for target_length == 0
int64_t current_target_prime = get_target_prime(
targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
lb = log_probs_data[lp_batch_offset +
(input_length - 1) * lp_input_stride +
lp_char_stride * current_target_prime];
} else {
lb = neginf;
}
force_emit = force_emits_data[fe_batch_offset + (input_length - 1)];
if (force_emit > -1 && force_emit != s) {
lb = neginf;
}
if (s < 2 * max_target_length + 1) {
log_beta_data[lb_batch_offset + (input_length - 1) * lb_input_stride +
lb_target_stride * s] = lb;
}
}
// go backward in s
for (int64_t block_s =
2 * max_target_length - (2 * max_target_length % blockDim.x);
block_s >= 0; block_s -= blockDim.x) {
int64_t s = threadIdx.x + block_s;
int64_t current_target_prime;
bool have_three;
if (s < 2 * target_length + 1 && target_length > 0) {
current_target_prime = get_target_prime(targets_data, tg_batch_offset,
tg_target_stride, s, BLANK);
have_three =
((s < 2 * target_length - 1) &&
(get_target_prime(targets_data, tg_batch_offset, tg_target_stride,
s + 2, BLANK) != current_target_prime));
} else {
current_target_prime = BLANK;
have_three = false;
}
// now go backward in t. Note that we need to skip the last timestep that we
// did above.
for (int64_t t = max_input_length - 2; t >= 0; t--) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only
// the threads within the same batch item
if ((t < input_length - 1) && (s < 2 * target_length + 1)) {
force_emit = force_emits_data[fe_batch_offset + t];
if ((force_emit > -1) && (force_emit != s)) {
log_beta_data[lb_batch_offset + lb_input_stride * t +
lb_target_stride * s] = neginf;
continue;
}
scalar_t lb1 =
log_beta_data[lb_batch_offset + lb_input_stride * (t + 1) +
lb_target_stride * s];
scalar_t lbmax = lb1;
scalar_t lb2, lb3;
if (s < 2 * target_length) {
lb2 = log_beta_data[lb_batch_offset + lb_input_stride * (t + 1) +
lb_target_stride * (s + 1)];
if (lb2 > lbmax)
lbmax = lb2;
} else {
lb2 = neginf;
}
if (have_three) {
lb3 = log_beta_data[lb_batch_offset + lb_input_stride * (t + 1) +
lb_target_stride * (s + 2)];
if (lb3 > lbmax)
lbmax = lb3;
} else {
lb3 = neginf;
}
if (lbmax == neginf)
lbmax = 0;
scalar_t lb = std::log(std::exp(lb1 - lbmax) + std::exp(lb2 - lbmax) +
std::exp(lb3 - lbmax)) +
lbmax +
log_probs_data[lp_batch_offset + t * lp_input_stride +
lp_char_stride * current_target_prime];
log_beta_data[lb_batch_offset + lb_input_stride * t +
lb_target_stride * s] = lb;
} else if ((s < 2 * max_target_length + 1) &&
(((target_length == 0) && (s > 0)) ||
(s >= 2 * target_length + 1) || (t >= input_length))) {
log_beta_data[lb_batch_offset + lb_input_stride * t +
lb_target_stride * s] = neginf;
}
}
}
}
// This implements the subtrahend of equation (16) for all *nonblank*
// characters. It assumes you have probs in gradient_data when called and it
// modifies gradient_data to be, the gradient. In order to facilitate this
// inplace update, We don't actually do this in logspace. (The other variant
// implemented uses log_space and the differences seem to be
// not so problematic at least with unit normal distributed test activations.)
// Internally this uses atomicAdd because different threads may write to the
// same gradient position. This is parallelised over b and s again. Note that
// for us, the Z of eqn (16) is actually constant for all t and it is the
// likelihood - this is why we use the negative log likelihood below.
// We also multiply by the input gradient to keep with standard autograd style.
// I took this trick from [2], for moderate alphabet sizes a log-space
// calculation (with an atomic log add) is similarly in performance, but for
// large alphabets the inplace nature is a considerable advantage.
template <typename scalar_t, typename target_t>
__global__ void
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
imputer_loss_backward_collect_nonblank_gpu_kernel(
scalar_t *__restrict__ gradient_data,
const scalar_t *__restrict__ grad_out_data,
int64_t grad_out_batch_stride,
const scalar_t *__restrict__ log_alpha_data,
const scalar_t *__restrict__ log_beta_data,
const scalar_t *log_probs_data,
const int64_t *__restrict__ input_lengths, int64_t max_input_length,
const target_t *__restrict__ targets_data,
const int64_t *__restrict__ target_lengths, int64_t max_target_length,
const scalar_t *__restrict__ neg_log_likelihood_data,
int64_t gr_input_stride, int64_t gr_batch_stride,
int64_t gr_char_stride, int64_t lp_input_stride,
int64_t lp_batch_stride, int64_t lp_char_stride,
int64_t la_batch_stride, int64_t la_input_stride,
int64_t la_target_stride, int64_t lb_batch_stride,
int64_t lb_input_stride, int64_t lb_target_stride,
const int64_t *__restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, int64_t num_labels, int64_t BLANK,
bool zero_infinity) {
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t s =
threadIdx.x + blockIdx.x * blockDim.x; // note, this directly indexes into
// targets, not targets prime!
if (b >= batch_size)
return;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t gr_batch_offset = b * gr_batch_stride;
int64_t lp_batch_offset = b * lp_batch_stride;
int64_t la_batch_offset = b * la_batch_stride;
int64_t lb_batch_offset = b * lb_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
if (s >= target_length)
return;
int64_t target = targets_data[tg_batch_offset + s * tg_target_stride];
scalar_t nll = neg_log_likelihood_data[b];
scalar_t gr = grad_out_data[b * grad_out_batch_stride];
if (zero_infinity && nll == INFINITY)
return;
for (int64_t t = 0; t < input_length; t++) {
scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride +
lp_char_stride * target];
atomicAdd(&gradient_data[gr_batch_offset + t * gr_input_stride +
gr_char_stride * target],
-std::exp(log_alpha_data[la_batch_offset + la_input_stride * t +
la_target_stride * (s * 2 + 1)] +
log_beta_data[lb_batch_offset + lb_input_stride * t +
lb_target_stride * (s * 2 + 1)] +
nll - lp) *
gr);
}
}
// This is the naive implementation of equation (16). It is parallelised in
// batch and input timestep. It appears to be faster than the above method for
// small batch sizes.
template <typename scalar_t, typename target_t>
__global__ void
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
imputer_loss_backward_collect_gpu_kernel(
scalar_t *__restrict__ gradient_data,
const scalar_t *__restrict__ grad_out_data,
int64_t grad_out_batch_stride,
const scalar_t *__restrict__ log_alpha_data,
const scalar_t *__restrict__ log_beta_data,
const scalar_t *log_probs_data,
const int64_t *__restrict__ input_lengths, int64_t max_input_length,
const target_t *__restrict__ targets_data,
const int64_t *__restrict__ target_lengths, int64_t max_target_length,
const scalar_t *__restrict__ neg_log_likelihood_data,
int64_t gr_input_stride, int64_t gr_batch_stride,
int64_t gr_char_stride, int64_t lp_input_stride,
int64_t lp_batch_stride, int64_t lp_char_stride,
int64_t la_batch_stride, int64_t la_input_stride,
int64_t la_target_stride, int64_t lb_batch_stride,
int64_t lb_input_stride, int64_t lb_target_stride,
const int64_t *__restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, int64_t num_labels, int64_t BLANK,
bool zero_infinity) {
constexpr scalar_t neginf = -INFINITY;
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t t = threadIdx.x + blockIdx.x * blockDim.x;
if ((t >= max_input_length) || (b >= batch_size))
return;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t gr_batch_offset = b * gr_batch_stride;
int64_t lp_batch_offset = b * lp_batch_stride;
int64_t la_batch_offset = b * la_batch_stride;
int64_t lb_batch_offset = b * lb_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
// collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
for (int s = 0; s < 2 * max_target_length + 1; s++) {
if (s < 2 * target_length + 1) { // if target_length == 0, s == 0
int64_t current_target_prime = get_target_prime(
targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
scalar_t log_alpha_beta =
(log_alpha_data[la_batch_offset + la_input_stride * t +
la_target_stride * s] +
log_beta_data[lb_batch_offset + lb_input_stride * t +
lb_target_stride * s]);
scalar_t &lcab = gradient_data[gr_batch_offset + t * gr_input_stride +
gr_char_stride * current_target_prime];
if (lcab == neginf) {
lcab = log_alpha_beta;
} else {
scalar_t max = ((lcab > log_alpha_beta) ? lcab : log_alpha_beta);
lcab = std::log(std::exp(lcab - max) + std::exp(log_alpha_beta - max)) +
max;
}
}
}
scalar_t nll = neg_log_likelihood_data[b];
scalar_t gr = grad_out_data[b * grad_out_batch_stride];
for (int64_t c = 0; c < num_labels; c++) {
scalar_t &res = gradient_data[gr_batch_offset + t * gr_input_stride +
gr_char_stride * c];
if (t < input_length && (!zero_infinity || nll != INFINITY)) {
scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride +
lp_char_stride * c];
res = (std::exp(lp) - std::exp(res + nll - lp)) * gr;
} else {
res = 0.;
}
}
}
// This is to zero gradients which corresponding to the out-of-sequence position
// Those gradients should not be used in any model update since the input
// elements are padded
template <typename scalar_t>
__global__ void
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
imputer_loss_zero_padded_gradients(
scalar_t *__restrict__ gradient_data, /* (T, B, D) layout */
const int64_t *__restrict__ input_lengths, /* (B, ) layout */
int64_t gr_timestep_stride, int64_t gr_batch_stride,
int64_t gr_label_stride, int64_t max_input_length, /* T */
int64_t batch_size, /* B */
int64_t num_labels /* D */) {
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (b >= batch_size || t >= max_input_length) {
return;
}
scalar_t input_length = input_lengths[b];
if (t >= input_length) {
for (int l = 0; l < num_labels; l++)
gradient_data[t * gr_timestep_stride + b * gr_batch_stride +
l * gr_label_stride] = 0.0f;
}
}
// The backward. It essentially computes eq 16 by using the above kernels.
// We don't do a lot of checking as we envision this to be called only when
// backpropagating through a (well-checked) forward.
template <typename scalar_t, ScalarType target_scalar_type>
Tensor imputer_loss_backward_gpu_template(
const Tensor &grad_out, const Tensor &log_probs, const Tensor &targets,
const Tensor &force_emits, IntArrayRef input_lengths,
IntArrayRef target_lengths, const Tensor &neg_log_likelihood,
const Tensor &log_alpha, int64_t BLANK, bool zero_infinity) {
constexpr scalar_t neginf = -INFINITY;
using target_t =
typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
int64_t batch_size = log_probs.size(1);
int64_t num_labels = log_probs.size(2);
int64_t lp_input_stride = log_probs.stride(0);
int64_t lp_char_stride = log_probs.stride(2);
int64_t tg_target_stride;
int64_t max_target_length;
auto tg_batch_offsets =
at::empty({batch_size}, TensorOptions(at::CPU(kLong)));
auto tg_batch_offsets_data = tg_batch_offsets.data_ptr<int64_t>();
if (targets.dim() == 1) { // concatenated targets
int64_t pos = 0;
max_target_length = 0;
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = pos;
pos += target_lengths[i];
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
}
tg_target_stride = targets.stride(0);
} else { // batch x max_target_length
// dim is 2
int64_t tg_batch_stride = targets.stride(0);
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = i * tg_batch_stride;
}
tg_target_stride = targets.stride(1);
max_target_length =
log_alpha.size(2) / 2; // targets.size(1) might be larger
}
auto target_lengths_t =
at::tensor(target_lengths, targets.options().dtype(kLong));
auto input_lengths_t =
at::tensor(input_lengths, targets.options().dtype(kLong));
tg_batch_offsets = tg_batch_offsets.cuda();
Tensor log_beta = at::empty_like(log_alpha, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
log_beta.fill_(neginf);
Tensor grad =
at::full_like(log_probs, neginf,
LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for
// log(sum (alpha beta))
// As above, there may be better configurations to use.
constexpr int max_threads =
std::is_same<scalar_t, float>::value
? 1024
: 896; // we need 72 or so 32 bit registers for double
int threads_target = max_threads;
while (threads_target / 2 >= 2 * max_target_length + 1) {
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int)batch_size);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
{
dim3 block(threads_target, threads_batch);
dim3 grid((2 * max_target_length + 1 + threads_target - 1) / threads_target,
(batch_size + threads_batch - 1) / threads_batch);
imputer_loss_backward_log_beta_gpu_kernel<scalar_t, target_t>
<<<grid, block, 0, stream>>>(
log_beta.data_ptr<scalar_t>(), log_probs.data_ptr<scalar_t>(),
input_lengths_t.data_ptr<int64_t>(), log_probs.size(0),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(),
max_target_length, force_emits.data_ptr<target_t>(),
log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
force_emits.stride(0), log_beta.stride(0), log_beta.stride(1),
log_beta.stride(2), tg_batch_offsets.data_ptr<int64_t>(),
tg_target_stride, batch_size, BLANK);
C10_CUDA_CHECK(cudaGetLastError()); // catch launch errors
}
// Very crude heuristic for what is a small problem., based on linearly
// regressing problem dimensions on the (capped) difference of timings. Note
// that for OK problems target length <= input length, so we only consider
// input length.
bool is_large = (2 * log_probs.size(0) + (24 * batch_size) / 10 +
(2 * num_labels) / 10) > 450;
if (is_large) { // large alphabet, large batch
// this computes the probs, minuend in (16)
exp_out(grad, log_probs);
// now we compute the subtrahend for the blanks. It is a straightforward
// reduction because we know that blanks are in every other position. maybe
// we should kernelize this, too.
auto grad_blank = grad.narrow(2, BLANK, 1);
grad_blank -=
(at::logsumexp(
log_alpha.as_strided(
{batch_size, log_alpha.size(1), max_target_length + 1},
{log_alpha.stride(0), log_alpha.stride(1),
log_alpha.stride(2) * 2}) +
log_beta.as_strided(
{batch_size, log_beta.size(1), max_target_length + 1},
{log_beta.stride(0), log_beta.stride(1),
log_beta.stride(2) * 2}),
2, true)
.permute({1, 0, 2})
.add_(neg_log_likelihood.view({1, batch_size, 1}))
.sub_(log_probs.narrow(2, BLANK, 1))
.exp_());
// scale by output gradient (blanks and first summand of non-blanks)
grad *= grad_out.view({1, batch_size, 1});
if (zero_infinity) {
grad = at::where(neg_log_likelihood.view({1, batch_size, 1}) ==
Scalar(INFINITY),
at::zeros({}, grad.options()), grad);
}
// For the non-blank characters, we use a kernel to compute the subtrahend.
// Again we might configure block and grid in a better way.
int threads_target = max_threads;
while (threads_target / 2 >= max_target_length && threads_target > 1) {
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int)batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid(std::max<int>(
(max_target_length + threads_target - 1) / threads_target, 1),
(batch_size + threads_batch - 1) / threads_batch, 1);
imputer_loss_backward_collect_nonblank_gpu_kernel<scalar_t, target_t>
<<<grid, block, 0, stream>>>(
grad.data_ptr<scalar_t>(), grad_out.data_ptr<scalar_t>(),
grad_out.stride(0), log_alpha.data_ptr<scalar_t>(),
log_beta.data_ptr<scalar_t>(), log_probs.data_ptr<scalar_t>(),
input_lengths_t.data_ptr<int64_t>(), log_probs.size(0),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(),
max_target_length, neg_log_likelihood.data_ptr<scalar_t>(),
grad.stride(0), grad.stride(1), grad.stride(2), log_probs.stride(0),
log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0),
log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0),
log_beta.stride(1), log_beta.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, batch_size,
num_labels, BLANK, zero_infinity);
C10_CUDA_CHECK(cudaGetLastError()); // catch launch errors
} else { // small problem, use naive algorithm
// Still no block/grid configuration guru...
int threads_input = max_threads;
while (threads_input / 2 >= log_probs.size(0) && threads_input > 1) {
threads_input /= 2;
}
threads_batch = std::min(max_threads / threads_input, (int)batch_size);
dim3 block(threads_input, threads_batch);
dim3 grid((log_probs.size(0) + threads_input - 1) / threads_input,
(batch_size + threads_batch - 1) / threads_batch);
imputer_loss_backward_collect_gpu_kernel<scalar_t, target_t>
<<<grid, block, 0, stream>>>(
grad.data_ptr<scalar_t>(), grad_out.data_ptr<scalar_t>(),
grad_out.stride(0), log_alpha.data_ptr<scalar_t>(),
log_beta.data_ptr<scalar_t>(), log_probs.data_ptr<scalar_t>(),
input_lengths_t.data_ptr<int64_t>(), log_probs.size(0),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(),
max_target_length, neg_log_likelihood.data_ptr<scalar_t>(),
grad.stride(0), grad.stride(1), grad.stride(2), log_probs.stride(0),
log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0),
log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0),
log_beta.stride(1), log_beta.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, batch_size,
num_labels, BLANK, zero_infinity);
C10_CUDA_CHECK(cudaGetLastError()); // catch launch errors
}
// zero those invalid graident elements due to padding
{
int threads_input = max_threads;
while (threads_input / 2 >= log_probs.size(0)) {
threads_input /= 2;
}
threads_batch = std::min(max_threads / threads_input, (int)batch_size);
dim3 block(threads_input, threads_batch);
dim3 grid((log_probs.size(0) + threads_input - 1) / threads_input,
(batch_size + threads_batch - 1) / threads_batch);
imputer_loss_zero_padded_gradients<scalar_t><<<grid, block, 0, stream>>>(
grad.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(),
grad.stride(0), grad.stride(1), grad.stride(2), grad.size(0),
grad.size(1), grad.size(2));
C10_CUDA_CHECK(cudaGetLastError());
}
return grad;
}
std::tuple<Tensor, Tensor>
imputer_loss_op(const Tensor &log_probs, const Tensor &targets,
const Tensor &force_emits, IntArrayRef input_lengths,
IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) {
(void)zero_infinity; // only used for backward
return AT_DISPATCH_FLOATING_TYPES(
log_probs.scalar_type(), "imputer_loss_cuda", [&] {
if (targets.scalar_type() == kLong) {
return imputer_loss_gpu_template<scalar_t, kLong>(
log_probs, targets, force_emits, input_lengths, target_lengths,
BLANK);
} else {
return imputer_loss_gpu_template<scalar_t, kInt>(
log_probs, targets, force_emits, input_lengths, target_lengths,
BLANK);
}
});
}
Tensor imputer_loss_backward_op(
const Tensor &grad, const Tensor &log_probs, const Tensor &targets,
const Tensor &force_emits, IntArrayRef input_lengths,
IntArrayRef target_lengths, const Tensor &neg_log_likelihood,
const Tensor &log_alpha, int64_t BLANK, bool zero_infinity) {
return AT_DISPATCH_FLOATING_TYPES(
log_probs.scalar_type(), "imputer_loss_backward_cuda", [&] {
if (targets.scalar_type() == kLong) {
return imputer_loss_backward_gpu_template<scalar_t, kLong>(
grad, log_probs, targets, force_emits, input_lengths,
target_lengths, neg_log_likelihood, log_alpha, BLANK,
zero_infinity);
} else {
return imputer_loss_backward_gpu_template<scalar_t, kInt>(
grad, log_probs, targets, force_emits, input_lengths,
target_lengths, neg_log_likelihood, log_alpha, BLANK,
zero_infinity);
}
});
}
import os
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
imputer = load(
"imputer_fn",
sources=[
os.path.join(module_path, "imputer.cpp"),
os.path.join(module_path, "imputer.cu"),
os.path.join(module_path, "best_alignment.cu"),
],
)
class ImputerLossFunction(Function):
@staticmethod
def forward(
ctx,
log_prob,
targets,
force_emits,
input_lengths,
target_lengths,
blank,
zero_infinity,
):
input_lengths = input_lengths.to("cpu", dtype=torch.int64)
target_lengths = target_lengths.to("cpu", dtype=torch.int64)
loss, log_alpha = imputer.imputer_loss(
log_prob,
targets,
force_emits,
input_lengths,
target_lengths,
blank,
zero_infinity,
)
ctx.save_for_backward(
log_prob,
targets,
force_emits,
input_lengths,
target_lengths,
loss,
log_alpha,
)
ctx.blank = blank
ctx.zero_infinity = zero_infinity
return loss
@staticmethod
def backward(ctx, grad_output):
log_prob, targets, force_emits, input_lengths, target_lengths, loss, log_alpha = (
ctx.saved_tensors
)
blank = ctx.blank
zero_infinity = ctx.zero_infinity
grad_input = imputer.imputer_loss_backward(
grad_output,
log_prob,
targets,
force_emits,
input_lengths,
target_lengths,
loss,
log_alpha,
blank,
zero_infinity,
)
return grad_input, None, None, None, None, None, None
imputer_loss_fn = ImputerLossFunction.apply
def imputer_loss(
log_prob,
targets,
force_emits,
input_lengths,
target_lengths,
blank=0,
reduction="mean",
zero_infinity=False,
):
"""The Imputer loss
Parameters:
log_prob (T, N, C): C = number of characters in alphabet including blank
T = input length
N = batch size
log probability of the outputs (e.g. torch.log_softmax of logits)
targets (N, S): S = maximum number of characters in target sequences
force_emits (N, T): sequence of ctc states that should be occur given times
that is, if force_emits is state s at time t, only ctc paths
that pass state s at time t will be enabled, and will be zero out the rest
this will be same as using cross entropy loss at time t
value should be in range [-1, 2 * S + 1), valid ctc states
-1 will means that it could be any states at time t (normal ctc paths)
input_lengths (N): lengths of log_prob
target_lengths (N): lengths of targets
blank (int): index of blank tokens (default 0)
reduction (str): reduction methods applied to the output. 'none' | 'mean' | 'sum'
zero_infinity (bool): if true imputer loss will zero out infinities.
infinities mostly occur when it is impossible to generate
target sequences using input sequences
(e.g. input sequences are shorter than target sequences)
"""
loss = imputer_loss_fn(
log_prob,
targets,
force_emits,
input_lengths,
target_lengths,
blank,
zero_infinity,
)
input_lengths = input_lengths.to("cpu", dtype=torch.int64)
target_lengths = target_lengths.to("cpu", dtype=torch.int64)
if zero_infinity:
inf = float("inf")
loss = torch.where(loss == inf, loss.new_zeros(1), loss)
if reduction == "mean":
target_length = target_lengths.to(loss).clamp(min=1)
return (loss / target_length).mean()
elif reduction == "sum":
return loss.sum()
elif reduction == "none":
return loss
else:
raise ValueError(
f"Supported reduction modes are: mean, sum, none; got {reduction}"
)
class ImputerLoss(nn.Module):
def __init__(self, blank=0, reduction="mean", zero_infinity=False):
"""The Imputer loss
Parameters:
blank (int): index of blank tokens (default 0)
reduction (str): reduction methods applied to the output. 'none' | 'mean' | 'sum'
zero_infinity (bool): if true imputer loss will zero out infinities.
infinities mostly occur when it is impossible to generate
target sequences using input sequences
(e.g. input sequences are shorter than target sequences)
Input:
log_prob (T, N, C): C = number of characters in alphabet including blank
T = input length
N = batch size
log probability of the outputs (e.g. torch.log_softmax of logits)
targets (N, S): S = maximum number of characters in target sequences
force_emits (N, T): sequence of ctc states that should be occur given times
that is, if force_emits is state s at time t, only ctc paths
that pass state s at time t will be enabled, and will be zero out the rest
this will be same as using cross entropy loss at time t
value should be in range [-1, 2 * S + 1), valid ctc states
-1 will means that it could be any states at time t (normal ctc paths)
input_lengths (N): lengths of log_prob
target_lengths (N): lengths of targets"""
super().__init__()
self.blank = blank
self.reduction = reduction
self.zero_infinity = zero_infinity
def forward(self, log_prob, targets, force_emits, input_lengths, target_lengths):
return imputer_loss(
log_prob,
targets,
force_emits,
input_lengths,
target_lengths,
self.blank,
self.reduction,
self.zero_infinity,
)
"""class ImputerLoss(nn.Module):
def __init__(self, blank=0, reduction="mean", zero_infinity=False, mask_eps=1e-8):
super().__init__()
self.blank = blank
self.reduction = reduction
self.zero_infinity = zero_infinity
self.mask_eps = math.log(mask_eps)
def forward(
self, logit, targets_ctc, targets_ce, mask, input_lengths, targets_ctc_lengths
):
n_target = logit.shape[-1]
mask_e = mask.unsqueeze(-1)
mask_exp = mask_e.repeat(1, 1, n_target)
log_p_mask = logit.masked_fill(mask_exp == 1, self.mask_eps)
mask_exp[:, :, self.blank] = 0
log_p_mask = log_p_mask.masked_fill((mask_e == 1) & (mask_exp == 0), 0)
log_p_mask = torch.log_softmax(log_p_mask, 2)
ctc_loss = F.ctc_loss(
log_p_mask.transpose(0, 1),
targets_ctc,
input_lengths,
targets_ctc_lengths,
blank=self.blank,
reduction=self.reduction,
zero_infinity=self.zero_infinity,
)
ce_loss = F.cross_entropy(
logit.view(-1, n_target), targets_ce.view(-1), reduction="none"
)
ce_loss = mask.view(-1) * ce_loss
if self.reduction == "mean":
ce_loss = ce_loss.mean()
elif self.reduction == "sum":
ce_loss = ce_loss.sum()
return ctc_loss + ce_loss"""
def get_alignment_path(log_alpha, path):
if log_alpha.shape[0] == 1:
current_state = 0
else:
current_state = log_alpha[-2:, -1].argmax() + (log_alpha.shape[0] - 2)
path_decode = [current_state]
for t in range(path.shape[1] - 1, 0, -1):
prev_state = path[current_state, t]
path_decode.append(prev_state)
current_state = prev_state
return path_decode[::-1]
def ctc_decode(seq, blank=0):
result = []
prev = -1
for s in seq:
if s == blank:
prev = s
continue
if prev == -1:
result.append(s)
else:
if s != prev:
result.append(s)
prev = s
return result
def best_alignment(
log_prob, targets, input_lengths, target_lengths, blank=0, zero_infinity=False
):
"""Get best alignment (maximum probability sequence of ctc states)
conditioned on log probabilities and target sequences
Input:
log_prob (T, N, C): C = number of characters in alphabet including blank
T = input length
N = batch size
log probability of the outputs (e.g. torch.log_softmax of logits)
targets (N, S): S = maximum number of characters in target sequences
input_lengths (N): lengths of log_prob
target_lengths (N): lengths of targets
blank (int): index of blank tokens (default 0)
zero_infinity (bool): if true imputer loss will zero out infinities.
infinities mostly occur when it is impossible to generate
target sequences using input sequences
(e.g. input sequences are shorter than target sequences)
Output:
best_aligns (List[List[int]]): sequence of ctc states that have maximum probabilties
given log probabilties, and compatible with target sequences"""
nll, log_alpha, alignment = imputer.best_alignment(
log_prob, targets, input_lengths, target_lengths, blank, zero_infinity
)
log_alpha = log_alpha.transpose(1, 2).detach().cpu().numpy()
alignment = alignment.transpose(1, 2).detach().cpu().numpy()
best_aligns = []
for log_a, align, input_len, target_len in zip(
log_alpha, alignment, input_lengths, target_lengths
):
state_len = target_len * 2 + 1
log_a = log_a[:state_len, :input_len]
align = align[:state_len, :input_len]
best_aligns.append(get_alignment_path(log_a, align))
return best_aligns
......@@ -107,7 +107,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
)
for model in models:
if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"):
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, cfg.common_eval.post_process,
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece",
src_dict, tgt_dict)
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论