Commit ed623111 by xuchen

fix the bug of the ctc

parent 0187e5d6
......@@ -188,7 +188,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
return loss
def reduce_metrics(self, logging_outputs) -> None:
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
trans_loss_sum = utils.item(
......@@ -197,7 +198,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs)
)
if self.ctc_weight > 0:
ctc_loss_sum = utils.item(
sum(log.get("ctc_loss", 0) for log in logging_outputs)
)
......@@ -215,7 +215,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
if self.ctc_weight > 0:
metrics.log_scalar(
"ctc_loss",
ctc_loss_sum / sample_size / math.log(2),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论