Commit 2de89089 by xuchen

fix the bugs of sae for MT

parent 380d7794
......@@ -41,7 +41,7 @@ interleaved-ctc-weight: 0.3
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
interleaved_ctc_upsampling_ratio: 2
interleaved_ctc_upsampling_ratio: 3
sae-adapter: league
sae-drop-prob: 0.0
......
......@@ -3,7 +3,7 @@
gpu_num=1
data_dir=
test_subset=(test)
test_subset=(valid test)
exp_name=
if [ "$#" -eq 1 ]; then
......@@ -14,7 +14,7 @@ sacrebleu=1
n_average=10
beam_size=5
len_penalty=1.0
max_tokens=80000
max_tokens=20000
dec_model=checkpoint_best.pt
cmd="./run.sh
......
arch: s2t_transformer_s
share-decoder-input-output-embed: True
share-ctc-and-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
......
ctc-weight: 0.3
share-ctc-and-embed: True
interleaved-ctc-weight: 0.2
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
......
......@@ -11,6 +11,7 @@ from omegaconf import II
from typing import Optional
import numpy as np
import logging
import editdistance
import torch
import torch.nn.functional as F
......@@ -65,6 +66,10 @@ class CtcCriterionConfig(FairseqDataclass):
default=0.0,
metadata={"help": "weight of the self distillation CTC loss"},
)
ctc_self_distill_prob: float = field(
default=0.1,
metadata={"help": "probability to use distillation loss"},
)
wer_kenlm_model: Optional[str] = field(
default=None,
......@@ -137,6 +142,7 @@ class CtcCriterion(FairseqCriterion):
self.target_ctc_weight = cfg.target_ctc_weight
self.target_interleaved_ctc_weight = cfg.target_interleaved_ctc_weight
self.ctc_self_distill_weight = cfg.ctc_self_distill_weight
self.ctc_self_distill_prob = cfg.ctc_self_distill_prob
self.ctc_entropy = cfg.ctc_entropy
self.ctc_entropy_cutoff = cfg.ctc_entropy_cutoff
self.all_ctc_weight = self.ctc_weight + self.interleaved_ctc_weight + \
......@@ -333,7 +339,8 @@ class CtcCriterion(FairseqCriterion):
# calculate the self distillation CTC loss
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0:
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and interleaved_ctc_num > 0 and \
torch.rand() < self.ctc_self_distill_prob:
for i in range(interleaved_ctc_num):
out = net_output["interleaved_ctc_logits"][i]
if type(out) == list:
......@@ -347,7 +354,8 @@ class CtcCriterion(FairseqCriterion):
loss = F.kl_div(
F.log_softmax(inter_ctc_logit, dim=-1, dtype=torch.float32),
F.softmax(ctc_logit, dim=-1, dtype=torch.float32),
F.log_softmax(ctc_logit, dim=-1, dtype=torch.float32).detach(),
log_target=True,
reduction="none",
)
loss = loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0)
......@@ -379,8 +387,6 @@ class CtcCriterion(FairseqCriterion):
logger.warning("Target CTC loss %f!" % target_ctc_loss)
if not model.training and self.ctc_weight + self.interleaved_ctc_weight > 0:
import editdistance
with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
target = tokens
......
......@@ -399,9 +399,9 @@ class S2TSATEEncoder(FairseqEncoder):
# acoustic encoder
acoustic_encoder_type = args.acoustic_encoder
if acoustic_encoder_type == "transformer":
self.acoustic_encoder = S2TTransformerEncoder(args, task)
self.acoustic_encoder = S2TTransformerEncoder(args, task, decoder_embed_tokens)
elif acoustic_encoder_type == "pds":
self.acoustic_encoder = PDSS2TTransformerEncoder(args, task)
self.acoustic_encoder = PDSS2TTransformerEncoder(args, task, decoder_embed_tokens)
else:
logging.error("Unsupported model arch {}!".format(acoustic_encoder_type))
......
......@@ -708,18 +708,29 @@ class TransformerCTCEncoder(FairseqEncoder):
return_all_hiddens,
token_embeddings)
def upsampling(self, x):
def upsampling(self, x, padding):
ratio = self.interleaved_ctc_upsampling_ratio
if ratio <= 1:
return x
seq_len, bsz, dim = x.size()
x = x.unsqueeze(1).expand(-1, ratio, -1, -1).reshape(-1, bsz, dim)
return x
def set_ctc_infer(self, ctc_infer, post_process):
bsz, seq_len, dim = x.size()
up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
up_padding = padding.unsqueeze(-1).expand(-1, -1, ratio).reshape(bsz, -1)
output_length = int(seq_len * ratio * 2/3)
select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
select_matrix[:, 1::ratio] = 1
threshold = select_matrix.sort(dim=-1, descending=True)[0][:, output_length:output_length + 1]
select_matrix = (select_matrix > threshold)
assert all(select_matrix.sum(dim=-1).eq(output_length))
out_x = up_x[select_matrix, :].reshape(bsz, -1, dim).contiguous()
out_padding = up_padding[select_matrix].reshape(bsz, -1).contiguous()
return out_x, out_padding
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
if hasattr(self, "ctc"):
self.ctc.set_infer(ctc_infer, post_process)
assert tgt_dict is not None
self.ctc.set_infer(ctc_infer, post_process, tgt_dict)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
......@@ -768,21 +779,19 @@ class TransformerCTCEncoder(FairseqEncoder):
if encoder_padding_mask is not None:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
ctc_padding_mask = encoder_padding_mask
if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
x, encoder_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_padding_mask = encoder_padding_mask
# B x T x C -> T x B x C
x = x.transpose(0, 1)
bsz = x.size(1)
encoder_states = []
if return_all_hiddens:
encoder_states.append(x)
org_encoder_padding_mask = encoder_padding_mask
ctc_padding_mask = encoder_padding_mask
if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
ctc_padding_mask = encoder_padding_mask.unsqueeze(-1). \
expand(-1, -1, self.interleaved_ctc_upsampling_ratio).reshape(bsz, -1)
# add emb into history
if self.history is not None:
self.history.push(x)
......@@ -795,10 +804,6 @@ class TransformerCTCEncoder(FairseqEncoder):
if self.history is not None:
x = self.history.pop()
if layer_idx + 1 in self.interleaved_ctc_layers:
x = self.upsampling(x)
encoder_padding_mask = ctc_padding_mask
x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None
)
......@@ -809,7 +814,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(self.upsampling(x.clone()), ctc_padding_mask)
ctc_logit = self.ctc(x.clone(), ctc_padding_mask)
# Interleaved CTC
if layer_idx in self.interleaved_ctc_layers:
......@@ -826,10 +831,10 @@ class TransformerCTCEncoder(FairseqEncoder):
x, _ = self.sae([norm_x, prob])
x = x.permute(1, 2, 0)
x = self.pool(x)
x = x.permute(2, 0, 1)
encoder_padding_mask = org_encoder_padding_mask
# x = x.permute(1, 2, 0)
# x = self.pool(x)
# x = x.permute(2, 0, 1)
# encoder_padding_mask = org_encoder_padding_mask
if self.history is not None:
self.history.push(x)
......@@ -841,7 +846,7 @@ class TransformerCTCEncoder(FairseqEncoder):
x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(self.upsampling(x), ctc_padding_mask)
ctc_logit = self.ctc(x, ctc_padding_mask)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
......
......@@ -78,7 +78,7 @@ class CTC(nn.Module):
pred_units = self.dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
print(pred_words_raw)
logger.info("\nCTC prediction: %s" % " ".join(pred_words_raw))
def valid(self, logits_or_probs, target, lengths):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论