Commit 793f553a by xuchen

optimize the implementation of mixup:

1. using different mixup prob for each sample
2. arbitrary mixup ratio
3. cross entropy mixup consistency loss
parent 444a1f46
...@@ -186,7 +186,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -186,7 +186,7 @@ class CtcCriterion(FairseqCriterion):
self.ctc_entropy + self.ctc_mixup_consistent_weight self.ctc_entropy + self.ctc_mixup_consistent_weight
if self.all_ctc_weight > 0: if self.all_ctc_weight > 0:
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True) self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="none", zero_infinity=True)
self.ctc_names = [] self.ctc_names = []
self.use_ctc = (self.ctc_weight + self.interleaved_ctc_weight > 0) self.use_ctc = (self.ctc_weight + self.interleaved_ctc_weight > 0)
...@@ -293,12 +293,13 @@ class CtcCriterion(FairseqCriterion): ...@@ -293,12 +293,13 @@ class CtcCriterion(FairseqCriterion):
loss = 0 loss = 0
with torch.backends.cudnn.flags(enabled=False): with torch.backends.cudnn.flags(enabled=False):
for item_targets, item_target_lengths, item_coef in zip(targets, target_lengths, loss_coef): for item_targets, item_target_lengths, item_coef in zip(targets, target_lengths, loss_coef):
loss += self.ctc_loss( item_loss = self.ctc_loss(
lprobs, lprobs,
item_targets, item_targets,
input_lengths, input_lengths,
item_target_lengths, item_target_lengths,
) * item_coef )
loss += (item_loss * item_coef).sum()
return loss, lprobs return loss, lprobs
@staticmethod @staticmethod
...@@ -403,7 +404,9 @@ class CtcCriterion(FairseqCriterion): ...@@ -403,7 +404,9 @@ class CtcCriterion(FairseqCriterion):
ctc_loss = 0 ctc_loss = 0
ctc_entropy = 0 ctc_entropy = 0
use_ctc = False
if self.ctc_weight > 0 and "ctc_logit" in net_output and len(net_output["ctc_logit"]) > 0: if self.ctc_weight > 0 and "ctc_logit" in net_output and len(net_output["ctc_logit"]) > 0:
use_ctc = True
ctc_logit = net_output["ctc_logit"][0] ctc_logit = net_output["ctc_logit"][0]
all_ctc_logits["ctc_logit"] = [ctc_logit, input_lengths] all_ctc_logits["ctc_logit"] = [ctc_logit, input_lengths]
...@@ -556,10 +559,12 @@ class CtcCriterion(FairseqCriterion): ...@@ -556,10 +559,12 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_loss += target_ctc_self_distill_loss * self.target_ctc_self_distill_weight ctc_self_distill_loss += target_ctc_self_distill_loss * self.target_ctc_self_distill_weight
ctc_mixup_consistent_loss = 0 ctc_mixup_consistent_loss = 0
if mixup is True and self.ctc_mixup_consistent_weight > 0: if use_ctc and mixup is True and self.ctc_mixup_consistent_weight > 0:
mixup_pos = mixup_idx1 != mixup_idx2 mixup_pos = mixup_idx1 != mixup_idx2
ctc_logit = net_output["ctc_logit"][0] ctc_logit = net_output["ctc_logit"][0]
mixup_real_coef = mixup_coef[mixup_pos]
loss_coef = [mixup_real_coef, 1 - mixup_real_coef]
mixup_real_logit = ctc_logit[:, mixup_pos, :] mixup_real_logit = ctc_logit[:, mixup_pos, :]
mixup_real_idx1 = mixup_idx1[mixup_pos] mixup_real_idx1 = mixup_idx1[mixup_pos]
mixup_real_idx2 = mixup_idx2[mixup_pos] mixup_real_idx2 = mixup_idx2[mixup_pos]
...@@ -571,12 +576,12 @@ class CtcCriterion(FairseqCriterion): ...@@ -571,12 +576,12 @@ class CtcCriterion(FairseqCriterion):
for logit, pad, coef in zip(mixup_target_logit, mixup_target_pad_mask, loss_coef): for logit, pad, coef in zip(mixup_target_logit, mixup_target_pad_mask, loss_coef):
loss = F.kl_div( loss = F.kl_div(
F.log_softmax(mixup_real_logit, dim=-1, dtype=torch.float32), F.log_softmax(mixup_real_logit, dim=-1, dtype=torch.float32),
# F.log_softmax(teacher_logit / temperature, dim=-1, dtype=torch.float32), # F.log_softmax(logit, dim=-1, dtype=torch.float32),
F.log_softmax(logit.detach(), dim=-1, dtype=torch.float32), F.log_softmax(logit.detach(), dim=-1, dtype=torch.float32),
log_target=True, log_target=True,
reduction="none", reduction="none",
) )
ctc_mixup_consistent_loss += loss.sum(-1).transpose(0, 1).masked_fill_(~pad, 0.0).sum() * coef ctc_mixup_consistent_loss += (loss.sum(-1).transpose(0, 1).masked_fill_(~pad, 0.0).sum(-1) * coef).sum()
logging_output["ctc_mixup_consistent_loss"] = utils.item(ctc_mixup_consistent_loss.data) logging_output["ctc_mixup_consistent_loss"] = utils.item(ctc_mixup_consistent_loss.data)
loss = \ loss = \
...@@ -588,11 +593,12 @@ class CtcCriterion(FairseqCriterion): ...@@ -588,11 +593,12 @@ class CtcCriterion(FairseqCriterion):
self.ctc_entropy * ctc_entropy + \ self.ctc_entropy * ctc_entropy + \
self.ctc_mixup_consistent_weight * ctc_mixup_consistent_loss self.ctc_mixup_consistent_weight * ctc_mixup_consistent_loss
if loss != 0:
logging_output["all_ctc_loss"] = utils.item(loss.data) logging_output["all_ctc_loss"] = utils.item(loss.data)
if torch.isnan(loss) or torch.isinf(loss) or utils.item(loss.data) < 0: if torch.isnan(loss) or torch.isinf(loss) or utils.item(loss.data) < 0:
# logger.warning("Illegal loss %f!" % loss) logger.warning("Illegal loss %f!" % loss)
if self.ctc_weight != 0: if ctc_loss != 0 and (torch.isnan(ctc_loss) or torch.isinf(ctc_loss)):
logger.warning("CTC loss %f!" % ctc_loss) logger.warning("CTC loss %f!" % ctc_loss)
if self.interleaved_ctc_weight != 0: if self.interleaved_ctc_weight != 0:
logger.warning("Intermedia CTC loss %f!" % interleaved_ctc_loss) logger.warning("Intermedia CTC loss %f!" % interleaved_ctc_loss)
......
...@@ -7,6 +7,8 @@ import math ...@@ -7,6 +7,8 @@ import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
import torch import torch
import torch.nn.functional as F
from fairseq import metrics, utils from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass from fairseq.dataclass import FairseqDataclass
...@@ -19,6 +21,10 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass): ...@@ -19,6 +21,10 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
default=0.0, default=0.0,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"}, metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
) )
mixup_consistent_weight: float = field(
default=0.0,
metadata={"help": "the weight for consistency regularization of mixup"},
)
report_accuracy: bool = field( report_accuracy: bool = field(
default=False, default=False,
metadata={"help": "report accuracy metric"}, metadata={"help": "report accuracy metric"},
...@@ -61,12 +67,14 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -61,12 +67,14 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
label_smoothing, label_smoothing,
ignore_prefix_size=0, ignore_prefix_size=0,
report_accuracy=False, report_accuracy=False,
mixup_consistent_weight=0.0,
): ):
super().__init__(task) super().__init__(task)
self.sentence_avg = sentence_avg self.sentence_avg = sentence_avg
self.eps = float(label_smoothing) self.eps = float(label_smoothing)
self.ignore_prefix_size = ignore_prefix_size self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy self.report_accuracy = report_accuracy
self.mixup_consistent_weight = mixup_consistent_weight
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
...@@ -77,7 +85,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -77,7 +85,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training 3) logging outputs to display while training
""" """
net_output = model(**sample["net_input"]) net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) loss, nll_loss, other_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = ( sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"] sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
) )
...@@ -88,6 +97,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -88,6 +97,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"nsentences": sample["target"].size(0), "nsentences": sample["target"].size(0),
"sample_size": sample_size, "sample_size": sample_size,
} }
if len(other_loss) != 0:
for key, value in other_loss.items():
loss += value
logging_output[key] = utils.item(value.data)
if self.report_accuracy: if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample) n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data) logging_output["n_correct"] = utils.item(n_correct.data)
...@@ -104,40 +117,67 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -104,40 +117,67 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
else: else:
lprobs = lprobs[self.ignore_prefix_size:, :, :].contiguous() lprobs = lprobs[self.ignore_prefix_size:, :, :].contiguous()
target = target[self.ignore_prefix_size:, :].contiguous() target = target[self.ignore_prefix_size:, :].contiguous()
if "mixup" in net_output[1] and net_output[1]["mixup"] is not None:
mixup = net_output[1]["mixup"]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
target1 = target[idx1].view(-1)
target2 = target[idx2].view(-1)
target = [target1, target2]
else:
target = target.view(-1)
return lprobs.view(-1, lprobs.size(-1)), target return lprobs.view(-1, lprobs.size(-1)), target
def compute_loss(self, model, net_output, sample, reduce=True): def compute_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample) lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
if type(target) == list: loss = nll_loss = 0
assert "mixup" in net_output[1] and net_output[1]["mixup"] is not None other_loss = dict()
coef = net_output[1]["mixup"]["coef"]
loss1, nll_loss1 = label_smoothed_nll_loss( if "mixup" in net_output[1] and net_output[1]["mixup"] is not None:
mixup = net_output[1]["mixup"]
mixup_idx1 = mixup["index1"]
mixup_idx2 = mixup["index2"]
batch_size = len(mixup_idx1)
target = model.get_targets(sample, net_output)
target1 = target[mixup_idx1].view(-1)
target2 = target[mixup_idx2].view(-1)
targets = [target1, target2]
mixup_coef = net_output[1]["mixup"]["coef"]
loss_coef = [mixup_coef, 1 - mixup_coef]
for item_target, item_coef in zip(targets, loss_coef):
item_loss, item_nll_loss = label_smoothed_nll_loss(
lprobs, lprobs,
target[0], item_target,
self.eps, self.eps,
ignore_index=self.padding_idx, ignore_index=self.padding_idx,
reduce=reduce, reduce=False,
) )
loss2, nll_loss2 = label_smoothed_nll_loss( loss += (item_loss.sum(-1).view(batch_size, -1).sum(-1) * item_coef).sum()
lprobs, nll_loss += (item_nll_loss.sum(-1).view(batch_size, -1).sum(-1) * item_coef).sum()
target[1],
self.eps, mixup_consistent_loss = 0
ignore_index=self.padding_idx, if self.mixup_consistent_weight > 0:
reduce=reduce, lprobs = lprobs.view(batch_size, -1, lprobs.size(-1))
mixup_pos = mixup_idx1 != mixup_idx2
mixup_real_coef = mixup_coef[mixup_pos]
loss_coef = [mixup_real_coef, 1 - mixup_real_coef]
mixup_real_lprobs = lprobs[mixup_pos, :, :]
mixup_real_idx1 = mixup_idx1[mixup_pos]
mixup_real_idx2 = mixup_idx2[mixup_pos]
non_padding_mask = ~target.eq(self.padding_idx)
no_mixup_lprobs = lprobs[~mixup_pos, :, :]
mixup_target_lprobs = [no_mixup_lprobs[mixup_real_idx1, :, :], no_mixup_lprobs[mixup_real_idx2, :, :]]
mixup_target_pad_mask = [non_padding_mask[mixup_real_idx1], non_padding_mask[mixup_real_idx2]]
for tgt_lprobs, pad, coef in zip(mixup_target_lprobs, mixup_target_pad_mask, loss_coef):
item_loss = F.kl_div(
F.log_softmax(mixup_real_lprobs, dim=-1, dtype=torch.float32),
F.log_softmax(tgt_lprobs.detach(), dim=-1, dtype=torch.float32),
log_target=True,
reduction="none",
) )
loss = coef * loss1 + (1 - coef) * loss2 mixup_consistent_loss += (item_loss.sum(-1).masked_fill_(~pad, 0.0).sum(-1) * coef).sum()
nll_loss = coef * nll_loss1 + (1 - coef) * nll_loss2 other_loss["mixup_consistent_loss"] = mixup_consistent_loss * self.mixup_consistent_weight
else: else:
target = target.view(-1)
loss, nll_loss = label_smoothed_nll_loss( loss, nll_loss = label_smoothed_nll_loss(
lprobs, lprobs,
target, target,
...@@ -145,19 +185,24 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -145,19 +185,24 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
ignore_index=self.padding_idx, ignore_index=self.padding_idx,
reduce=reduce, reduce=reduce,
) )
return loss, nll_loss return loss, nll_loss, other_loss
def compute_accuracy(self, model, net_output, sample): def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample) lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
if type(target) == list:
n_correct = total = 0 if "mixup" in net_output[1] and net_output[1]["mixup"] is not None:
for item in target: mixup = net_output[1]["mixup"]
mask = item.ne(self.padding_idx) mixup_idx1 = mixup["index1"]
n_correct += torch.sum( mixup_idx2 = mixup["index2"]
lprobs.argmax(1).masked_select(mask).eq(item.masked_select(mask)) batch_size = len(mixup_idx1)
)
total += torch.sum(mask) no_mixup_pos = mixup_idx1 == mixup_idx2
idx = mixup_idx1[no_mixup_pos]
lprobs = lprobs.view(batch_size, -1, lprobs.size(-1))[idx, :, :].view(-1, lprobs.size(-1))
target = target[idx].view(-1)
else: else:
target = target.view(-1)
mask = target.ne(self.padding_idx) mask = target.ne(self.padding_idx)
n_correct = torch.sum( n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)) lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
...@@ -170,6 +215,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -170,6 +215,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs) loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
mixup_consistent_loss_sum = sum(log.get("mixup_consistent_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
...@@ -182,6 +228,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -182,6 +228,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
metrics.log_derived( metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
) )
if mixup_consistent_loss_sum != 0:
metrics.log_scalar(
"mixup_consistent_loss", mixup_consistent_loss_sum / sample_size / math.log(2), sample_size, round=3
)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs)) total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0: if total > 0:
......
...@@ -24,8 +24,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -24,8 +24,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
sentence_avg, sentence_avg,
cfg: CtcCriterionConfig, cfg: CtcCriterionConfig,
ctc_weight=0.0, ctc_weight=0.0,
save_dir=None): save_dir=None,
super().__init__(task, sentence_avg, label_smoothing) mixup_consistent_weight=0.0):
super().__init__(task, sentence_avg, label_smoothing,
report_accuracy=True,
mixup_consistent_weight=mixup_consistent_weight)
self.report_accuracy = True self.report_accuracy = True
self.ctc_weight = ctc_weight self.ctc_weight = ctc_weight
...@@ -37,13 +40,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -37,13 +40,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
"""Add criterion-specific arguments to the parser.""" """Add criterion-specific arguments to the parser."""
LabelSmoothedCrossEntropyCriterion.add_args(parser) LabelSmoothedCrossEntropyCriterion.add_args(parser)
CtcCriterion.add_args(parser) CtcCriterion.add_args(parser)
# parser.add_argument(
# "--ctc-weight",
# default=0.0,
# type=float,
# metavar="D",
# help="weight of CTC loss",
# )
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
...@@ -71,21 +67,18 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -71,21 +67,18 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
else: else:
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths) encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
use_mixup = False
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
use_mixup = True
net_output = model.decoder( net_output = model.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
) )
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) loss, nll_loss, other_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = ( sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"] sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
) )
n_tokens = sample["ntokens"] n_tokens = sample["ntokens"]
n_sentences = sample["target"].size(0) n_sentences = sample["target"].size(0)
if use_mixup:
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
sample_size //= net_output[0].size(0) if self.sentence_avg else encoder_out["mixup"]["ratio"] sample_size //= net_output[0].size(0) if self.sentence_avg else encoder_out["mixup"]["ratio"]
n_tokens //= encoder_out["mixup"]["ratio"] n_tokens //= encoder_out["mixup"]["ratio"]
n_sentences //= net_output[0].size(0) n_sentences //= net_output[0].size(0)
...@@ -97,6 +90,10 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -97,6 +90,10 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
"nsentences": n_sentences, "nsentences": n_sentences,
"sample_size": sample_size, "sample_size": sample_size,
} }
if len(other_loss) != 0:
for key, value in other_loss.items():
loss += value
logging_output[key] = utils.item(value.data)
if self.report_accuracy: if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample) n_correct, total = self.compute_accuracy(model, net_output, sample)
...@@ -120,6 +117,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -120,6 +117,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
@staticmethod @staticmethod
def reduce_metrics(logging_outputs) -> None: def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
trans_loss_sum = utils.item( trans_loss_sum = utils.item(
sum(log.get("trans_loss", 0) for log in logging_outputs) sum(log.get("trans_loss", 0) for log in logging_outputs)
...@@ -127,9 +125,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -127,9 +125,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
nll_loss_sum = utils.item( nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs) sum(log.get("nll_loss", 0) for log in logging_outputs)
) )
mixup_consistent_loss_sum = utils.item(
sum(log.get("mixup_consistent_loss", 0) for log in logging_outputs)
)
enc_loss_sum = utils.item( enc_loss_sum = utils.item(
sum(log.get("encoder_loss", 0) for log in logging_outputs) sum(log.get("encoder_loss", 0) for log in logging_outputs)
) )
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item( sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs) sum(log.get("sample_size", 0) for log in logging_outputs)
...@@ -145,6 +147,10 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -145,6 +147,10 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
metrics.log_scalar( metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
) )
if mixup_consistent_loss_sum != 0:
metrics.log_scalar(
"mixup_consistent_loss", mixup_consistent_loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_derived( metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
) )
......
...@@ -819,25 +819,39 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -819,25 +819,39 @@ class S2TTransformerEncoder(FairseqEncoder):
batch = x.size(1) batch = x.size(1)
indices = np.random.permutation(batch) indices = np.random.permutation(batch)
org_indices = np.arange(batch) org_indices = np.arange(batch)
if self.mixup_ratio == 1: # if self.mixup_ratio == 1:
if len(indices) % 2 != 0: # if len(indices) % 2 != 0:
indices = np.append(indices, (indices[-1])) # indices = np.append(indices, (indices[-1]))
idx1 = indices[0::2] # idx1 = indices[0::2]
idx2 = indices[1::2] # idx2 = indices[1::2]
#
if self.mixup_keep_org: # if self.mixup_keep_org:
idx1 = np.append(org_indices, idx1) # idx1 = np.append(org_indices, idx1)
idx2 = np.append(org_indices, idx2) # idx2 = np.append(org_indices, idx2)
#
else: # else:
mix_size = int(max(2, batch * self.mixup_ratio // 2 * 2)) # mix_size = int(max(2, batch * self.mixup_ratio // 2 * 2))
mix_indices = indices[: mix_size] # mix_indices = indices[: mix_size]
# if self.mixup_keep_org:
# idx1 = np.append(org_indices, mix_indices[0::2])
# idx2 = np.append(org_indices, mix_indices[1::2])
# else:
# idx1 = np.append(mix_indices[0::2], (indices[mix_size:]))
# idx2 = np.append(mix_indices[1::2], (indices[mix_size:]))
mixup_size = int(batch * self.mixup_ratio)
mixup_index1 = np.random.permutation(mixup_size)
mixup_index2 = np.random.permutation(mixup_size)
if self.mixup_keep_org: if self.mixup_keep_org:
idx1 = np.append(org_indices, mix_indices[0::2]) idx1 = np.append(org_indices, mixup_index1)
idx2 = np.append(org_indices, mix_indices[1::2]) idx2 = np.append(org_indices, mixup_index2)
else: else:
idx1 = np.append(mix_indices[0::2], (indices[mix_size:])) keep_indices = []
idx2 = np.append(mix_indices[1::2], (indices[mix_size:])) for i in org_indices:
if i not in mixup_index1 and i not in mixup_index2:
keep_indices.append(i)
idx1 = np.append(keep_indices, mixup_index1)
idx2 = np.append(keep_indices, mixup_index2)
idx1 = torch.from_numpy(idx1).to(x.device) idx1 = torch.from_numpy(idx1).to(x.device)
idx2 = torch.from_numpy(idx2).to(x.device) idx2 = torch.from_numpy(idx2).to(x.device)
...@@ -845,8 +859,9 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -845,8 +859,9 @@ class S2TTransformerEncoder(FairseqEncoder):
x1 = x[:, idx1] x1 = x[:, idx1]
x2 = x[:, idx2] x2 = x[:, idx2]
coef = self.beta.sample().to(x.device).type_as(x) coef = self.beta.sample([len(idx1)]).to(x.device).type_as(x).view(-1)
x = (coef * x1 + (1 - coef) * x2) mixup_coef = coef.view(1, -1, 1)
x = (mixup_coef * x1 + (1 - mixup_coef) * x2)
pad1 = encoder_padding_mask[idx1] pad1 = encoder_padding_mask[idx1]
pad2 = encoder_padding_mask[idx2] pad2 = encoder_padding_mask[idx2]
......
...@@ -1058,7 +1058,8 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -1058,7 +1058,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x1 = x[:, idx1] x1 = x[:, idx1]
x2 = x[:, idx2] x2 = x[:, idx2]
x = coef * x1 + (1 - coef) * x2 mixup_coef = coef.view(1, -1, 1)
x = mixup_coef * x1 + (1 - mixup_coef) * x2
if self_attn_padding_mask is not None: if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1] pad1 = self_attn_padding_mask[idx1]
......
...@@ -116,6 +116,16 @@ def parse_args_and_arch( ...@@ -116,6 +116,16 @@ def parse_args_and_arch(
is_config_file=True, is_config_file=True,
help="Configuration YAML filename (for training)", help="Configuration YAML filename (for training)",
) )
parser.add_argument(
"--train-config5",
is_config_file=True,
help="Configuration YAML filename (for training)",
)
parser.add_argument(
"--train-config6",
is_config_file=True,
help="Configuration YAML filename (for training)",
)
if suppress_defaults: if suppress_defaults:
# Parse args without any default values. This requires us to parse # Parse args without any default values. This requires us to parse
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论