Commit cabfc4ea by xuchen

Daily revision

parent 0a70c5c5
...@@ -13,4 +13,4 @@ ctc_infer_sort=${infer_dir}/${tag}_ctc_infer_sort ...@@ -13,4 +13,4 @@ ctc_infer_sort=${infer_dir}/${tag}_ctc_infer_sort
cut -f1 ${s2s_infer_file} > ${idx} cut -f1 ${s2s_infer_file} > ${idx}
paste ${idx} ${org_ctc_infer_file} > ${ctc_infer} paste ${idx} ${org_ctc_infer_file} > ${ctc_infer}
sort -n -t $'\t' ${ctc_infer} | cut -f2 > ${ctc_infer_sort} sort -n -t $'\t' ${ctc_infer} | cut -f2 > ${ctc_infer_sort}
python3 ./cal_wer_lcrm.py ${ref} ${ctc_infer_sort} python3 ./cal_wer.py ${ref} ${ctc_infer_sort}
\ No newline at end of file \ No newline at end of file
...@@ -300,9 +300,9 @@ class CtcCriterion(FairseqCriterion): ...@@ -300,9 +300,9 @@ class CtcCriterion(FairseqCriterion):
return loss, lprobs return loss, lprobs
@staticmethod @staticmethod
def get_ctc_self_distill_loss(distill_num, teacher_logit, student_logits, non_padding_mask): def get_ctc_self_distill_loss(distill_num, teacher_logit, student_logits, non_padding_mask, temperature=1.0):
ctc_self_distill_loss = 0
ctc_self_distill_num = 0 ctc_self_distill_losses = []
for i in range(distill_num): for i in range(distill_num):
logit = student_logits[i] logit = student_logits[i]
if type(logit) == list: if type(logit) == list:
...@@ -315,15 +315,15 @@ class CtcCriterion(FairseqCriterion): ...@@ -315,15 +315,15 @@ class CtcCriterion(FairseqCriterion):
continue continue
loss = F.kl_div( loss = F.kl_div(
F.log_softmax(student_logit, dim=-1, dtype=torch.float32), F.log_softmax(student_logit / temperature, dim=-1, dtype=torch.float32),
F.log_softmax(teacher_logit, dim=-1, dtype=torch.float32), # F.log_softmax(teacher_logit / temperature, dim=-1, dtype=torch.float32),
# F.log_softmax(teacher_logit.detach(), dim=-1, dtype=torch.float32), F.log_softmax(teacher_logit.detach() / temperature, dim=-1, dtype=torch.float32),
log_target=True, log_target=True,
reduction="none", reduction="none",
) )
ctc_self_distill_loss += loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0).sum() loss = loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0).sum()
ctc_self_distill_num += 1 ctc_self_distill_losses.append(loss)
return ctc_self_distill_num, ctc_self_distill_loss return ctc_self_distill_losses
def get_target_text(self, sample): def get_target_text(self, sample):
if self.aligned_target_ctc and "aligned_target" in sample: if self.aligned_target_ctc and "aligned_target" in sample:
...@@ -507,10 +507,17 @@ class CtcCriterion(FairseqCriterion): ...@@ -507,10 +507,17 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_num = interleaved_ctc_num - 1 ctc_self_distill_num = interleaved_ctc_num - 1
if ctc_self_distill_num != 0: if ctc_self_distill_num != 0:
ctc_self_distill_num, source_ctc_self_distill_loss = \ source_ctc_self_distill_losses = \
self.get_ctc_self_distill_loss( self.get_ctc_self_distill_loss(
ctc_self_distill_num, teacher_logit, student_logits, non_padding) ctc_self_distill_num,
source_ctc_self_distill_loss /= ctc_self_distill_num teacher_logit,
student_logits,
non_padding,
self.ctc_self_distill_temperature
)
ctc_self_distill_num = len(source_ctc_self_distill_losses)
source_ctc_self_distill_loss = sum(source_ctc_self_distill_losses) / ctc_self_distill_num
logging_output["ctc_self_distill_loss"] = utils.item(source_ctc_self_distill_loss.data) logging_output["ctc_self_distill_loss"] = utils.item(source_ctc_self_distill_loss.data)
ctc_self_distill_loss += source_ctc_self_distill_loss * self.ctc_self_distill_weight ctc_self_distill_loss += source_ctc_self_distill_loss * self.ctc_self_distill_weight
...@@ -529,11 +536,18 @@ class CtcCriterion(FairseqCriterion): ...@@ -529,11 +536,18 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_num = target_interleaved_ctc_num - 1 ctc_self_distill_num = target_interleaved_ctc_num - 1
if ctc_self_distill_num != 0: if ctc_self_distill_num != 0:
ctc_self_distill_num, target_ctc_self_distill_loss = \ target_ctc_self_distill_losses = \
self.get_ctc_self_distill_loss( self.get_ctc_self_distill_loss(
ctc_self_distill_num, teacher_logit, student_logits, non_padding) ctc_self_distill_num,
teacher_logit,
student_logits,
non_padding,
self.ctc_self_distill_temperature
)
ctc_self_distill_num = len(target_ctc_self_distill_losses)
target_ctc_self_distill_loss /= ctc_self_distill_num target_ctc_self_distill_loss = sum(target_ctc_self_distill_losses) / ctc_self_distill_num
logging_output["target_ctc_self_distill_loss"] = utils.item(target_ctc_self_distill_loss.data) logging_output["target_ctc_self_distill_loss"] = utils.item(target_ctc_self_distill_loss.data)
ctc_self_distill_loss += target_ctc_self_distill_loss * self.target_ctc_self_distill_weight ctc_self_distill_loss += target_ctc_self_distill_loss * self.target_ctc_self_distill_weight
......
...@@ -605,6 +605,20 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -605,6 +605,20 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
} }
return x, encoder_padding_mask, input_lengths, mixup return x, encoder_padding_mask, input_lengths, mixup
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None):
if hasattr(self, "ctc"):
assert src_dict is not None
self.ctc.set_infer(ctc_infer, post_process, src_dict,
path=path + ".ctc" if path is not None else None)
def ctc_valid(self, lprobs, targets, input_lengths,
dictionary, lang="source"):
if hasattr(self, "ctc"):
return self.ctc.valid(lprobs, targets, input_lengths,
dictionary)
logger.error("No ctc module in textual encoder")
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
batch = src_tokens.size(0) batch = src_tokens.size(0)
...@@ -748,7 +762,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -748,7 +762,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None: if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x) ctc_logit = self.ctc(x, encoder_padding_mask, is_top=True)
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
......
...@@ -314,6 +314,9 @@ def base_architecture(args): ...@@ -314,6 +314,9 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.encoder_no_scale_embedding = getattr(args, "encoder_no_scale_embedding", False)
args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False) args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False) args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论