# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion try: from fairseq.model_parallel.megatron.mpu.cross_entropy import ( vocab_parallel_cross_entropy, ) has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False @register_criterion("vocab_parallel_cross_entropy") class VocabParallelCrossEntropyCriterion(FairseqCriterion): def __init__(self, task, sentence_avg): super().__init__(task) self.sentence_avg = sentence_avg if not has_megatron_submodule: raise ImportError( "\n\nPlease install the megatron submodule:" "\n\n git submodule update --init " "fairseq/model_parallel/megatron" ) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ net_output = model(**sample["net_input"]) target = sample["target"] loss = vocab_parallel_cross_entropy(net_output[0].float(), target) loss = (loss * (target != self.padding_idx)).sum() sample_size = ( sample["target"].size(0) if self.sentence_avg else sample["ntokens"] ) logging_output = { "loss": utils.item(loss.data) if reduce else loss.data, "ntokens": sample["ntokens"], "nsentences": sample["target"].size(0), "sample_size": sample_size, } return loss, sample_size, logging_output @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get("loss", 0) for log in logging_outputs) ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) metrics.log_scalar( "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 ) if sample_size != ntokens: metrics.log_scalar( "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 ) metrics.log_derived( "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) ) else: metrics.log_derived( "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) ) @staticmethod def logging_outputs_can_be_summed() -> bool: """ Whether the logging outputs returned by `forward` can be summed across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ return True