Commit ed623111 by xuchen

fix the bug of the ctc

parent 0187e5d6
...@@ -188,7 +188,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -188,7 +188,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
return loss return loss
def reduce_metrics(self, logging_outputs) -> None: @staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
trans_loss_sum = utils.item( trans_loss_sum = utils.item(
...@@ -197,10 +198,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -197,10 +198,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
nll_loss_sum = utils.item( nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs) sum(log.get("nll_loss", 0) for log in logging_outputs)
) )
if self.ctc_weight > 0: ctc_loss_sum = utils.item(
ctc_loss_sum = utils.item( sum(log.get("ctc_loss", 0) for log in logging_outputs)
sum(log.get("ctc_loss", 0) for log in logging_outputs) )
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item( sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs) sum(log.get("sample_size", 0) for log in logging_outputs)
...@@ -215,13 +215,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -215,13 +215,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
metrics.log_scalar( metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
) )
if self.ctc_weight > 0: metrics.log_scalar(
metrics.log_scalar( "ctc_loss",
"ctc_loss", ctc_loss_sum / sample_size / math.log(2),
ctc_loss_sum / sample_size / math.log(2), sample_size,
sample_size, round=3,
round=3, )
)
metrics.log_derived( metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论