Commit 793f553a by xuchen

optimize the implementation of mixup:

1. using different mixup prob for each sample
2. arbitrary mixup ratio
3. cross entropy mixup consistency loss
parent 444a1f46
......@@ -186,7 +186,7 @@ class CtcCriterion(FairseqCriterion):
self.ctc_entropy + self.ctc_mixup_consistent_weight
if self.all_ctc_weight > 0:
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="none", zero_infinity=True)
self.ctc_names = []
self.use_ctc = (self.ctc_weight + self.interleaved_ctc_weight > 0)
......@@ -293,12 +293,13 @@ class CtcCriterion(FairseqCriterion):
loss = 0
with torch.backends.cudnn.flags(enabled=False):
for item_targets, item_target_lengths, item_coef in zip(targets, target_lengths, loss_coef):
loss += self.ctc_loss(
item_loss = self.ctc_loss(
lprobs,
item_targets,
input_lengths,
item_target_lengths,
) * item_coef
)
loss += (item_loss * item_coef).sum()
return loss, lprobs
@staticmethod
......@@ -403,7 +404,9 @@ class CtcCriterion(FairseqCriterion):
ctc_loss = 0
ctc_entropy = 0
use_ctc = False
if self.ctc_weight > 0 and "ctc_logit" in net_output and len(net_output["ctc_logit"]) > 0:
use_ctc = True
ctc_logit = net_output["ctc_logit"][0]
all_ctc_logits["ctc_logit"] = [ctc_logit, input_lengths]
......@@ -556,10 +559,12 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_loss += target_ctc_self_distill_loss * self.target_ctc_self_distill_weight
ctc_mixup_consistent_loss = 0
if mixup is True and self.ctc_mixup_consistent_weight > 0:
if use_ctc and mixup is True and self.ctc_mixup_consistent_weight > 0:
mixup_pos = mixup_idx1 != mixup_idx2
ctc_logit = net_output["ctc_logit"][0]
mixup_real_coef = mixup_coef[mixup_pos]
loss_coef = [mixup_real_coef, 1 - mixup_real_coef]
mixup_real_logit = ctc_logit[:, mixup_pos, :]
mixup_real_idx1 = mixup_idx1[mixup_pos]
mixup_real_idx2 = mixup_idx2[mixup_pos]
......@@ -571,12 +576,12 @@ class CtcCriterion(FairseqCriterion):
for logit, pad, coef in zip(mixup_target_logit, mixup_target_pad_mask, loss_coef):
loss = F.kl_div(
F.log_softmax(mixup_real_logit, dim=-1, dtype=torch.float32),
# F.log_softmax(teacher_logit / temperature, dim=-1, dtype=torch.float32),
# F.log_softmax(logit, dim=-1, dtype=torch.float32),
F.log_softmax(logit.detach(), dim=-1, dtype=torch.float32),
log_target=True,
reduction="none",
)
ctc_mixup_consistent_loss += loss.sum(-1).transpose(0, 1).masked_fill_(~pad, 0.0).sum() * coef
ctc_mixup_consistent_loss += (loss.sum(-1).transpose(0, 1).masked_fill_(~pad, 0.0).sum(-1) * coef).sum()
logging_output["ctc_mixup_consistent_loss"] = utils.item(ctc_mixup_consistent_loss.data)
loss = \
......@@ -588,16 +593,17 @@ class CtcCriterion(FairseqCriterion):
self.ctc_entropy * ctc_entropy + \
self.ctc_mixup_consistent_weight * ctc_mixup_consistent_loss
logging_output["all_ctc_loss"] = utils.item(loss.data)
if torch.isnan(loss) or torch.isinf(loss) or utils.item(loss.data) < 0:
# logger.warning("Illegal loss %f!" % loss)
if self.ctc_weight != 0:
logger.warning("CTC loss %f!" % ctc_loss)
if self.interleaved_ctc_weight != 0:
logger.warning("Intermedia CTC loss %f!" % interleaved_ctc_loss)
if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss)
if loss != 0:
logging_output["all_ctc_loss"] = utils.item(loss.data)
if torch.isnan(loss) or torch.isinf(loss) or utils.item(loss.data) < 0:
logger.warning("Illegal loss %f!" % loss)
if ctc_loss != 0 and (torch.isnan(ctc_loss) or torch.isinf(ctc_loss)):
logger.warning("CTC loss %f!" % ctc_loss)
if self.interleaved_ctc_weight != 0:
logger.warning("Intermedia CTC loss %f!" % interleaved_ctc_loss)
if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss)
# CER is not completely accurate and is for reference only.
if not model.training:
......
......@@ -7,6 +7,8 @@ import math
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
......@@ -19,6 +21,10 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
default=0.0,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
mixup_consistent_weight: float = field(
default=0.0,
metadata={"help": "the weight for consistency regularization of mixup"},
)
report_accuracy: bool = field(
default=False,
metadata={"help": "report accuracy metric"},
......@@ -61,12 +67,14 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
mixup_consistent_weight=0.0,
):
super().__init__(task)
self.sentence_avg = sentence_avg
self.eps = float(label_smoothing)
self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy
self.mixup_consistent_weight = mixup_consistent_weight
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
......@@ -77,7 +85,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
loss, nll_loss, other_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
......@@ -88,6 +97,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if len(other_loss) != 0:
for key, value in other_loss.items():
loss += value
logging_output[key] = utils.item(value.data)
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
......@@ -104,40 +117,67 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
else:
lprobs = lprobs[self.ignore_prefix_size:, :, :].contiguous()
target = target[self.ignore_prefix_size:, :].contiguous()
if "mixup" in net_output[1] and net_output[1]["mixup"] is not None:
mixup = net_output[1]["mixup"]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
target1 = target[idx1].view(-1)
target2 = target[idx2].view(-1)
target = [target1, target2]
else:
target = target.view(-1)
return lprobs.view(-1, lprobs.size(-1)), target
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
if type(target) == list:
assert "mixup" in net_output[1] and net_output[1]["mixup"] is not None
coef = net_output[1]["mixup"]["coef"]
loss1, nll_loss1 = label_smoothed_nll_loss(
lprobs,
target[0],
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
loss2, nll_loss2 = label_smoothed_nll_loss(
lprobs,
target[1],
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
loss = coef * loss1 + (1 - coef) * loss2
nll_loss = coef * nll_loss1 + (1 - coef) * nll_loss2
loss = nll_loss = 0
other_loss = dict()
if "mixup" in net_output[1] and net_output[1]["mixup"] is not None:
mixup = net_output[1]["mixup"]
mixup_idx1 = mixup["index1"]
mixup_idx2 = mixup["index2"]
batch_size = len(mixup_idx1)
target = model.get_targets(sample, net_output)
target1 = target[mixup_idx1].view(-1)
target2 = target[mixup_idx2].view(-1)
targets = [target1, target2]
mixup_coef = net_output[1]["mixup"]["coef"]
loss_coef = [mixup_coef, 1 - mixup_coef]
for item_target, item_coef in zip(targets, loss_coef):
item_loss, item_nll_loss = label_smoothed_nll_loss(
lprobs,
item_target,
self.eps,
ignore_index=self.padding_idx,
reduce=False,
)
loss += (item_loss.sum(-1).view(batch_size, -1).sum(-1) * item_coef).sum()
nll_loss += (item_nll_loss.sum(-1).view(batch_size, -1).sum(-1) * item_coef).sum()
mixup_consistent_loss = 0
if self.mixup_consistent_weight > 0:
lprobs = lprobs.view(batch_size, -1, lprobs.size(-1))
mixup_pos = mixup_idx1 != mixup_idx2
mixup_real_coef = mixup_coef[mixup_pos]
loss_coef = [mixup_real_coef, 1 - mixup_real_coef]
mixup_real_lprobs = lprobs[mixup_pos, :, :]
mixup_real_idx1 = mixup_idx1[mixup_pos]
mixup_real_idx2 = mixup_idx2[mixup_pos]
non_padding_mask = ~target.eq(self.padding_idx)
no_mixup_lprobs = lprobs[~mixup_pos, :, :]
mixup_target_lprobs = [no_mixup_lprobs[mixup_real_idx1, :, :], no_mixup_lprobs[mixup_real_idx2, :, :]]
mixup_target_pad_mask = [non_padding_mask[mixup_real_idx1], non_padding_mask[mixup_real_idx2]]
for tgt_lprobs, pad, coef in zip(mixup_target_lprobs, mixup_target_pad_mask, loss_coef):
item_loss = F.kl_div(
F.log_softmax(mixup_real_lprobs, dim=-1, dtype=torch.float32),
F.log_softmax(tgt_lprobs.detach(), dim=-1, dtype=torch.float32),
log_target=True,
reduction="none",
)
mixup_consistent_loss += (item_loss.sum(-1).masked_fill_(~pad, 0.0).sum(-1) * coef).sum()
other_loss["mixup_consistent_loss"] = mixup_consistent_loss * self.mixup_consistent_weight
else:
target = target.view(-1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
......@@ -145,24 +185,29 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
ignore_index=self.padding_idx,
reduce=reduce,
)
return loss, nll_loss
return loss, nll_loss, other_loss
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
if type(target) == list:
n_correct = total = 0
for item in target:
mask = item.ne(self.padding_idx)
n_correct += torch.sum(
lprobs.argmax(1).masked_select(mask).eq(item.masked_select(mask))
)
total += torch.sum(mask)
if "mixup" in net_output[1] and net_output[1]["mixup"] is not None:
mixup = net_output[1]["mixup"]
mixup_idx1 = mixup["index1"]
mixup_idx2 = mixup["index2"]
batch_size = len(mixup_idx1)
no_mixup_pos = mixup_idx1 == mixup_idx2
idx = mixup_idx1[no_mixup_pos]
lprobs = lprobs.view(batch_size, -1, lprobs.size(-1))[idx, :, :].view(-1, lprobs.size(-1))
target = target[idx].view(-1)
else:
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
)
total = torch.sum(mask)
target = target.view(-1)
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
)
total = torch.sum(mask)
return n_correct, total
@classmethod
......@@ -170,6 +215,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
mixup_consistent_loss_sum = sum(log.get("mixup_consistent_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
......@@ -182,6 +228,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
if mixup_consistent_loss_sum != 0:
metrics.log_scalar(
"mixup_consistent_loss", mixup_consistent_loss_sum / sample_size / math.log(2), sample_size, round=3
)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
......
......@@ -24,8 +24,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
sentence_avg,
cfg: CtcCriterionConfig,
ctc_weight=0.0,
save_dir=None):
super().__init__(task, sentence_avg, label_smoothing)
save_dir=None,
mixup_consistent_weight=0.0):
super().__init__(task, sentence_avg, label_smoothing,
report_accuracy=True,
mixup_consistent_weight=mixup_consistent_weight)
self.report_accuracy = True
self.ctc_weight = ctc_weight
......@@ -37,13 +40,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
"""Add criterion-specific arguments to the parser."""
LabelSmoothedCrossEntropyCriterion.add_args(parser)
CtcCriterion.add_args(parser)
# parser.add_argument(
# "--ctc-weight",
# default=0.0,
# type=float,
# metavar="D",
# help="weight of CTC loss",
# )
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
......@@ -71,21 +67,18 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
else:
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
use_mixup = False
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
use_mixup = True
net_output = model.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
loss, nll_loss, other_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
n_tokens = sample["ntokens"]
n_sentences = sample["target"].size(0)
if use_mixup:
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
sample_size //= net_output[0].size(0) if self.sentence_avg else encoder_out["mixup"]["ratio"]
n_tokens //= encoder_out["mixup"]["ratio"]
n_sentences //= net_output[0].size(0)
......@@ -97,6 +90,10 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
"nsentences": n_sentences,
"sample_size": sample_size,
}
if len(other_loss) != 0:
for key, value in other_loss.items():
loss += value
logging_output[key] = utils.item(value.data)
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
......@@ -120,6 +117,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
trans_loss_sum = utils.item(
sum(log.get("trans_loss", 0) for log in logging_outputs)
......@@ -127,9 +125,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs)
)
mixup_consistent_loss_sum = utils.item(
sum(log.get("mixup_consistent_loss", 0) for log in logging_outputs)
)
enc_loss_sum = utils.item(
sum(log.get("encoder_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
......@@ -145,6 +147,10 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
if mixup_consistent_loss_sum != 0:
metrics.log_scalar(
"mixup_consistent_loss", mixup_consistent_loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
......
......@@ -819,25 +819,39 @@ class S2TTransformerEncoder(FairseqEncoder):
batch = x.size(1)
indices = np.random.permutation(batch)
org_indices = np.arange(batch)
if self.mixup_ratio == 1:
if len(indices) % 2 != 0:
indices = np.append(indices, (indices[-1]))
idx1 = indices[0::2]
idx2 = indices[1::2]
if self.mixup_keep_org:
idx1 = np.append(org_indices, idx1)
idx2 = np.append(org_indices, idx2)
# if self.mixup_ratio == 1:
# if len(indices) % 2 != 0:
# indices = np.append(indices, (indices[-1]))
# idx1 = indices[0::2]
# idx2 = indices[1::2]
#
# if self.mixup_keep_org:
# idx1 = np.append(org_indices, idx1)
# idx2 = np.append(org_indices, idx2)
#
# else:
# mix_size = int(max(2, batch * self.mixup_ratio // 2 * 2))
# mix_indices = indices[: mix_size]
# if self.mixup_keep_org:
# idx1 = np.append(org_indices, mix_indices[0::2])
# idx2 = np.append(org_indices, mix_indices[1::2])
# else:
# idx1 = np.append(mix_indices[0::2], (indices[mix_size:]))
# idx2 = np.append(mix_indices[1::2], (indices[mix_size:]))
mixup_size = int(batch * self.mixup_ratio)
mixup_index1 = np.random.permutation(mixup_size)
mixup_index2 = np.random.permutation(mixup_size)
if self.mixup_keep_org:
idx1 = np.append(org_indices, mixup_index1)
idx2 = np.append(org_indices, mixup_index2)
else:
mix_size = int(max(2, batch * self.mixup_ratio // 2 * 2))
mix_indices = indices[: mix_size]
if self.mixup_keep_org:
idx1 = np.append(org_indices, mix_indices[0::2])
idx2 = np.append(org_indices, mix_indices[1::2])
else:
idx1 = np.append(mix_indices[0::2], (indices[mix_size:]))
idx2 = np.append(mix_indices[1::2], (indices[mix_size:]))
keep_indices = []
for i in org_indices:
if i not in mixup_index1 and i not in mixup_index2:
keep_indices.append(i)
idx1 = np.append(keep_indices, mixup_index1)
idx2 = np.append(keep_indices, mixup_index2)
idx1 = torch.from_numpy(idx1).to(x.device)
idx2 = torch.from_numpy(idx2).to(x.device)
......@@ -845,8 +859,9 @@ class S2TTransformerEncoder(FairseqEncoder):
x1 = x[:, idx1]
x2 = x[:, idx2]
coef = self.beta.sample().to(x.device).type_as(x)
x = (coef * x1 + (1 - coef) * x2)
coef = self.beta.sample([len(idx1)]).to(x.device).type_as(x).view(-1)
mixup_coef = coef.view(1, -1, 1)
x = (mixup_coef * x1 + (1 - mixup_coef) * x2)
pad1 = encoder_padding_mask[idx1]
pad2 = encoder_padding_mask[idx2]
......
......@@ -1058,7 +1058,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x1 = x[:, idx1]
x2 = x[:, idx2]
x = coef * x1 + (1 - coef) * x2
mixup_coef = coef.view(1, -1, 1)
x = mixup_coef * x1 + (1 - mixup_coef) * x2
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1]
......
......@@ -116,6 +116,16 @@ def parse_args_and_arch(
is_config_file=True,
help="Configuration YAML filename (for training)",
)
parser.add_argument(
"--train-config5",
is_config_file=True,
help="Configuration YAML filename (for training)",
)
parser.add_argument(
"--train-config6",
is_config_file=True,
help="Configuration YAML filename (for training)",
)
if suppress_defaults:
# Parse args without any default values. This requires us to parse
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论