Commit 2de89089 by xuchen

fix the bugs of sae for MT

parent 380d7794
...@@ -41,7 +41,7 @@ interleaved-ctc-weight: 0.3 ...@@ -41,7 +41,7 @@ interleaved-ctc-weight: 0.3
interleaved-ctc-layers: 6,9 interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0 interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0 interleaved-ctc-drop-prob: 0
interleaved_ctc_upsampling_ratio: 2 interleaved_ctc_upsampling_ratio: 3
sae-adapter: league sae-adapter: league
sae-drop-prob: 0.0 sae-drop-prob: 0.0
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
gpu_num=1 gpu_num=1
data_dir= data_dir=
test_subset=(test) test_subset=(valid test)
exp_name= exp_name=
if [ "$#" -eq 1 ]; then if [ "$#" -eq 1 ]; then
...@@ -14,7 +14,7 @@ sacrebleu=1 ...@@ -14,7 +14,7 @@ sacrebleu=1
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=80000 max_tokens=20000
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
......
arch: s2t_transformer_s arch: s2t_transformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
share-ctc-and-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
......
ctc-weight: 0.3 ctc-weight: 0.3
share-ctc-and-embed: True
interleaved-ctc-weight: 0.2 interleaved-ctc-weight: 0.2
interleaved-ctc-layers: 6,9 interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0 interleaved-ctc-temperature: 1.0
......
...@@ -11,6 +11,7 @@ from omegaconf import II ...@@ -11,6 +11,7 @@ from omegaconf import II
from typing import Optional from typing import Optional
import numpy as np import numpy as np
import logging import logging
import editdistance
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -65,6 +66,10 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -65,6 +66,10 @@ class CtcCriterionConfig(FairseqDataclass):
default=0.0, default=0.0,
metadata={"help": "weight of the self distillation CTC loss"}, metadata={"help": "weight of the self distillation CTC loss"},
) )
ctc_self_distill_prob: float = field(
default=0.1,
metadata={"help": "probability to use distillation loss"},
)
wer_kenlm_model: Optional[str] = field( wer_kenlm_model: Optional[str] = field(
default=None, default=None,
...@@ -137,6 +142,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -137,6 +142,7 @@ class CtcCriterion(FairseqCriterion):
self.target_ctc_weight = cfg.target_ctc_weight self.target_ctc_weight = cfg.target_ctc_weight
self.target_interleaved_ctc_weight = cfg.target_interleaved_ctc_weight self.target_interleaved_ctc_weight = cfg.target_interleaved_ctc_weight
self.ctc_self_distill_weight = cfg.ctc_self_distill_weight self.ctc_self_distill_weight = cfg.ctc_self_distill_weight
self.ctc_self_distill_prob = cfg.ctc_self_distill_prob
self.ctc_entropy = cfg.ctc_entropy self.ctc_entropy = cfg.ctc_entropy
self.ctc_entropy_cutoff = cfg.ctc_entropy_cutoff self.ctc_entropy_cutoff = cfg.ctc_entropy_cutoff
self.all_ctc_weight = self.ctc_weight + self.interleaved_ctc_weight + \ self.all_ctc_weight = self.ctc_weight + self.interleaved_ctc_weight + \
...@@ -333,7 +339,8 @@ class CtcCriterion(FairseqCriterion): ...@@ -333,7 +339,8 @@ class CtcCriterion(FairseqCriterion):
# calculate the self distillation CTC loss # calculate the self distillation CTC loss
ctc_self_distill_loss = 0 ctc_self_distill_loss = 0
ctc_self_distill_num = 0 ctc_self_distill_num = 0
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0: if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0 and \
torch.rand() < self.ctc_self_distill_prob:
for i in range(interleaved_ctc_num): for i in range(interleaved_ctc_num):
out = net_output["interleaved_ctc_logits"][i] out = net_output["interleaved_ctc_logits"][i]
if type(out) == list: if type(out) == list:
...@@ -347,7 +354,8 @@ class CtcCriterion(FairseqCriterion): ...@@ -347,7 +354,8 @@ class CtcCriterion(FairseqCriterion):
loss = F.kl_div( loss = F.kl_div(
F.log_softmax(inter_ctc_logit, dim=-1, dtype=torch.float32), F.log_softmax(inter_ctc_logit, dim=-1, dtype=torch.float32),
F.softmax(ctc_logit, dim=-1, dtype=torch.float32), F.log_softmax(ctc_logit, dim=-1, dtype=torch.float32).detach(),
log_target=True,
reduction="none", reduction="none",
) )
loss = loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0) loss = loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0)
...@@ -379,8 +387,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -379,8 +387,6 @@ class CtcCriterion(FairseqCriterion):
logger.warning("Target CTC loss %f!" % target_ctc_loss) logger.warning("Target CTC loss %f!" % target_ctc_loss)
if not model.training and self.ctc_weight + self.interleaved_ctc_weight > 0: if not model.training and self.ctc_weight + self.interleaved_ctc_weight > 0:
import editdistance
with torch.no_grad(): with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
target = tokens target = tokens
......
...@@ -399,9 +399,9 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -399,9 +399,9 @@ class S2TSATEEncoder(FairseqEncoder):
# acoustic encoder # acoustic encoder
acoustic_encoder_type = args.acoustic_encoder acoustic_encoder_type = args.acoustic_encoder
if acoustic_encoder_type == "transformer": if acoustic_encoder_type == "transformer":
self.acoustic_encoder = S2TTransformerEncoder(args, task) self.acoustic_encoder = S2TTransformerEncoder(args, task, decoder_embed_tokens)
elif acoustic_encoder_type == "pds": elif acoustic_encoder_type == "pds":
self.acoustic_encoder = PDSS2TTransformerEncoder(args, task) self.acoustic_encoder = PDSS2TTransformerEncoder(args, task, decoder_embed_tokens)
else: else:
logging.error("Unsupported model arch {}!".format(acoustic_encoder_type)) logging.error("Unsupported model arch {}!".format(acoustic_encoder_type))
......
...@@ -708,18 +708,29 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -708,18 +708,29 @@ class TransformerCTCEncoder(FairseqEncoder):
return_all_hiddens, return_all_hiddens,
token_embeddings) token_embeddings)
def upsampling(self, x): def upsampling(self, x, padding):
ratio = self.interleaved_ctc_upsampling_ratio ratio = self.interleaved_ctc_upsampling_ratio
if ratio <= 1: if ratio <= 1:
return x return x
seq_len, bsz, dim = x.size() bsz, seq_len, dim = x.size()
x = x.unsqueeze(1).expand(-1, ratio, -1, -1).reshape(-1, bsz, dim) up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
return x up_padding = padding.unsqueeze(-1).expand(-1, -1, ratio).reshape(bsz, -1)
def set_ctc_infer(self, ctc_infer, post_process): output_length = int(seq_len * ratio * 2/3)
select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
select_matrix[:, 1::ratio] = 1
threshold = select_matrix.sort(dim=-1, descending=True)[0][:, output_length:output_length + 1]
select_matrix = (select_matrix > threshold)
assert all(select_matrix.sum(dim=-1).eq(output_length))
out_x = up_x[select_matrix, :].reshape(bsz, -1, dim).contiguous()
out_padding = up_padding[select_matrix].reshape(bsz, -1).contiguous()
return out_x, out_padding
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
if hasattr(self, "ctc"): if hasattr(self, "ctc"):
self.ctc.set_infer(ctc_infer, post_process) assert tgt_dict is not None
self.ctc.set_infer(ctc_infer, post_process, tgt_dict)
# TorchScript doesn't support super() method so that the scriptable Subclass # TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript. # can't access the base class model in Torchscript.
...@@ -768,21 +779,19 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -768,21 +779,19 @@ class TransformerCTCEncoder(FairseqEncoder):
if encoder_padding_mask is not None: if encoder_padding_mask is not None:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
ctc_padding_mask = encoder_padding_mask
if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
x, encoder_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_padding_mask = encoder_padding_mask
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
bsz = x.size(1)
encoder_states = [] encoder_states = []
if return_all_hiddens: if return_all_hiddens:
encoder_states.append(x) encoder_states.append(x)
org_encoder_padding_mask = encoder_padding_mask
ctc_padding_mask = encoder_padding_mask
if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
ctc_padding_mask = encoder_padding_mask.unsqueeze(-1). \
expand(-1, -1, self.interleaved_ctc_upsampling_ratio).reshape(bsz, -1)
# add emb into history # add emb into history
if self.history is not None: if self.history is not None:
self.history.push(x) self.history.push(x)
...@@ -795,10 +804,6 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -795,10 +804,6 @@ class TransformerCTCEncoder(FairseqEncoder):
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
if layer_idx + 1 in self.interleaved_ctc_layers:
x = self.upsampling(x)
encoder_padding_mask = ctc_padding_mask
x = layer( x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None x, encoder_padding_mask=encoder_padding_mask if has_pads else None
) )
...@@ -809,7 +814,7 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -809,7 +814,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC # CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(self.upsampling(x.clone()), ctc_padding_mask) ctc_logit = self.ctc(x.clone(), ctc_padding_mask)
# Interleaved CTC # Interleaved CTC
if layer_idx in self.interleaved_ctc_layers: if layer_idx in self.interleaved_ctc_layers:
...@@ -826,10 +831,10 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -826,10 +831,10 @@ class TransformerCTCEncoder(FairseqEncoder):
x, _ = self.sae([norm_x, prob]) x, _ = self.sae([norm_x, prob])
x = x.permute(1, 2, 0) # x = x.permute(1, 2, 0)
x = self.pool(x) # x = self.pool(x)
x = x.permute(2, 0, 1) # x = x.permute(2, 0, 1)
encoder_padding_mask = org_encoder_padding_mask # encoder_padding_mask = org_encoder_padding_mask
if self.history is not None: if self.history is not None:
self.history.push(x) self.history.push(x)
...@@ -841,7 +846,7 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -841,7 +846,7 @@ class TransformerCTCEncoder(FairseqEncoder):
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None: if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(self.upsampling(x), ctc_padding_mask) ctc_logit = self.ctc(x, ctc_padding_mask)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead. # `forward` so we use a dictionary instead.
......
...@@ -78,7 +78,7 @@ class CTC(nn.Module): ...@@ -78,7 +78,7 @@ class CTC(nn.Module):
pred_units = self.dictionary.string(pred_units_arr) pred_units = self.dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split() pred_words_raw = post_process(pred_units, self.post_process).split()
print(pred_words_raw) logger.info("\nCTC prediction: %s" % " ".join(pred_words_raw))
def valid(self, logits_or_probs, target, lengths): def valid(self, logits_or_probs, target, lengths):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论