Commit 2e6223d1 by xuchen

info

parent 12201609
......@@ -27,7 +27,8 @@ logger = logging.getLogger(__name__)
try:
from fairseq.torch_imputer import best_alignment, imputer_loss
except:
logger.error("Imputer is not available.")
# logger.error("Imputer is not available.")
pass
@dataclass
......@@ -379,15 +380,6 @@ class CtcCriterion(FairseqCriterion):
self.blank_idx,
)
# if "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) != 0:
# ctc_logit = encoder_out["ctc_logit"][0]
# ctc_alignment_oracle["ctc"] = get_ctc_align(ctc_logit, tokens, input_lengths, target_lengths, self.pad_idx, self.blank_idx)
# if "inter_ctc_logits" in encoder_out and len(encoder_out["inter_ctc_logits"]) != 0:
# ctc_alignment_oracle["inter_ctc"] = []
# for ctc_logit in encoder_out["inter_ctc_logits"]:
# ctc_alignment_oracle["inter_ctc"].append(get_ctc_align(ctc_logit, tokens, input_lengths, target_lengths, self.pad_idx, self.blank_idx))
xctc_logit = None
if "xctc_logit" in encoder_out and len(encoder_out["xctc_logit"]) != 0:
xctc_logit = encoder_out["xctc_logit"][0]
......@@ -406,7 +398,6 @@ class CtcCriterion(FairseqCriterion):
tokens = self.get_ctc_target_text(sample)
target_pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
# target_pad_mask = (tokens != self.pad_idx)
target_lengths = target_pad_mask.sum(-1)
ctc_alignment_oracle["xctc"] = get_ctc_align(
......@@ -418,15 +409,6 @@ class CtcCriterion(FairseqCriterion):
self.blank_idx,
)
# if "xctc_logit" in encoder_out and len(encoder_out["xctc_logit"]) != 0:
# xctc_logit = encoder_out["xctc_logit"][0]
# ctc_alignment_oracle["xctc"] = get_ctc_align(xctc_logit, tokens, input_lengths, target_lengths, self.pad_idx, self.blank_idx)
# if "inter_xctc_logits" in encoder_out and len(
# encoder_out["inter_xctc_logits"]) != 0:
# ctc_alignment_oracle["inter_xctc"] = []
# for xctc_logit in encoder_out["inter_xctc_logits"]:
# ctc_alignment_oracle["inter_xctc"].append(get_ctc_align(xctc_logit, tokens, input_lengths, target_lengths, self.pad_idx, self.blank_idx))
axctc_logit = None
if "axctc_logit" in encoder_out and len(encoder_out["axctc_logit"]) != 0:
axctc_logit = encoder_out["axctc_logit"][0]
......@@ -448,15 +430,6 @@ class CtcCriterion(FairseqCriterion):
self.blank_idx,
)
# if "axctc_logit" in encoder_out and len(encoder_out["axctc_logit"]) != 0:
# axctc_logit = encoder_out["axctc_logit"][0]
# ctc_alignment_oracle["axctc"] = get_ctc_align(axctc_logit, tokens, input_lengths, target_lengths, self.pad_idx, self.blank_idx)
# if "inter_axctc_logits" in encoder_out and len(
# encoder_out["inter_axctc_logits"]) != 0:
# ctc_alignment_oracle["inter_axctc"] = []
# for axctc_logit in encoder_out["inter_axctc_logits"]:
# ctc_alignment_oracle["inter_axctc"].append(get_ctc_align(axctc_logit, tokens, input_lengths, target_lengths, self.pad_idx, self.blank_idx))
return ctc_alignment_oracle
def get_ctc_loss(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论