Commit e40eac14 by xuchen

optimize the information dump

parent d946bc3b
......@@ -23,12 +23,14 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
def __init__(self, task, label_smoothing,
sentence_avg,
cfg: CtcCriterionConfig,
ctc_weight=0.0):
ctc_weight=0.0,
save_dir=None):
super().__init__(task, sentence_avg, label_smoothing)
self.report_accuracy = True
self.ctc_weight = ctc_weight
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight)
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight, save_dir)
self.save_dir = save_dir
@staticmethod
def add_args(parser):
......@@ -62,6 +64,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
encoder_out = model.encoder(src_tokens, src_lengths,
text_src_tokens, text_src_lengths)
else:
if self.training and getattr(model.encoder, "sae_ground_truth_ratio", 0) != 0:
ctc_alignment_oracle = self.ctc_criterion.get_ground_truth_alignment(model, sample)
encoder_out = model.encoder(src_tokens, src_lengths,
ctc_alignment_oracle=ctc_alignment_oracle)
else:
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
use_mixup = False
......
......@@ -2,7 +2,6 @@ import logging
from typing import Dict, Optional
import torch
import torch.nn as nn
from fairseq import checkpoint_utils, utils
from fairseq.models import (
......@@ -138,7 +137,6 @@ class CTCDecoder(object):
# the max beam size is the dictionary size - 1, since we never select pad
self.beam_size = min(self.beam_size, self.vocab_size - 1)
# from fairseq.sequence_generator import EnsembleModel
from fairseq.sequence_generator import EnsembleModel
if isinstance(models, EnsembleModel):
self.model = models
......@@ -240,8 +238,12 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
......@@ -276,7 +278,9 @@ def base_architecture(args):
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False)
args.sae_embed_norm = getattr(args, "sae_embed_norm", False)
args.sae_out_norm = getattr(args, "sae_out_norm", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
......@@ -310,8 +314,6 @@ def base_architecture(args):
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
@register_model_architecture("s2t_ctc", "s2t_ctc_s")
def s2t_ctc_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
......
......@@ -125,7 +125,7 @@ class S2TSATEModel(S2TTransformerModel):
# target CTC
parser.add_argument(
"--target-ctc-layer",
default=None,
default=0,
type=int,
help="ctc layer for target sentence",
)
......@@ -233,15 +233,15 @@ class S2TSATEModel(S2TTransformerModel):
return cls(encoder, decoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens):
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out, **kwargs
)
return decoder_out
......@@ -286,7 +286,9 @@ class TextualEncoder(FairseqEncoder):
self.use_ctc = getattr(args, "target_ctc_weight", 0) > 0
if self.use_ctc:
self.ctc_layer = getattr(args, "target_ctc_layer", layer_num)
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.ctc_layer == 0:
self.ctc_layer = layer_num
self.inter_ctc = True if self.ctc_layer != layer_num else False
if self.inter_ctc:
logger.info("Target CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(embed_dim,
......@@ -294,13 +296,16 @@ class TextualEncoder(FairseqEncoder):
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
if embed_tokens is not None:
if embed_tokens is not None and args.share_target_ctc_and_embed and \
self.ctc.ctc_projection.weight.size() == embed_tokens.weight.size():
self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.interleaved_ctc_layers = []
self.target_interleaved_ctc_layers = getattr(args, "target_interleaved_ctc_layers", None)
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
if self.target_interleaved_ctc_layers is not None:
target_interleaved_ctc_layers = self.target_interleaved_ctc_layers.split(",")
for layer_idx in target_interleaved_ctc_layers:
......@@ -337,7 +342,7 @@ class TextualEncoder(FairseqEncoder):
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
def forward(self, x, encoder_padding_mask=None, history=None):
def forward(self, x, encoder_padding_mask=None, history=None, **kwargs):
if self.encoder_embed_norm:
x = self.embed_ln(x)
......@@ -356,7 +361,7 @@ class TextualEncoder(FairseqEncoder):
layer_idx += 1
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
target_ctc_logit = self.ctc(x.clone())
target_ctc_logit = self.ctc(x.clone(), encoder_padding_mask, "Target Layer %d" % layer_idx)
if layer_idx != self.layer_num and layer_idx in self.interleaved_ctc_layers:
if self.interleaved_ctc_drop_prob > 0:
......@@ -365,11 +370,23 @@ class TextualEncoder(FairseqEncoder):
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
logit = self.ctc(norm_x, encoder_padding_mask, "Target Layer %d" % layer_idx)
target_interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask)
# CTC alignment
oracle = None
oracle_mask = None
force_emit = None
if self.sae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ctc_alignment_oracle is not None and ctc_alignment_oracle["target"] is not None:
oracle, best_aligns_pad = ctc_alignment_oracle["target"]
oracle_mask = (torch.rand(oracle.size(),
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask, oracle, oracle_mask)
if history is not None:
history.push(x)
......@@ -381,7 +398,7 @@ class TextualEncoder(FairseqEncoder):
x = self.layer_norm(x)
if self.use_ctc and target_ctc_logit is None:
target_ctc_logit = self.ctc(x)
target_ctc_logit = self.ctc(x, encoder_padding_mask, "Target output")
return x, target_ctc_logit, target_interleaved_ctc_logits
......@@ -435,6 +452,7 @@ class S2TSATEEncoder(FairseqEncoder):
self.freeze_acoustic_encoder = getattr(args, "freeze_acoustic_encoder", False)
self.freeze_textual_encoder = getattr(args, "freeze_textual_encoder", False)
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
if getattr(args, "use_enc_dlcl", False):
layer_num = args.encoder_layers + args.text_encoder_layers + 2
......@@ -443,21 +461,34 @@ class S2TSATEEncoder(FairseqEncoder):
self.history = None
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
if hasattr(self, "ctc"):
if hasattr(self.acoustic_encoder, "ctc"):
assert src_dict is not None
self.acoustic_encoder.ctc.set_infer(ctc_infer, post_process, src_dict)
if hasattr(self.textual_encoder, "ctc"):
assert tgt_dict is not None
self.textual_encoder.ctc.set_infer(ctc_infer, post_process, tgt_dict)
def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"):
if lang == "source":
if hasattr(self.acoustic_encoder, "ctc"):
return self.acoustic_encoder.ctc.valid(lprobs, targets, input_lengths, dictionary)
else:
logger.error("No ctc module in textual encoder")
else:
if hasattr(self.textual_encoder, "ctc"):
return self.textual_encoder.ctc.valid(lprobs, targets, input_lengths, dictionary)
else:
logger.error("No ctc module in textual encoder")
def forward(self, src_tokens, src_lengths=None, **kwargs):
if self.history is not None:
self.history.clean()
if self.freeze_acoustic_encoder:
with torch.no_grad():
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths)
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths, **kwargs)
else:
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths)
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths, **kwargs)
encoder_out = acoustic_encoder_out["encoder_out"][0]
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
......@@ -490,16 +521,16 @@ class S2TSATEEncoder(FairseqEncoder):
if self.freeze_textual_encoder:
with torch.no_grad():
x, target_ctc_logit, target_interleaved_ctc_logits = self.textual_encoder(x, encoder_padding_mask,
self.history)
self.history, **kwargs)
else:
x, target_ctc_logit, target_interleaved_ctc_logits = self.textual_encoder(x, encoder_padding_mask,
self.history)
self.history, **kwargs)
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C
"interleaved_ctc_logits": acoustic_encoder_out.get("interleaved_ctc_logits", []), # B x T x C
"target_ctc_logit": target_ctc_logit, # B x T x C
"target_ctc_logit": [target_ctc_logit], # B x T x C
"target_interleaved_ctc_logits": target_interleaved_ctc_logits, # B x T x C
"ctc_padding_mask": [ctc_padding_mask], # B x T
"encoder_padding_mask": [encoder_padding_mask], # B x T
......
......@@ -447,6 +447,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="cutoff of the distribution in sae",
)
parser.add_argument(
"--sae-ground-truth-ratio",
default=0,
type=float,
help="the ratio for ground truth in sae",
)
parser.add_argument(
"--share-sae-and-ctc",
action="store_true",
help="share the weight of ctc and sae",
......@@ -570,13 +576,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
lprobs.batch_first = True
return lprobs
def forward(self, src_tokens, src_lengths, prev_output_tokens):
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
......@@ -655,6 +661,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
self.interleaved_ctc_layers = []
if args.interleaved_ctc_layers is not None:
interleaved_ctc_layers = args.interleaved_ctc_layers.split(",")
......@@ -681,6 +688,7 @@ class S2TTransformerEncoder(FairseqEncoder):
"out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"gt_ratio": self.sae_ground_truth_ratio,
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
......@@ -717,6 +725,14 @@ class S2TTransformerEncoder(FairseqEncoder):
assert src_dict is not None
self.ctc.set_infer(ctc_infer, post_process, src_dict)
def ctc_valid(self, lprobs, targets, input_lengths,
dictionary, lang="source"):
if hasattr(self, "ctc"):
return self.ctc.valid(lprobs, targets, input_lengths,
dictionary)
else:
logger.error("No ctc module in textual encoder")
def set_debug_var(self, debug_var_flag):
self.debug_var = debug_var_flag
......@@ -879,7 +895,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone(), encoder_padding_mask)
ctc_logit = self.ctc(x.clone(), encoder_padding_mask, "Source Layer %d" % layer_idx)
# interleaved CTC
if layer_idx in self.interleaved_ctc_layers:
......@@ -889,15 +905,27 @@ class S2TTransformerEncoder(FairseqEncoder):
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, encoder_padding_mask)
logit = self.ctc(norm_x, encoder_padding_mask, "Source Layer %d" % layer_idx)
interleaved_ctc_logits.append(logit)
logit = logit.clamp(min=-1e8 if logit.dtype == torch.float32 else -1e4,
max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask)
# CTC alignment
oracle = None
oracle_mask = None
force_emit = None
if self.sae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ctc_alignment_oracle is not None and ctc_alignment_oracle["source"] is not None:
oracle, best_aligns_pad = ctc_alignment_oracle["source"]
oracle_mask = (torch.rand(oracle.size(),
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask, oracle, oracle_mask)
self.show_debug(x, "x after sae")
# gather cosine similarity
......@@ -917,7 +945,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.show_debug(x, "x after encoding layer norm")
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, encoder_padding_mask)
ctc_logit = self.ctc(x, encoder_padding_mask, "Source output")
self.show_debug(x, "x after ctc")
return {
......@@ -925,6 +953,7 @@ class S2TTransformerEncoder(FairseqEncoder):
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"interleaved_ctc_logits": interleaved_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
# "oracle": [oracle, oracle_mask, force_emit],
"mixup": mixup,
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
......
......@@ -315,7 +315,7 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
parser.add_argument(
"--interleaved-ctc-upsampling-ratio",
default=2,
type=int,
type=float,
help="upsampling ratio of the representation for CTC calculation",
)
parser.add_argument(
......@@ -355,6 +355,24 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
action="store_true",
help="share the weight of ctc and sae",
)
parser.add_argument(
"--sae-embed-norm",
default=False,
action="store_true",
help="use the layer norm for embed output",
)
parser.add_argument(
"--sae-out-norm",
default=False,
action="store_true",
help="use the layer norm for final output",
)
parser.add_argument(
"--sae-ground-truth-ratio",
default=0,
type=float,
help="the ratio for ground truth in sae",
)
# fmt: on
@classmethod
......@@ -625,6 +643,11 @@ class TransformerCTCEncoder(FairseqEncoder):
logger.info("Interleaved CTC loss in layer %d" % layer_idx)
self.un_sample = torch.nn.Upsample(scale_factor=self.interleaved_ctc_upsampling_ratio, mode="linear",
align_corners=True)
self.down_sample = torch.nn.Upsample(scale_factor=1 / self.interleaved_ctc_upsampling_ratio, mode="linear",
align_corners=True)
if not self.use_ctc:
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
......@@ -633,10 +656,14 @@ class TransformerCTCEncoder(FairseqEncoder):
self.ctc.ctc_projection.weight = decoder_embed_tokens.weight
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
strategy = {
"embed_norm": getattr(args, "sae_embed_norm", False),
"out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
"gt_ratio": self.sae_ground_truth_ratio,
}
self.sae = Adapter(embed_dim, args.sae_adapter,
......@@ -645,9 +672,9 @@ class TransformerCTCEncoder(FairseqEncoder):
)
if args.share_ctc_and_sae and hasattr(self.sae, "embed_adapter"):
self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
if hasattr(self, "ctc"):
self.pool = nn.MaxPool1d(kernel_size=self.interleaved_ctc_upsampling_ratio,
stride=self.interleaved_ctc_upsampling_ratio)
# if hasattr(self, "ctc"):
# self.pool = nn.MaxPool1d(kernel_size=self.interleaved_ctc_upsampling_ratio,
# stride=self.interleaved_ctc_upsampling_ratio)
def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args)
......@@ -679,6 +706,7 @@ class TransformerCTCEncoder(FairseqEncoder):
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
**kwargs
):
"""
Args:
......@@ -706,7 +734,8 @@ class TransformerCTCEncoder(FairseqEncoder):
return self.forward_scriptable(src_tokens,
src_lengths,
return_all_hiddens,
token_embeddings)
token_embeddings,
**kwargs)
def upsampling(self, x, padding):
ratio = self.interleaved_ctc_upsampling_ratio
......@@ -714,12 +743,17 @@ class TransformerCTCEncoder(FairseqEncoder):
return x
if len(x.size()) == 3:
bsz, seq_len, dim = x.size()
up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
# bsz, seq_len, dim = x.size()
# up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
seq_len, bsz, dim = x.size()
x = x.permute(1, 2, 0)
up_x = self.un_sample(x)
up_x = up_x.permute(2, 0, 1)
else:
bsz, seq_len = x.size()
up_x = x.unsqueeze(2).expand(-1, -1, ratio).reshape(bsz, -1)
up_padding = padding.unsqueeze(-1).expand(-1, -1, ratio).reshape(bsz, -1)
up_padding = padding.unsqueeze(-1).expand(-1, -1, int(ratio)).reshape(bsz, -1)
# output_length = int(seq_len * ratio * 2/3)
# select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
......@@ -742,6 +776,14 @@ class TransformerCTCEncoder(FairseqEncoder):
assert tgt_dict is not None
self.ctc.set_infer(ctc_infer, post_process, tgt_dict)
def ctc_valid(self, lprobs, targets, input_lengths,
dictionary, lang="source"):
if hasattr(self, "ctc"):
return self.ctc.valid(lprobs, targets, input_lengths,
dictionary)
else:
logger.error("No ctc module in textual encoder")
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
......@@ -752,6 +794,7 @@ class TransformerCTCEncoder(FairseqEncoder):
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
**kwargs
):
"""
Args:
......@@ -783,10 +826,10 @@ class TransformerCTCEncoder(FairseqEncoder):
if self.history is not None:
self.history.clean()
ctc_padding_mask = encoder_padding_mask
if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
src_tokens, encoder_padding_mask = self.upsampling(src_tokens, encoder_padding_mask)
ctc_padding_mask = encoder_padding_mask
# ctc_padding_mask = encoder_padding_mask
# if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
# src_tokens, encoder_padding_mask = self.upsampling(src_tokens, encoder_padding_mask)
# ctc_padding_mask = encoder_padding_mask
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
......@@ -796,6 +839,8 @@ class TransformerCTCEncoder(FairseqEncoder):
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# x, encoder_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_padding_mask = encoder_padding_mask
encoder_states = []
......@@ -824,6 +869,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_logit = self.ctc(x.clone(), ctc_padding_mask)
# Interleaved CTC
......@@ -833,18 +879,31 @@ class TransformerCTCEncoder(FairseqEncoder):
if p < self.interleaved_ctc_drop_prob:
break
x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, ctc_padding_mask)
interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, _ = self.sae([norm_x, prob])
# x = x.permute(1, 2, 0)
# x = self.pool(x)
# x = x.permute(2, 0, 1)
# encoder_padding_mask = org_encoder_padding_mask
# CTC alignment
oracle = None
oracle_mask = None
force_emit = None
if self.sae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ctc_alignment_oracle is not None and ctc_alignment_oracle["source"] is not None:
oracle, best_aligns_pad = ctc_alignment_oracle["source"]
oracle_mask = (torch.rand(oracle.size(),
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, _ = self.sae([norm_x, prob], ctc_padding_mask, oracle, oracle_mask)
x = x.permute(1, 2, 0)
# x = nn.functional.interpolate(x, scale_factor=1/self.interleaved_ctc_upsampling_ratio, mode="linear")
x = self.down_sample(x)
x = x.permute(2, 0, 1)
if self.history is not None:
self.history.push(x)
......
......@@ -98,7 +98,7 @@ class Adapter(nn.Module):
self.ctc_compress = getattr(CTCCompressStrategy, ctc_compress_strategy)
logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy)
if "league" in self.adapter_type:
if self.cal_context:
self.distribution_cutoff = strategy.get("distribution_cutoff", None)
if self.distribution_cutoff is not None:
self.distribution_cutoff = int(self.distribution_cutoff)
......@@ -107,17 +107,21 @@ class Adapter(nn.Module):
self.drop_prob = strategy.get("drop_prob", 0)
if self.drop_prob != 0:
logger.info("Adapter drop probability: %f" % self.drop_prob)
self.ground_truth_ratio = strategy.get("gt_ratio", 0)
self.out_norm = strategy.get("out_norm", False)
if self.out_norm:
self.out_ln = LayerNorm(dim)
def forward(self, x, padding=None):
def forward(self, x, padding=None, oracle=None, oracle_mask=None):
representation, distribution = x
distribution = distribution.type_as(representation)
seq_len, bsz, dim = representation.size()
org_distribution = distribution
distribution = distribution.contiguous().view(-1, distribution.size(-1))
vocab_size = distribution.size(-1)
distribution = distribution.contiguous().view(-1, vocab_size)
linear_out = None
soft_out = None
......@@ -125,18 +129,32 @@ class Adapter(nn.Module):
linear_out = self.linear_adapter(representation)
if self.cal_context:
if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), org_distribution.size(-1) - 1)
cutoff = min(int(self.distribution_cutoff), vocab_size - 1)
# threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
# distribution = torch.where(
# org_distribution > threshold, org_distribution, torch.zeros_like(org_distribution)
# )
threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True)
distribution = torch.where(
threshold > 0.9, org_distribution, torch.zeros_like(org_distribution)
)
distribution = distribution.view(-1, distribution.size(-1))
# threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True)
# distribution = torch.where(
# threshold > 0.9, org_distribution, torch.zeros_like(org_distribution)
# )
# distribution = distribution.view(-1, vocab_size)
distribution[:, 0] = 0
distribution = distribution / distribution.sum(-1, keepdim=True)
if self.ground_truth_ratio > 0 and oracle is not None:
oracle = oracle.unsqueeze(-1)
oracle_one_hot = (oracle == torch.arange(vocab_size, device=oracle.device).unsqueeze(0)).\
to(distribution.dtype).transpose(0, 1)
oracle_mask = oracle_mask.transpose(0, 1).unsqueeze(-1).repeat(1, 1, vocab_size)
modify_dist = oracle_mask * oracle_one_hot + ~oracle_mask * org_distribution
soft_out = torch.mm(modify_dist.view(-1, vocab_size), self.embed_adapter.weight).view(seq_len, bsz, -1)
else:
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
if self.embed_norm:
soft_out = self.embed_ln(soft_out)
......
......@@ -37,13 +37,14 @@ class CTC(nn.Module):
self.dictionary = dictionary
self.infer_decoding = False
self.post_process = "sentencepiece"
self.blank_idx = 0
def set_infer(self, is_infer, text_post_process, dictionary):
self.infer_decoding = is_infer
self.post_process = text_post_process
self.dictionary = dictionary
def forward(self, x, padding=None):
def forward(self, x, padding=None, tag=None):
if self.need_layernorm:
x = self.LayerNorm(x)
......@@ -52,7 +53,7 @@ class CTC(nn.Module):
if not self.training and self.infer_decoding:
assert self.dictionary is not None
input_lengths = (~padding).sum(-1)
self.infer(x.transpose(0, 1).float().contiguous().cpu(), input_lengths)
self.infer(x.transpose(0, 1).float().contiguous().cpu(), input_lengths, tag)
return x
......@@ -65,7 +66,7 @@ class CTC(nn.Module):
def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1)
def infer(self, logits_or_probs, lengths):
def infer(self, logits_or_probs, lengths, tag=None):
for lp, inp_l in zip(
logits_or_probs,
lengths,
......@@ -78,27 +79,30 @@ class CTC(nn.Module):
pred_units = self.dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
logger.info("\nCTC prediction: %s" % " ".join(pred_words_raw))
if tag is not None:
logger.info("%s CTC prediction: %s" % (tag, " ".join(pred_words_raw)))
else:
logger.info("CTC prediction: %s" % (" ".join(pred_words_raw)))
def valid(self, logits_or_probs, target, lengths):
def valid(self, logits_or_probs, targets, input_lengths, dictionary):
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
wv_errs = 0
with torch.no_grad():
for lp, t, inp_l in zip(
logits_or_probs,
target,
lengths,
targets,
input_lengths,
):
lp = lp[:inp_l].unsqueeze(0)
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
)
p = (t != dictionary.pad()) & (t != dictionary.eos())
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units = dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
......@@ -109,10 +113,13 @@ class CTC(nn.Module):
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_units = dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
w_len += len(targ_words)
return c_err, c_len, w_errs, w_len, wv_errs
from .imputer import imputer_loss, ImputerLoss, best_alignment, ctc_decode
#include <torch/extension.h>
#include <tuple>
std::tuple<torch::Tensor, torch::Tensor>
imputer_loss_op(const torch::Tensor &log_probs, const torch::Tensor &targets,
const torch::Tensor &force_emits, at::IntArrayRef input_lengths,
at::IntArrayRef target_lengths, int64_t BLANK,
bool zero_infinity);
torch::Tensor imputer_loss_backward_op(
const torch::Tensor &grad, const torch::Tensor &log_probs,
const torch::Tensor &targets, const torch::Tensor &force_emits,
at::IntArrayRef input_lengths, at::IntArrayRef target_lengths,
const torch::Tensor &neg_log_likelihood, const torch::Tensor &log_alpha,
int64_t BLANK, bool zero_infinity);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
best_alignment_op(const torch::Tensor &log_probs, const torch::Tensor &targets,
at::IntArrayRef input_lengths, at::IntArrayRef target_lengths,
int64_t BLANK, bool zero_infinity);
std::tuple<torch::Tensor, torch::Tensor> imputer_loss(
const torch::Tensor &log_probs, const torch::Tensor &targets,
const torch::Tensor &force_emits, const torch::Tensor &input_lengths,
const torch::Tensor &target_lengths, int64_t BLANK, bool zero_infinity) {
torch::Tensor ilc =
input_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
torch::Tensor tlc =
target_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
at::IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
at::IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
auto res =
imputer_loss_op(log_probs, targets.to(log_probs.device(), at::kLong),
force_emits.to(log_probs.device(), at::kLong), il, tl,
BLANK, zero_infinity);
return res;
}
torch::Tensor imputer_loss_backward(
const torch::Tensor &grad, const torch::Tensor &log_probs,
const torch::Tensor &targets, const torch::Tensor &force_emits,
const torch::Tensor &input_lengths, const torch::Tensor &target_lengths,
const torch::Tensor &neg_log_likelihood, const torch::Tensor &log_alpha,
int64_t BLANK, bool zero_infinity) {
torch::Tensor ilc =
input_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
torch::Tensor tlc =
target_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
at::IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
at::IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
torch::Tensor res;
res = imputer_loss_backward_op(
grad, log_probs, targets.to(log_probs.device(), at::kLong),
force_emits.to(log_probs.device(), at::kLong), il, tl, neg_log_likelihood,
log_alpha, BLANK, zero_infinity);
return res;
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
best_alignment(const torch::Tensor &log_probs, const torch::Tensor &targets,
const torch::Tensor &input_lengths,
const torch::Tensor &target_lengths, int64_t BLANK,
bool zero_infinity) {
torch::Tensor ilc =
input_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
torch::Tensor tlc =
target_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
at::IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
at::IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
auto res =
best_alignment_op(log_probs, targets.to(log_probs.device(), at::kLong),
il, tl, BLANK, zero_infinity);
return res;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("imputer_loss", &imputer_loss, "calculate imputer loss");
m.def("imputer_loss_backward", &imputer_loss_backward,
"calculate imputer loss gradient");
m.def("best_alignment", &best_alignment, "get best alignments for ctc");
}
\ No newline at end of file
......@@ -107,7 +107,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
)
for model in models:
if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"):
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, cfg.common_eval.post_process,
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece",
src_dict, tgt_dict)
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论