Commit cabfc4ea by xuchen

Daily revision

parent 0a70c5c5
......@@ -13,4 +13,4 @@ ctc_infer_sort=${infer_dir}/${tag}_ctc_infer_sort
cut -f1 ${s2s_infer_file} > ${idx}
paste ${idx} ${org_ctc_infer_file} > ${ctc_infer}
sort -n -t $'\t' ${ctc_infer} | cut -f2 > ${ctc_infer_sort}
python3 ./cal_wer_lcrm.py ${ref} ${ctc_infer_sort}
\ No newline at end of file
python3 ./cal_wer.py ${ref} ${ctc_infer_sort}
\ No newline at end of file
......@@ -300,9 +300,9 @@ class CtcCriterion(FairseqCriterion):
return loss, lprobs
@staticmethod
def get_ctc_self_distill_loss(distill_num, teacher_logit, student_logits, non_padding_mask):
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
def get_ctc_self_distill_loss(distill_num, teacher_logit, student_logits, non_padding_mask, temperature=1.0):
ctc_self_distill_losses = []
for i in range(distill_num):
logit = student_logits[i]
if type(logit) == list:
......@@ -315,15 +315,15 @@ class CtcCriterion(FairseqCriterion):
continue
loss = F.kl_div(
F.log_softmax(student_logit, dim=-1, dtype=torch.float32),
F.log_softmax(teacher_logit, dim=-1, dtype=torch.float32),
# F.log_softmax(teacher_logit.detach(), dim=-1, dtype=torch.float32),
F.log_softmax(student_logit / temperature, dim=-1, dtype=torch.float32),
# F.log_softmax(teacher_logit / temperature, dim=-1, dtype=torch.float32),
F.log_softmax(teacher_logit.detach() / temperature, dim=-1, dtype=torch.float32),
log_target=True,
reduction="none",
)
ctc_self_distill_loss += loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0).sum()
ctc_self_distill_num += 1
return ctc_self_distill_num, ctc_self_distill_loss
loss = loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0).sum()
ctc_self_distill_losses.append(loss)
return ctc_self_distill_losses
def get_target_text(self, sample):
if self.aligned_target_ctc and "aligned_target" in sample:
......@@ -507,10 +507,17 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_num = interleaved_ctc_num - 1
if ctc_self_distill_num != 0:
ctc_self_distill_num, source_ctc_self_distill_loss = \
source_ctc_self_distill_losses = \
self.get_ctc_self_distill_loss(
ctc_self_distill_num, teacher_logit, student_logits, non_padding)
source_ctc_self_distill_loss /= ctc_self_distill_num
ctc_self_distill_num,
teacher_logit,
student_logits,
non_padding,
self.ctc_self_distill_temperature
)
ctc_self_distill_num = len(source_ctc_self_distill_losses)
source_ctc_self_distill_loss = sum(source_ctc_self_distill_losses) / ctc_self_distill_num
logging_output["ctc_self_distill_loss"] = utils.item(source_ctc_self_distill_loss.data)
ctc_self_distill_loss += source_ctc_self_distill_loss * self.ctc_self_distill_weight
......@@ -529,11 +536,18 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_num = target_interleaved_ctc_num - 1
if ctc_self_distill_num != 0:
ctc_self_distill_num, target_ctc_self_distill_loss = \
target_ctc_self_distill_losses = \
self.get_ctc_self_distill_loss(
ctc_self_distill_num, teacher_logit, student_logits, non_padding)
ctc_self_distill_num,
teacher_logit,
student_logits,
non_padding,
self.ctc_self_distill_temperature
)
ctc_self_distill_num = len(target_ctc_self_distill_losses)
target_ctc_self_distill_loss /= ctc_self_distill_num
target_ctc_self_distill_loss = sum(target_ctc_self_distill_losses) / ctc_self_distill_num
logging_output["target_ctc_self_distill_loss"] = utils.item(target_ctc_self_distill_loss.data)
ctc_self_distill_loss += target_ctc_self_distill_loss * self.target_ctc_self_distill_weight
......
......@@ -605,6 +605,20 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
}
return x, encoder_padding_mask, input_lengths, mixup
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None):
if hasattr(self, "ctc"):
assert src_dict is not None
self.ctc.set_infer(ctc_infer, post_process, src_dict,
path=path + ".ctc" if path is not None else None)
def ctc_valid(self, lprobs, targets, input_lengths,
dictionary, lang="source"):
if hasattr(self, "ctc"):
return self.ctc.valid(lprobs, targets, input_lengths,
dictionary)
logger.error("No ctc module in textual encoder")
def forward(self, src_tokens, src_lengths):
batch = src_tokens.size(0)
......@@ -748,7 +762,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x)
ctc_logit = self.ctc(x, encoder_padding_mask, is_top=True)
return {
"encoder_out": [x], # T x B x C
......
......@@ -314,6 +314,9 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.encoder_no_scale_embedding = getattr(args, "encoder_no_scale_embedding", False)
args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论