Commit ed623111 by xuchen

fix the bug of the ctc

parent 0187e5d6
......@@ -188,7 +188,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
return loss
def reduce_metrics(self, logging_outputs) -> None:
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
trans_loss_sum = utils.item(
......@@ -197,10 +198,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs)
)
if self.ctc_weight > 0:
ctc_loss_sum = utils.item(
sum(log.get("ctc_loss", 0) for log in logging_outputs)
)
ctc_loss_sum = utils.item(
sum(log.get("ctc_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
......@@ -215,13 +215,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
if self.ctc_weight > 0:
metrics.log_scalar(
"ctc_loss",
ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_scalar(
"ctc_loss",
ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论