Commit 2e6223d1 by xuchen

info

parent 12201609
...@@ -27,7 +27,8 @@ logger = logging.getLogger(__name__) ...@@ -27,7 +27,8 @@ logger = logging.getLogger(__name__)
try: try:
from fairseq.torch_imputer import best_alignment, imputer_loss from fairseq.torch_imputer import best_alignment, imputer_loss
except: except:
logger.error("Imputer is not available.") # logger.error("Imputer is not available.")
pass
@dataclass @dataclass
...@@ -379,15 +380,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -379,15 +380,6 @@ class CtcCriterion(FairseqCriterion):
self.blank_idx, self.blank_idx,
) )
# if "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) != 0:
# ctc_logit = encoder_out["ctc_logit"][0]
# ctc_alignment_oracle["ctc"] = get_ctc_align(ctc_logit, tokens, input_lengths, target_lengths, self.pad_idx, self.blank_idx)
# if "inter_ctc_logits" in encoder_out and len(encoder_out["inter_ctc_logits"]) != 0:
# ctc_alignment_oracle["inter_ctc"] = []
# for ctc_logit in encoder_out["inter_ctc_logits"]:
# ctc_alignment_oracle["inter_ctc"].append(get_ctc_align(ctc_logit, tokens, input_lengths, target_lengths, self.pad_idx, self.blank_idx))
xctc_logit = None xctc_logit = None
if "xctc_logit" in encoder_out and len(encoder_out["xctc_logit"]) != 0: if "xctc_logit" in encoder_out and len(encoder_out["xctc_logit"]) != 0:
xctc_logit = encoder_out["xctc_logit"][0] xctc_logit = encoder_out["xctc_logit"][0]
...@@ -406,7 +398,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -406,7 +398,6 @@ class CtcCriterion(FairseqCriterion):
tokens = self.get_ctc_target_text(sample) tokens = self.get_ctc_target_text(sample)
target_pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx) target_pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
# target_pad_mask = (tokens != self.pad_idx)
target_lengths = target_pad_mask.sum(-1) target_lengths = target_pad_mask.sum(-1)
ctc_alignment_oracle["xctc"] = get_ctc_align( ctc_alignment_oracle["xctc"] = get_ctc_align(
...@@ -418,15 +409,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -418,15 +409,6 @@ class CtcCriterion(FairseqCriterion):
self.blank_idx, self.blank_idx,
) )
# if "xctc_logit" in encoder_out and len(encoder_out["xctc_logit"]) != 0:
# xctc_logit = encoder_out["xctc_logit"][0]
# ctc_alignment_oracle["xctc"] = get_ctc_align(xctc_logit, tokens, input_lengths, target_lengths, self.pad_idx, self.blank_idx)
# if "inter_xctc_logits" in encoder_out and len(
# encoder_out["inter_xctc_logits"]) != 0:
# ctc_alignment_oracle["inter_xctc"] = []
# for xctc_logit in encoder_out["inter_xctc_logits"]:
# ctc_alignment_oracle["inter_xctc"].append(get_ctc_align(xctc_logit, tokens, input_lengths, target_lengths, self.pad_idx, self.blank_idx))
axctc_logit = None axctc_logit = None
if "axctc_logit" in encoder_out and len(encoder_out["axctc_logit"]) != 0: if "axctc_logit" in encoder_out and len(encoder_out["axctc_logit"]) != 0:
axctc_logit = encoder_out["axctc_logit"][0] axctc_logit = encoder_out["axctc_logit"][0]
...@@ -448,15 +430,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -448,15 +430,6 @@ class CtcCriterion(FairseqCriterion):
self.blank_idx, self.blank_idx,
) )
# if "axctc_logit" in encoder_out and len(encoder_out["axctc_logit"]) != 0:
# axctc_logit = encoder_out["axctc_logit"][0]
# ctc_alignment_oracle["axctc"] = get_ctc_align(axctc_logit, tokens, input_lengths, target_lengths, self.pad_idx, self.blank_idx)
# if "inter_axctc_logits" in encoder_out and len(
# encoder_out["inter_axctc_logits"]) != 0:
# ctc_alignment_oracle["inter_axctc"] = []
# for axctc_logit in encoder_out["inter_axctc_logits"]:
# ctc_alignment_oracle["inter_axctc"].append(get_ctc_align(axctc_logit, tokens, input_lengths, target_lengths, self.pad_idx, self.blank_idx))
return ctc_alignment_oracle return ctc_alignment_oracle
def get_ctc_loss( def get_ctc_loss(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论