Commit e40eac14 by xuchen

optimize the information dump

parent d946bc3b
...@@ -23,12 +23,14 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -23,12 +23,14 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
def __init__(self, task, label_smoothing, def __init__(self, task, label_smoothing,
sentence_avg, sentence_avg,
cfg: CtcCriterionConfig, cfg: CtcCriterionConfig,
ctc_weight=0.0): ctc_weight=0.0,
save_dir=None):
super().__init__(task, sentence_avg, label_smoothing) super().__init__(task, sentence_avg, label_smoothing)
self.report_accuracy = True self.report_accuracy = True
self.ctc_weight = ctc_weight self.ctc_weight = ctc_weight
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight) self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight, save_dir)
self.save_dir = save_dir
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -62,7 +64,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -62,7 +64,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
encoder_out = model.encoder(src_tokens, src_lengths, encoder_out = model.encoder(src_tokens, src_lengths,
text_src_tokens, text_src_lengths) text_src_tokens, text_src_lengths)
else: else:
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths) if self.training and getattr(model.encoder, "sae_ground_truth_ratio", 0) != 0:
ctc_alignment_oracle = self.ctc_criterion.get_ground_truth_alignment(model, sample)
encoder_out = model.encoder(src_tokens, src_lengths,
ctc_alignment_oracle=ctc_alignment_oracle)
else:
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
use_mixup = False use_mixup = False
if "mixup" in encoder_out and encoder_out["mixup"] is not None: if "mixup" in encoder_out and encoder_out["mixup"] is not None:
......
...@@ -2,7 +2,6 @@ import logging ...@@ -2,7 +2,6 @@ import logging
from typing import Dict, Optional from typing import Dict, Optional
import torch import torch
import torch.nn as nn
from fairseq import checkpoint_utils, utils from fairseq import checkpoint_utils, utils
from fairseq.models import ( from fairseq.models import (
...@@ -138,7 +137,6 @@ class CTCDecoder(object): ...@@ -138,7 +137,6 @@ class CTCDecoder(object):
# the max beam size is the dictionary size - 1, since we never select pad # the max beam size is the dictionary size - 1, since we never select pad
self.beam_size = min(self.beam_size, self.vocab_size - 1) self.beam_size = min(self.beam_size, self.vocab_size - 1)
# from fairseq.sequence_generator import EnsembleModel
from fairseq.sequence_generator import EnsembleModel from fairseq.sequence_generator import EnsembleModel
if isinstance(models, EnsembleModel): if isinstance(models, EnsembleModel):
self.model = models self.model = models
...@@ -240,8 +238,12 @@ def base_architecture(args): ...@@ -240,8 +238,12 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
# CTC # CTC
args.ctc_layer = getattr(args, "ctc_layer", 0) args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer # Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu") args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
...@@ -276,7 +278,9 @@ def base_architecture(args): ...@@ -276,7 +278,9 @@ def base_architecture(args):
# Semantics-augmented Encoding (sae) # Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none") args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False) args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False)
args.sae_embed_norm = getattr(args, "sae_embed_norm", False)
args.sae_out_norm = getattr(args, "sae_out_norm", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0) args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None) args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
...@@ -310,8 +314,6 @@ def base_architecture(args): ...@@ -310,8 +314,6 @@ def base_architecture(args):
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv") args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
@register_model_architecture("s2t_ctc", "s2t_ctc_s") @register_model_architecture("s2t_ctc", "s2t_ctc_s")
def s2t_ctc_s(args): def s2t_ctc_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
......
...@@ -125,7 +125,7 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -125,7 +125,7 @@ class S2TSATEModel(S2TTransformerModel):
# target CTC # target CTC
parser.add_argument( parser.add_argument(
"--target-ctc-layer", "--target-ctc-layer",
default=None, default=0,
type=int, type=int,
help="ctc layer for target sentence", help="ctc layer for target sentence",
) )
...@@ -233,15 +233,15 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -233,15 +233,15 @@ class S2TSATEModel(S2TTransformerModel):
return cls(encoder, decoder) return cls(encoder, decoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens): def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
""" """
The forward method inherited from the base class has a **kwargs The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs. method overwrites the forward method definition without **kwargs.
""" """
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths) encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder( decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out prev_output_tokens=prev_output_tokens, encoder_out=encoder_out, **kwargs
) )
return decoder_out return decoder_out
...@@ -286,7 +286,9 @@ class TextualEncoder(FairseqEncoder): ...@@ -286,7 +286,9 @@ class TextualEncoder(FairseqEncoder):
self.use_ctc = getattr(args, "target_ctc_weight", 0) > 0 self.use_ctc = getattr(args, "target_ctc_weight", 0) > 0
if self.use_ctc: if self.use_ctc:
self.ctc_layer = getattr(args, "target_ctc_layer", layer_num) self.ctc_layer = getattr(args, "target_ctc_layer", layer_num)
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False if self.ctc_layer == 0:
self.ctc_layer = layer_num
self.inter_ctc = True if self.ctc_layer != layer_num else False
if self.inter_ctc: if self.inter_ctc:
logger.info("Target CTC loss in layer %d" % self.ctc_layer) logger.info("Target CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(embed_dim, self.ctc = CTC(embed_dim,
...@@ -294,13 +296,16 @@ class TextualEncoder(FairseqEncoder): ...@@ -294,13 +296,16 @@ class TextualEncoder(FairseqEncoder):
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False) need_layernorm=True if self.inter_ctc else False)
if embed_tokens is not None: if embed_tokens is not None and args.share_target_ctc_and_embed and \
self.ctc.ctc_projection.weight.size() == embed_tokens.weight.size():
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.interleaved_ctc_layers = [] self.interleaved_ctc_layers = []
self.target_interleaved_ctc_layers = getattr(args, "target_interleaved_ctc_layers", None) self.target_interleaved_ctc_layers = getattr(args, "target_interleaved_ctc_layers", None)
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
if self.target_interleaved_ctc_layers is not None: if self.target_interleaved_ctc_layers is not None:
target_interleaved_ctc_layers = self.target_interleaved_ctc_layers.split(",") target_interleaved_ctc_layers = self.target_interleaved_ctc_layers.split(",")
for layer_idx in target_interleaved_ctc_layers: for layer_idx in target_interleaved_ctc_layers:
...@@ -337,7 +342,7 @@ class TextualEncoder(FairseqEncoder): ...@@ -337,7 +342,7 @@ class TextualEncoder(FairseqEncoder):
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
def forward(self, x, encoder_padding_mask=None, history=None): def forward(self, x, encoder_padding_mask=None, history=None, **kwargs):
if self.encoder_embed_norm: if self.encoder_embed_norm:
x = self.embed_ln(x) x = self.embed_ln(x)
...@@ -356,7 +361,7 @@ class TextualEncoder(FairseqEncoder): ...@@ -356,7 +361,7 @@ class TextualEncoder(FairseqEncoder):
layer_idx += 1 layer_idx += 1
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
target_ctc_logit = self.ctc(x.clone()) target_ctc_logit = self.ctc(x.clone(), encoder_padding_mask, "Target Layer %d" % layer_idx)
if layer_idx != self.layer_num and layer_idx in self.interleaved_ctc_layers: if layer_idx != self.layer_num and layer_idx in self.interleaved_ctc_layers:
if self.interleaved_ctc_drop_prob > 0: if self.interleaved_ctc_drop_prob > 0:
...@@ -365,11 +370,23 @@ class TextualEncoder(FairseqEncoder): ...@@ -365,11 +370,23 @@ class TextualEncoder(FairseqEncoder):
break break
norm_x = self.layer_norm(x) norm_x = self.layer_norm(x)
logit = self.ctc(norm_x) logit = self.ctc(norm_x, encoder_padding_mask, "Target Layer %d" % layer_idx)
target_interleaved_ctc_logits.append(logit) target_interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1) prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask)
# CTC alignment
oracle = None
oracle_mask = None
force_emit = None
if self.sae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ctc_alignment_oracle is not None and ctc_alignment_oracle["target"] is not None:
oracle, best_aligns_pad = ctc_alignment_oracle["target"]
oracle_mask = (torch.rand(oracle.size(),
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask, oracle, oracle_mask)
if history is not None: if history is not None:
history.push(x) history.push(x)
...@@ -381,7 +398,7 @@ class TextualEncoder(FairseqEncoder): ...@@ -381,7 +398,7 @@ class TextualEncoder(FairseqEncoder):
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and target_ctc_logit is None: if self.use_ctc and target_ctc_logit is None:
target_ctc_logit = self.ctc(x) target_ctc_logit = self.ctc(x, encoder_padding_mask, "Target output")
return x, target_ctc_logit, target_interleaved_ctc_logits return x, target_ctc_logit, target_interleaved_ctc_logits
...@@ -435,6 +452,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -435,6 +452,7 @@ class S2TSATEEncoder(FairseqEncoder):
self.freeze_acoustic_encoder = getattr(args, "freeze_acoustic_encoder", False) self.freeze_acoustic_encoder = getattr(args, "freeze_acoustic_encoder", False)
self.freeze_textual_encoder = getattr(args, "freeze_textual_encoder", False) self.freeze_textual_encoder = getattr(args, "freeze_textual_encoder", False)
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
if getattr(args, "use_enc_dlcl", False): if getattr(args, "use_enc_dlcl", False):
layer_num = args.encoder_layers + args.text_encoder_layers + 2 layer_num = args.encoder_layers + args.text_encoder_layers + 2
...@@ -443,11 +461,24 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -443,11 +461,24 @@ class S2TSATEEncoder(FairseqEncoder):
self.history = None self.history = None
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None): def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
if hasattr(self, "ctc"): if hasattr(self.acoustic_encoder, "ctc"):
assert src_dict is not None assert src_dict is not None
self.acoustic_encoder.ctc.set_infer(ctc_infer, post_process, src_dict) self.acoustic_encoder.ctc.set_infer(ctc_infer, post_process, src_dict)
if hasattr(self.textual_encoder, "ctc"):
assert tgt_dict is not None
self.textual_encoder.ctc.set_infer(ctc_infer, post_process, tgt_dict)
def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"):
if lang == "source":
if hasattr(self.acoustic_encoder, "ctc"):
return self.acoustic_encoder.ctc.valid(lprobs, targets, input_lengths, dictionary)
else:
logger.error("No ctc module in textual encoder")
else:
if hasattr(self.textual_encoder, "ctc"): if hasattr(self.textual_encoder, "ctc"):
self.textual_encoder.ctc.set_infer(ctc_infer, post_process, tgt_dict) return self.textual_encoder.ctc.valid(lprobs, targets, input_lengths, dictionary)
else:
logger.error("No ctc module in textual encoder")
def forward(self, src_tokens, src_lengths=None, **kwargs): def forward(self, src_tokens, src_lengths=None, **kwargs):
if self.history is not None: if self.history is not None:
...@@ -455,9 +486,9 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -455,9 +486,9 @@ class S2TSATEEncoder(FairseqEncoder):
if self.freeze_acoustic_encoder: if self.freeze_acoustic_encoder:
with torch.no_grad(): with torch.no_grad():
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths) acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths, **kwargs)
else: else:
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths) acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths, **kwargs)
encoder_out = acoustic_encoder_out["encoder_out"][0] encoder_out = acoustic_encoder_out["encoder_out"][0]
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0] encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
...@@ -490,16 +521,16 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -490,16 +521,16 @@ class S2TSATEEncoder(FairseqEncoder):
if self.freeze_textual_encoder: if self.freeze_textual_encoder:
with torch.no_grad(): with torch.no_grad():
x, target_ctc_logit, target_interleaved_ctc_logits = self.textual_encoder(x, encoder_padding_mask, x, target_ctc_logit, target_interleaved_ctc_logits = self.textual_encoder(x, encoder_padding_mask,
self.history) self.history, **kwargs)
else: else:
x, target_ctc_logit, target_interleaved_ctc_logits = self.textual_encoder(x, encoder_padding_mask, x, target_ctc_logit, target_interleaved_ctc_logits = self.textual_encoder(x, encoder_padding_mask,
self.history) self.history, **kwargs)
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C "ctc_logit": [ctc_logit], # T x B x C
"interleaved_ctc_logits": acoustic_encoder_out.get("interleaved_ctc_logits", []), # B x T x C "interleaved_ctc_logits": acoustic_encoder_out.get("interleaved_ctc_logits", []), # B x T x C
"target_ctc_logit": target_ctc_logit, # B x T x C "target_ctc_logit": [target_ctc_logit], # B x T x C
"target_interleaved_ctc_logits": target_interleaved_ctc_logits, # B x T x C "target_interleaved_ctc_logits": target_interleaved_ctc_logits, # B x T x C
"ctc_padding_mask": [ctc_padding_mask], # B x T "ctc_padding_mask": [ctc_padding_mask], # B x T
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
......
...@@ -447,6 +447,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -447,6 +447,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="cutoff of the distribution in sae", help="cutoff of the distribution in sae",
) )
parser.add_argument( parser.add_argument(
"--sae-ground-truth-ratio",
default=0,
type=float,
help="the ratio for ground truth in sae",
)
parser.add_argument(
"--share-sae-and-ctc", "--share-sae-and-ctc",
action="store_true", action="store_true",
help="share the weight of ctc and sae", help="share the weight of ctc and sae",
...@@ -570,13 +576,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -570,13 +576,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
lprobs.batch_first = True lprobs.batch_first = True
return lprobs return lprobs
def forward(self, src_tokens, src_lengths, prev_output_tokens): def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
""" """
The forward method inherited from the base class has a **kwargs The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs. method overwrites the forward method definition without **kwargs.
""" """
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths) encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder( decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
) )
...@@ -655,6 +661,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -655,6 +661,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
self.interleaved_ctc_layers = [] self.interleaved_ctc_layers = []
if args.interleaved_ctc_layers is not None: if args.interleaved_ctc_layers is not None:
interleaved_ctc_layers = args.interleaved_ctc_layers.split(",") interleaved_ctc_layers = args.interleaved_ctc_layers.split(",")
...@@ -681,6 +688,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -681,6 +688,7 @@ class S2TTransformerEncoder(FairseqEncoder):
"out_norm": getattr(args, "sae_out_norm", False), "out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None), "ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None), "distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"gt_ratio": self.sae_ground_truth_ratio,
"drop_prob": getattr(args, "sae_drop_prob", 0), "drop_prob": getattr(args, "sae_drop_prob", 0),
} }
...@@ -717,6 +725,14 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -717,6 +725,14 @@ class S2TTransformerEncoder(FairseqEncoder):
assert src_dict is not None assert src_dict is not None
self.ctc.set_infer(ctc_infer, post_process, src_dict) self.ctc.set_infer(ctc_infer, post_process, src_dict)
def ctc_valid(self, lprobs, targets, input_lengths,
dictionary, lang="source"):
if hasattr(self, "ctc"):
return self.ctc.valid(lprobs, targets, input_lengths,
dictionary)
else:
logger.error("No ctc module in textual encoder")
def set_debug_var(self, debug_var_flag): def set_debug_var(self, debug_var_flag):
self.debug_var = debug_var_flag self.debug_var = debug_var_flag
...@@ -879,7 +895,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -879,7 +895,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask) x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone(), encoder_padding_mask) ctc_logit = self.ctc(x.clone(), encoder_padding_mask, "Source Layer %d" % layer_idx)
# interleaved CTC # interleaved CTC
if layer_idx in self.interleaved_ctc_layers: if layer_idx in self.interleaved_ctc_layers:
...@@ -889,15 +905,27 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -889,15 +905,27 @@ class S2TTransformerEncoder(FairseqEncoder):
break break
norm_x = self.layer_norm(x) norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, encoder_padding_mask) logit = self.ctc(norm_x, encoder_padding_mask, "Source Layer %d" % layer_idx)
interleaved_ctc_logits.append(logit) interleaved_ctc_logits.append(logit)
logit = logit.clamp(min=-1e8 if logit.dtype == torch.float32 else -1e4, logit = logit.clamp(min=-1e8 if logit.dtype == torch.float32 else -1e4,
max=1e8 if logit.dtype == torch.float32 else 1e4) max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1) prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask)
# CTC alignment
oracle = None
oracle_mask = None
force_emit = None
if self.sae_ground_truth_ratio > 0:
ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ctc_alignment_oracle is not None and ctc_alignment_oracle["source"] is not None:
oracle, best_aligns_pad = ctc_alignment_oracle["source"]
oracle_mask = (torch.rand(oracle.size(),
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask, oracle, oracle_mask)
self.show_debug(x, "x after sae") self.show_debug(x, "x after sae")
# gather cosine similarity # gather cosine similarity
...@@ -917,7 +945,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -917,7 +945,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.show_debug(x, "x after encoding layer norm") self.show_debug(x, "x after encoding layer norm")
if self.use_ctc and ctc_logit is None: if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, encoder_padding_mask) ctc_logit = self.ctc(x, encoder_padding_mask, "Source output")
self.show_debug(x, "x after ctc") self.show_debug(x, "x after ctc")
return { return {
...@@ -925,6 +953,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -925,6 +953,7 @@ class S2TTransformerEncoder(FairseqEncoder):
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C "ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"interleaved_ctc_logits": interleaved_ctc_logits, # T x B x C "interleaved_ctc_logits": interleaved_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
# "oracle": [oracle, oracle_mask, force_emit],
"mixup": mixup, "mixup": mixup,
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C] "encoder_states": [], # List[T x B x C]
......
...@@ -315,7 +315,7 @@ class TransformerCTCModel(FairseqEncoderDecoderModel): ...@@ -315,7 +315,7 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
parser.add_argument( parser.add_argument(
"--interleaved-ctc-upsampling-ratio", "--interleaved-ctc-upsampling-ratio",
default=2, default=2,
type=int, type=float,
help="upsampling ratio of the representation for CTC calculation", help="upsampling ratio of the representation for CTC calculation",
) )
parser.add_argument( parser.add_argument(
...@@ -355,6 +355,24 @@ class TransformerCTCModel(FairseqEncoderDecoderModel): ...@@ -355,6 +355,24 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
action="store_true", action="store_true",
help="share the weight of ctc and sae", help="share the weight of ctc and sae",
) )
parser.add_argument(
"--sae-embed-norm",
default=False,
action="store_true",
help="use the layer norm for embed output",
)
parser.add_argument(
"--sae-out-norm",
default=False,
action="store_true",
help="use the layer norm for final output",
)
parser.add_argument(
"--sae-ground-truth-ratio",
default=0,
type=float,
help="the ratio for ground truth in sae",
)
# fmt: on # fmt: on
@classmethod @classmethod
...@@ -625,6 +643,11 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -625,6 +643,11 @@ class TransformerCTCEncoder(FairseqEncoder):
logger.info("Interleaved CTC loss in layer %d" % layer_idx) logger.info("Interleaved CTC loss in layer %d" % layer_idx)
self.un_sample = torch.nn.Upsample(scale_factor=self.interleaved_ctc_upsampling_ratio, mode="linear",
align_corners=True)
self.down_sample = torch.nn.Upsample(scale_factor=1 / self.interleaved_ctc_upsampling_ratio, mode="linear",
align_corners=True)
if not self.use_ctc: if not self.use_ctc:
self.ctc = CTC(embed_dim, self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings, dictionary_size=decoder_embed_tokens.num_embeddings,
...@@ -633,10 +656,14 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -633,10 +656,14 @@ class TransformerCTCEncoder(FairseqEncoder):
self.ctc.ctc_projection.weight = decoder_embed_tokens.weight self.ctc.ctc_projection.weight = decoder_embed_tokens.weight
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
strategy = { strategy = {
"embed_norm": getattr(args, "sae_embed_norm", False),
"out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None), "ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None), "distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"drop_prob": getattr(args, "sae_drop_prob", 0), "drop_prob": getattr(args, "sae_drop_prob", 0),
"gt_ratio": self.sae_ground_truth_ratio,
} }
self.sae = Adapter(embed_dim, args.sae_adapter, self.sae = Adapter(embed_dim, args.sae_adapter,
...@@ -645,9 +672,9 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -645,9 +672,9 @@ class TransformerCTCEncoder(FairseqEncoder):
) )
if args.share_ctc_and_sae and hasattr(self.sae, "embed_adapter"): if args.share_ctc_and_sae and hasattr(self.sae, "embed_adapter"):
self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
if hasattr(self, "ctc"): # if hasattr(self, "ctc"):
self.pool = nn.MaxPool1d(kernel_size=self.interleaved_ctc_upsampling_ratio, # self.pool = nn.MaxPool1d(kernel_size=self.interleaved_ctc_upsampling_ratio,
stride=self.interleaved_ctc_upsampling_ratio) # stride=self.interleaved_ctc_upsampling_ratio)
def build_encoder_layer(self, args): def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args) layer = TransformerEncoderLayer(args)
...@@ -679,6 +706,7 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -679,6 +706,7 @@ class TransformerCTCEncoder(FairseqEncoder):
src_lengths: Optional[torch.Tensor] = None, src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False, return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None,
**kwargs
): ):
""" """
Args: Args:
...@@ -706,7 +734,8 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -706,7 +734,8 @@ class TransformerCTCEncoder(FairseqEncoder):
return self.forward_scriptable(src_tokens, return self.forward_scriptable(src_tokens,
src_lengths, src_lengths,
return_all_hiddens, return_all_hiddens,
token_embeddings) token_embeddings,
**kwargs)
def upsampling(self, x, padding): def upsampling(self, x, padding):
ratio = self.interleaved_ctc_upsampling_ratio ratio = self.interleaved_ctc_upsampling_ratio
...@@ -714,12 +743,17 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -714,12 +743,17 @@ class TransformerCTCEncoder(FairseqEncoder):
return x return x
if len(x.size()) == 3: if len(x.size()) == 3:
bsz, seq_len, dim = x.size() # bsz, seq_len, dim = x.size()
up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim) # up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
seq_len, bsz, dim = x.size()
x = x.permute(1, 2, 0)
up_x = self.un_sample(x)
up_x = up_x.permute(2, 0, 1)
else: else:
bsz, seq_len = x.size() bsz, seq_len = x.size()
up_x = x.unsqueeze(2).expand(-1, -1, ratio).reshape(bsz, -1) up_x = x.unsqueeze(2).expand(-1, -1, ratio).reshape(bsz, -1)
up_padding = padding.unsqueeze(-1).expand(-1, -1, ratio).reshape(bsz, -1) up_padding = padding.unsqueeze(-1).expand(-1, -1, int(ratio)).reshape(bsz, -1)
# output_length = int(seq_len * ratio * 2/3) # output_length = int(seq_len * ratio * 2/3)
# select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device) # select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
...@@ -742,6 +776,14 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -742,6 +776,14 @@ class TransformerCTCEncoder(FairseqEncoder):
assert tgt_dict is not None assert tgt_dict is not None
self.ctc.set_infer(ctc_infer, post_process, tgt_dict) self.ctc.set_infer(ctc_infer, post_process, tgt_dict)
def ctc_valid(self, lprobs, targets, input_lengths,
dictionary, lang="source"):
if hasattr(self, "ctc"):
return self.ctc.valid(lprobs, targets, input_lengths,
dictionary)
else:
logger.error("No ctc module in textual encoder")
# TorchScript doesn't support super() method so that the scriptable Subclass # TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript. # can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and # Current workaround is to add a helper function with different name and
...@@ -752,6 +794,7 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -752,6 +794,7 @@ class TransformerCTCEncoder(FairseqEncoder):
src_lengths: Optional[torch.Tensor] = None, src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False, return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None,
**kwargs
): ):
""" """
Args: Args:
...@@ -783,10 +826,10 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -783,10 +826,10 @@ class TransformerCTCEncoder(FairseqEncoder):
if self.history is not None: if self.history is not None:
self.history.clean() self.history.clean()
ctc_padding_mask = encoder_padding_mask # ctc_padding_mask = encoder_padding_mask
if self.use_ctc or len(self.interleaved_ctc_layers) != 0: # if self.use_ctc or len(self.interleaved_ctc_layers) != 0:
src_tokens, encoder_padding_mask = self.upsampling(src_tokens, encoder_padding_mask) # src_tokens, encoder_padding_mask = self.upsampling(src_tokens, encoder_padding_mask)
ctc_padding_mask = encoder_padding_mask # ctc_padding_mask = encoder_padding_mask
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
...@@ -796,6 +839,8 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -796,6 +839,8 @@ class TransformerCTCEncoder(FairseqEncoder):
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
# x, encoder_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_padding_mask = encoder_padding_mask
encoder_states = [] encoder_states = []
...@@ -824,6 +869,7 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -824,6 +869,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC # CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_logit = self.ctc(x.clone(), ctc_padding_mask) ctc_logit = self.ctc(x.clone(), ctc_padding_mask)
# Interleaved CTC # Interleaved CTC
...@@ -833,18 +879,31 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -833,18 +879,31 @@ class TransformerCTCEncoder(FairseqEncoder):
if p < self.interleaved_ctc_drop_prob: if p < self.interleaved_ctc_drop_prob:
break break
x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
norm_x = self.layer_norm(x) norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, ctc_padding_mask) logit = self.ctc(norm_x, ctc_padding_mask)
interleaved_ctc_logits.append(logit) interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1) prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, _ = self.sae([norm_x, prob]) # CTC alignment
oracle = None
# x = x.permute(1, 2, 0) oracle_mask = None
# x = self.pool(x) force_emit = None
# x = x.permute(2, 0, 1) if self.sae_ground_truth_ratio > 0:
# encoder_padding_mask = org_encoder_padding_mask ctc_alignment_oracle = kwargs.get("ctc_alignment_oracle", None)
if ctc_alignment_oracle is not None and ctc_alignment_oracle["source"] is not None:
oracle, best_aligns_pad = ctc_alignment_oracle["source"]
oracle_mask = (torch.rand(oracle.size(),
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, _ = self.sae([norm_x, prob], ctc_padding_mask, oracle, oracle_mask)
x = x.permute(1, 2, 0)
# x = nn.functional.interpolate(x, scale_factor=1/self.interleaved_ctc_upsampling_ratio, mode="linear")
x = self.down_sample(x)
x = x.permute(2, 0, 1)
if self.history is not None: if self.history is not None:
self.history.push(x) self.history.push(x)
......
...@@ -98,7 +98,7 @@ class Adapter(nn.Module): ...@@ -98,7 +98,7 @@ class Adapter(nn.Module):
self.ctc_compress = getattr(CTCCompressStrategy, ctc_compress_strategy) self.ctc_compress = getattr(CTCCompressStrategy, ctc_compress_strategy)
logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy) logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy)
if "league" in self.adapter_type: if self.cal_context:
self.distribution_cutoff = strategy.get("distribution_cutoff", None) self.distribution_cutoff = strategy.get("distribution_cutoff", None)
if self.distribution_cutoff is not None: if self.distribution_cutoff is not None:
self.distribution_cutoff = int(self.distribution_cutoff) self.distribution_cutoff = int(self.distribution_cutoff)
...@@ -107,17 +107,21 @@ class Adapter(nn.Module): ...@@ -107,17 +107,21 @@ class Adapter(nn.Module):
self.drop_prob = strategy.get("drop_prob", 0) self.drop_prob = strategy.get("drop_prob", 0)
if self.drop_prob != 0: if self.drop_prob != 0:
logger.info("Adapter drop probability: %f" % self.drop_prob) logger.info("Adapter drop probability: %f" % self.drop_prob)
self.ground_truth_ratio = strategy.get("gt_ratio", 0)
self.out_norm = strategy.get("out_norm", False) self.out_norm = strategy.get("out_norm", False)
if self.out_norm: if self.out_norm:
self.out_ln = LayerNorm(dim) self.out_ln = LayerNorm(dim)
def forward(self, x, padding=None): def forward(self, x, padding=None, oracle=None, oracle_mask=None):
representation, distribution = x representation, distribution = x
distribution = distribution.type_as(representation) distribution = distribution.type_as(representation)
seq_len, bsz, dim = representation.size() seq_len, bsz, dim = representation.size()
org_distribution = distribution org_distribution = distribution
distribution = distribution.contiguous().view(-1, distribution.size(-1))
vocab_size = distribution.size(-1)
distribution = distribution.contiguous().view(-1, vocab_size)
linear_out = None linear_out = None
soft_out = None soft_out = None
...@@ -125,18 +129,32 @@ class Adapter(nn.Module): ...@@ -125,18 +129,32 @@ class Adapter(nn.Module):
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
if self.cal_context: if self.cal_context:
if self.distribution_cutoff is not None: if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), org_distribution.size(-1) - 1) cutoff = min(int(self.distribution_cutoff), vocab_size - 1)
# threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1] # threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
# distribution = torch.where( # distribution = torch.where(
# org_distribution > threshold, org_distribution, torch.zeros_like(org_distribution) # org_distribution > threshold, org_distribution, torch.zeros_like(org_distribution)
# ) # )
threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True)
distribution = torch.where(
threshold > 0.9, org_distribution, torch.zeros_like(org_distribution)
)
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1) # threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True)
# distribution = torch.where(
# threshold > 0.9, org_distribution, torch.zeros_like(org_distribution)
# )
# distribution = distribution.view(-1, vocab_size)
distribution[:, 0] = 0
distribution = distribution / distribution.sum(-1, keepdim=True)
if self.ground_truth_ratio > 0 and oracle is not None:
oracle = oracle.unsqueeze(-1)
oracle_one_hot = (oracle == torch.arange(vocab_size, device=oracle.device).unsqueeze(0)).\
to(distribution.dtype).transpose(0, 1)
oracle_mask = oracle_mask.transpose(0, 1).unsqueeze(-1).repeat(1, 1, vocab_size)
modify_dist = oracle_mask * oracle_one_hot + ~oracle_mask * org_distribution
soft_out = torch.mm(modify_dist.view(-1, vocab_size), self.embed_adapter.weight).view(seq_len, bsz, -1)
else:
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
if self.embed_norm: if self.embed_norm:
soft_out = self.embed_ln(soft_out) soft_out = self.embed_ln(soft_out)
......
...@@ -37,13 +37,14 @@ class CTC(nn.Module): ...@@ -37,13 +37,14 @@ class CTC(nn.Module):
self.dictionary = dictionary self.dictionary = dictionary
self.infer_decoding = False self.infer_decoding = False
self.post_process = "sentencepiece" self.post_process = "sentencepiece"
self.blank_idx = 0
def set_infer(self, is_infer, text_post_process, dictionary): def set_infer(self, is_infer, text_post_process, dictionary):
self.infer_decoding = is_infer self.infer_decoding = is_infer
self.post_process = text_post_process self.post_process = text_post_process
self.dictionary = dictionary self.dictionary = dictionary
def forward(self, x, padding=None): def forward(self, x, padding=None, tag=None):
if self.need_layernorm: if self.need_layernorm:
x = self.LayerNorm(x) x = self.LayerNorm(x)
...@@ -52,7 +53,7 @@ class CTC(nn.Module): ...@@ -52,7 +53,7 @@ class CTC(nn.Module):
if not self.training and self.infer_decoding: if not self.training and self.infer_decoding:
assert self.dictionary is not None assert self.dictionary is not None
input_lengths = (~padding).sum(-1) input_lengths = (~padding).sum(-1)
self.infer(x.transpose(0, 1).float().contiguous().cpu(), input_lengths) self.infer(x.transpose(0, 1).float().contiguous().cpu(), input_lengths, tag)
return x return x
...@@ -65,7 +66,7 @@ class CTC(nn.Module): ...@@ -65,7 +66,7 @@ class CTC(nn.Module):
def argmax(self, x): def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1) return torch.argmax(self.ctc_projection(x), dim=-1)
def infer(self, logits_or_probs, lengths): def infer(self, logits_or_probs, lengths, tag=None):
for lp, inp_l in zip( for lp, inp_l in zip(
logits_or_probs, logits_or_probs,
lengths, lengths,
...@@ -78,41 +79,47 @@ class CTC(nn.Module): ...@@ -78,41 +79,47 @@ class CTC(nn.Module):
pred_units = self.dictionary.string(pred_units_arr) pred_units = self.dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split() pred_words_raw = post_process(pred_units, self.post_process).split()
logger.info("\nCTC prediction: %s" % " ".join(pred_words_raw)) if tag is not None:
logger.info("%s CTC prediction: %s" % (tag, " ".join(pred_words_raw)))
else:
logger.info("CTC prediction: %s" % (" ".join(pred_words_raw)))
def valid(self, logits_or_probs, target, lengths): def valid(self, logits_or_probs, targets, input_lengths, dictionary):
c_err = 0 c_err = 0
c_len = 0 c_len = 0
w_errs = 0 w_errs = 0
w_len = 0 w_len = 0
wv_errs = 0
for lp, t, inp_l in zip( with torch.no_grad():
logits_or_probs, for lp, t, inp_l in zip(
target, logits_or_probs,
lengths, targets,
): input_lengths,
lp = lp[:inp_l].unsqueeze(0) ):
lp = lp[:inp_l].unsqueeze(0)
p = (t != self.task.target_dictionary.pad()) & ( p = (t != dictionary.pad()) & (t != dictionary.eos())
t != self.task.target_dictionary.eos() targ = t[p]
) targ_units = dictionary.string(targ)
targ = t[p] targ_units_arr = targ.tolist()
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive() toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist() pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr) c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr) c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split() targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr) pred_units = dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split() pred_words_raw = post_process(pred_units, self.post_process).split()
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
dist = editdistance.eval(pred_words_raw, targ_words) w_len += len(targ_words)
w_errs += dist
w_len += len(targ_words) return c_err, c_len, w_errs, w_len, wv_errs
\ No newline at end of file
from .imputer import imputer_loss, ImputerLoss, best_alignment, ctc_decode
#include <torch/extension.h>
#include <tuple>
std::tuple<torch::Tensor, torch::Tensor>
imputer_loss_op(const torch::Tensor &log_probs, const torch::Tensor &targets,
const torch::Tensor &force_emits, at::IntArrayRef input_lengths,
at::IntArrayRef target_lengths, int64_t BLANK,
bool zero_infinity);
torch::Tensor imputer_loss_backward_op(
const torch::Tensor &grad, const torch::Tensor &log_probs,
const torch::Tensor &targets, const torch::Tensor &force_emits,
at::IntArrayRef input_lengths, at::IntArrayRef target_lengths,
const torch::Tensor &neg_log_likelihood, const torch::Tensor &log_alpha,
int64_t BLANK, bool zero_infinity);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
best_alignment_op(const torch::Tensor &log_probs, const torch::Tensor &targets,
at::IntArrayRef input_lengths, at::IntArrayRef target_lengths,
int64_t BLANK, bool zero_infinity);
std::tuple<torch::Tensor, torch::Tensor> imputer_loss(
const torch::Tensor &log_probs, const torch::Tensor &targets,
const torch::Tensor &force_emits, const torch::Tensor &input_lengths,
const torch::Tensor &target_lengths, int64_t BLANK, bool zero_infinity) {
torch::Tensor ilc =
input_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
torch::Tensor tlc =
target_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
at::IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
at::IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
auto res =
imputer_loss_op(log_probs, targets.to(log_probs.device(), at::kLong),
force_emits.to(log_probs.device(), at::kLong), il, tl,
BLANK, zero_infinity);
return res;
}
torch::Tensor imputer_loss_backward(
const torch::Tensor &grad, const torch::Tensor &log_probs,
const torch::Tensor &targets, const torch::Tensor &force_emits,
const torch::Tensor &input_lengths, const torch::Tensor &target_lengths,
const torch::Tensor &neg_log_likelihood, const torch::Tensor &log_alpha,
int64_t BLANK, bool zero_infinity) {
torch::Tensor ilc =
input_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
torch::Tensor tlc =
target_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
at::IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
at::IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
torch::Tensor res;
res = imputer_loss_backward_op(
grad, log_probs, targets.to(log_probs.device(), at::kLong),
force_emits.to(log_probs.device(), at::kLong), il, tl, neg_log_likelihood,
log_alpha, BLANK, zero_infinity);
return res;
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
best_alignment(const torch::Tensor &log_probs, const torch::Tensor &targets,
const torch::Tensor &input_lengths,
const torch::Tensor &target_lengths, int64_t BLANK,
bool zero_infinity) {
torch::Tensor ilc =
input_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
torch::Tensor tlc =
target_lengths.to(at::Device(at::kCPU), at::kLong).contiguous();
at::IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
at::IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
auto res =
best_alignment_op(log_probs, targets.to(log_probs.device(), at::kLong),
il, tl, BLANK, zero_infinity);
return res;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("imputer_loss", &imputer_loss, "calculate imputer loss");
m.def("imputer_loss_backward", &imputer_loss_backward,
"calculate imputer loss gradient");
m.def("best_alignment", &best_alignment, "get best alignments for ctc");
}
\ No newline at end of file
...@@ -107,7 +107,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -107,7 +107,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
) )
for model in models: for model in models:
if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"): if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"):
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, cfg.common_eval.post_process, model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece",
src_dict, tgt_dict) src_dict, tgt_dict)
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论