Commit a201a883 by xuchen

Try more settings of adapter

parent 5d84c743
......@@ -56,6 +56,11 @@ class CtcCriterionConfig(FairseqDataclass):
default=0.0,
metadata={"help": "weight of interleaved CTC loss"},
)
aligned_target_ctc: bool = field(
default=False,
metadata={"help": "calculate target ctc by aligned text"},
)
target_ctc_weight: float = field(
default=0.0,
metadata={"help": "weight of CTC loss for target sentence"},
......@@ -157,6 +162,7 @@ class CtcCriterion(FairseqCriterion):
self.cal_all_ctc = cfg.cal_all_ctc
self.ctc_weight = ctc_weight
self.interleaved_ctc_weight = cfg.interleaved_ctc_weight
self.aligned_target_ctc = cfg.aligned_target_ctc
self.target_ctc_weight = cfg.target_ctc_weight
self.target_interleaved_ctc_weight = cfg.target_interleaved_ctc_weight
......@@ -314,6 +320,12 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_num += 1
return ctc_self_distill_num, ctc_self_distill_loss
def get_target_text(self, sample):
if self.aligned_target_ctc and "aligned_target" in sample:
return sample["aligned_target"]["tokens"]
else:
return sample["target"]
def compute_ctc_loss(self, model, sample, net_output, logging_output):
if "transcript" in sample:
tokens = sample["transcript"]["tokens"]
......@@ -405,7 +417,7 @@ class CtcCriterion(FairseqCriterion):
target_interleaved_ctc_loss = 0
target_interleaved_ctc_num = 0
if self.use_target_ctc:
target_tokens = sample["target"]
target_tokens = self.get_target_text(sample)
target_pad_mask = (target_tokens != self.pad_idx) & (target_tokens != self.eos_idx)
target_no_padding_mask = ~target_pad_mask
......@@ -557,7 +569,7 @@ class CtcCriterion(FairseqCriterion):
if target_lprobs is not None:
target_lprobs_t = target_lprobs.transpose(0, 1).float().contiguous().cpu()
target_tokens = sample["target"]
target_tokens = self.get_target_text(sample)
if mixup:
idx = mixup_idx1 if mixup_coef > 0.5 else mixup_idx2
target_tokens = target_tokens[idx]
......
......@@ -283,6 +283,8 @@ def base_architecture(args):
args.sae_out_norm = getattr(args, "sae_out_norm", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
args.sae_distribution_hard = getattr(args, "sae_distribution_hard", False)
args.sae_gumbel = getattr(args, "sae_gumbel", False)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
......
import logging
import math
import os
import torch
import torch.nn as nn
......@@ -124,6 +125,11 @@ class S2TSATEModel(S2TTransformerModel):
)
# target CTC
parser.add_argument(
"--target-sae-adapter",
type=str,
help="adapter type of target sae ",
)
parser.add_argument(
"--target-ctc-layer",
default=0,
type=int,
......@@ -300,7 +306,6 @@ class TextualEncoder(FairseqEncoder):
self.ctc.ctc_projection.weight.size() == embed_tokens.weight.size():
self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.interleaved_ctc_layers = []
self.target_interleaved_ctc_layers = getattr(args, "target_interleaved_ctc_layers", None)
......@@ -330,11 +335,14 @@ class TextualEncoder(FairseqEncoder):
"embed_norm": getattr(args, "sae_embed_norm", False),
"out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"ctc_temperature": getattr(args, "sae_ctc_temperature", 1.0),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"gumbel": getattr(args, "sae_gumbel", False),
"distribution_hard": getattr(args, "sae_distribution_hard", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
self.sae = Adapter(embed_dim, args.sae_adapter,
self.sae = Adapter(embed_dim, args.target_sae_adapter,
len(dictionary),
strategy=strategy)
if args.share_target_sae_and_ctc and hasattr(self.sae, "embed_adapter"):
......@@ -372,7 +380,6 @@ class TextualEncoder(FairseqEncoder):
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, encoder_padding_mask, "Target Layer %d" % layer_idx)
target_interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
# CTC alignment
oracle = None
......@@ -386,7 +393,8 @@ class TextualEncoder(FairseqEncoder):
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask, oracle, oracle_mask)
if self.sae.adapter_type != "none":
x, encoder_padding_mask = self.sae([norm_x, logit], encoder_padding_mask, oracle, oracle_mask)
if history is not None:
history.push(x)
......@@ -398,7 +406,7 @@ class TextualEncoder(FairseqEncoder):
x = self.layer_norm(x)
if self.use_ctc and target_ctc_logit is None:
target_ctc_logit = self.ctc(x, encoder_padding_mask, "Target output")
target_ctc_logit = self.ctc(x, encoder_padding_mask, "Target output", is_top=True)
return x, target_ctc_logit, target_interleaved_ctc_logits
......@@ -460,13 +468,19 @@ class S2TSATEEncoder(FairseqEncoder):
else:
self.history = None
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None):
if hasattr(self.acoustic_encoder, "ctc"):
assert src_dict is not None
self.acoustic_encoder.ctc.set_infer(ctc_infer, post_process, src_dict)
logger.info("Acoustic Encoder CTC Inference")
self.acoustic_encoder.ctc.set_infer(ctc_infer, post_process, src_dict,
path=path + ".src_ctc" if path is not None else None)
# path=os.path.join(path, "src_ctc") if path is not None else None)
if hasattr(self.textual_encoder, "ctc"):
assert tgt_dict is not None
self.textual_encoder.ctc.set_infer(ctc_infer, post_process, tgt_dict)
logger.info("Textual Encoder CTC Inference")
self.textual_encoder.ctc.set_infer(ctc_infer, post_process, tgt_dict,
path=path + ".tgt_ctc" if path is not None else None)
# path=os.path.join(path, "tgt_ctc") if path is not None else None)
def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"):
if lang == "source":
......@@ -500,11 +514,11 @@ class S2TSATEEncoder(FairseqEncoder):
if "ctc_logit" in acoustic_encoder_out and len(acoustic_encoder_out["ctc_logit"]) > 0:
ctc_logit = acoustic_encoder_out["ctc_logit"][0]
ctc_prob = F.softmax(ctc_logit / self.adapter_temperature, dim=-1, dtype=torch.float32)
# ctc_prob = F.softmax(ctc_logit / self.adapter_temperature, dim=-1, dtype=torch.float32)
else:
ctc_logit = None
ctc_prob = None
x = (encoder_out, ctc_prob)
# ctc_prob = None
x = (encoder_out, ctc_logit)
x, encoder_padding_mask = self.adapter(x, encoder_padding_mask)
......@@ -677,10 +691,13 @@ def base_architecture(args):
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.target_sae_adapter = getattr(args, "target_sae_adapter", args.sae_adapter)
args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False)
args.share_target_sae_and_ctc = getattr(args, "share_target_sae_and_ctc", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
args.sae_distribution_hard = getattr(args, "sae_distribution_hard", False)
args.sae_gumbel = getattr(args, "sae_gumbel", False)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
......
......@@ -415,7 +415,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="the position of interleaved ctc layers, separated by comma ",
)
parser.add_argument(
"--interleaved-ctc-temperature",
"--sae-ctc-temperature",
default=1,
type=float,
help="temperature of the CTC probability in sae",
......@@ -447,6 +447,16 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="cutoff of the distribution in sae",
)
parser.add_argument(
"--sae-gumbel",
action="store_true",
help="use gumbel softmax in sae",
)
parser.add_argument(
"--sae-distribution-hard",
action="store_true",
help="use hard distribution in sae",
)
parser.add_argument(
"--sae-ground-truth-ratio",
default=0,
type=float,
......@@ -643,7 +653,8 @@ class S2TTransformerEncoder(FairseqEncoder):
else:
self.history = None
self.use_ctc = "sate" in args.arch or getattr(args, "ctc_weight", 0) > 0
# self.use_ctc = "sate" in args.arch or getattr(args, "ctc_weight", 0) > 0
self.use_ctc = getattr(args, "ctc_weight", 0) > 0
if self.use_ctc:
self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != 0 and self.ctc_layer != args.encoder_layers else False
......@@ -659,11 +670,12 @@ class S2TTransformerEncoder(FairseqEncoder):
embed_tokens is not None and dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
self.interleaved_ctc_layers = []
self.use_inter_ctc = False
if args.interleaved_ctc_layers is not None:
self.use_inter_ctc = True
interleaved_ctc_layers = args.interleaved_ctc_layers.split(",")
for layer_idx in interleaved_ctc_layers:
layer_idx = int(layer_idx)
......@@ -687,7 +699,10 @@ class S2TTransformerEncoder(FairseqEncoder):
"embed_norm": getattr(args, "sae_embed_norm", False),
"out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"ctc_temperature": getattr(args, "sae_ctc_temperature", 1.0),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"gumbel": getattr(args, "sae_gumbel", False),
"distribution_hard": getattr(args, "sae_distribution_hard", None),
"gt_ratio": self.sae_ground_truth_ratio,
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
......@@ -720,10 +735,18 @@ class S2TTransformerEncoder(FairseqEncoder):
# debug the variance
self.debug_var = False
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
self.update_num = 0
self.curr_temp = 0
def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
self.update_num = num_updates
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None):
if hasattr(self, "ctc"):
assert src_dict is not None
self.ctc.set_infer(ctc_infer, post_process, src_dict)
self.ctc.set_infer(ctc_infer, post_process, src_dict,
path=path + ".ctc" if path is not None else None)
def ctc_valid(self, lprobs, targets, input_lengths,
dictionary, lang="source"):
......@@ -906,13 +929,8 @@ class S2TTransformerEncoder(FairseqEncoder):
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, encoder_padding_mask, "Source Layer %d" % layer_idx)
interleaved_ctc_logits.append(logit)
logit = logit.clamp(min=-1e8 if logit.dtype == torch.float32 else -1e4,
max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
# CTC alignment
oracle = None
oracle_mask = None
......@@ -925,7 +943,8 @@ class S2TTransformerEncoder(FairseqEncoder):
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask, oracle, oracle_mask)
if self.sae.adapter_type != "none":
x, encoder_padding_mask = self.sae([norm_x, logit], encoder_padding_mask, oracle, oracle_mask)
self.show_debug(x, "x after sae")
# gather cosine similarity
......@@ -945,7 +964,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.show_debug(x, "x after encoding layer norm")
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, encoder_padding_mask, "Source output")
ctc_logit = self.ctc(x, encoder_padding_mask, "Source output", is_top=True)
self.show_debug(x, "x after ctc")
return {
......@@ -1145,6 +1164,8 @@ def base_architecture(args):
args.sae_out_norm = getattr(args, "sae_out_norm", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
args.sae_distribution_hard = getattr(args, "sae_distribution_hard", False)
args.sae_gumbel = getattr(args, "sae_gumbel", False)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
......
......@@ -319,7 +319,7 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
help="upsampling ratio of the representation for CTC calculation",
)
parser.add_argument(
"--interleaved-ctc-temperature",
"--sae-ctc-temperature",
default=1,
type=float,
help="temperature of the CTC probability in sae",
......@@ -351,6 +351,16 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
help="cutoff of the distribution in sae",
)
parser.add_argument(
"--sae-gumbel",
action="store_true",
help="use gumbel softmax in sae",
)
parser.add_argument(
"--sae-distribution-hard",
action="store_true",
help="use hard distribution in sae",
)
parser.add_argument(
"--share-ctc-and-sae",
action="store_true",
help="share the weight of ctc and sae",
......@@ -629,7 +639,6 @@ class TransformerCTCEncoder(FairseqEncoder):
self.ctc.ctc_projection.weight = decoder_embed_tokens.weight
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.interleaved_ctc_upsampling_ratio = int(args.interleaved_ctc_upsampling_ratio)
self.interleaved_ctc_layers = []
......@@ -661,7 +670,10 @@ class TransformerCTCEncoder(FairseqEncoder):
"embed_norm": getattr(args, "sae_embed_norm", False),
"out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"ctc_temperature": getattr(args, "sae_ctc_temperature", 1.0),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"gumbel": getattr(args, "sae_gumbel", False),
"distribution_hard": getattr(args, "sae_distribution_hard", None),
"drop_prob": getattr(args, "sae_drop_prob", 0),
"gt_ratio": self.sae_ground_truth_ratio,
}
......@@ -743,9 +755,6 @@ class TransformerCTCEncoder(FairseqEncoder):
return x
if len(x.size()) == 3:
# bsz, seq_len, dim = x.size()
# up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
seq_len, bsz, dim = x.size()
x = x.permute(1, 2, 0)
up_x = self.un_sample(x)
......@@ -755,20 +764,25 @@ class TransformerCTCEncoder(FairseqEncoder):
up_x = x.unsqueeze(2).expand(-1, -1, ratio).reshape(bsz, -1)
up_padding = padding.unsqueeze(-1).expand(-1, -1, int(ratio)).reshape(bsz, -1)
# output_length = int(seq_len * ratio * 2/3)
# select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
# select_matrix[:, 1::ratio] = 1
# mask = select_matrix.sort(dim=-1, descending=True)[1][:, :output_length]
# mask = mask.sort(dim=-1)[0]
#
# if len(x.size()) == 3:
# out_x = torch.gather(up_x, dim=1, index=mask.unsqueeze(-1).expand(-1, -1, dim)).contiguous()
# else:
# out_x = torch.gather(up_x, dim=1, index=mask).contiguous()
# out_padding = torch.gather(up_padding, dim=1, index=mask).contiguous()
out_x = up_x
out_padding = up_padding
perturb = False
if perturb:
output_length = int(seq_len * ratio * 2/3)
select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
select_matrix[:, 1::ratio] = 1
mask = select_matrix.sort(dim=-1, descending=True)[1][:, :output_length]
mask = mask.sort(dim=-1)[0]
if len(x.size()) == 3:
up_x = up_x.transpose(0, 1)
out_x = torch.gather(up_x, dim=1, index=mask.unsqueeze(-1).expand(-1, -1, dim)).contiguous()
out_x = out_x.transpose(0, 1)
else:
out_x = torch.gather(up_x, dim=1, index=mask).contiguous()
out_padding = torch.gather(up_padding, dim=1, index=mask).contiguous()
else:
out_x = up_x.contiguous()
out_padding = up_padding.contiguous()
return out_x, out_padding
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
......@@ -869,8 +883,8 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_logit = self.ctc(x.clone(), ctc_padding_mask)
up_x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_logit = self.ctc(up_x, ctc_padding_mask)
# Interleaved CTC
if layer_idx in self.interleaved_ctc_layers:
......@@ -879,12 +893,10 @@ class TransformerCTCEncoder(FairseqEncoder):
if p < self.interleaved_ctc_drop_prob:
break
x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
norm_x = self.layer_norm(x)
up_x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
norm_x = self.layer_norm(up_x)
logit = self.ctc(norm_x, ctc_padding_mask)
interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
# CTC alignment
oracle = None
......@@ -898,7 +910,7 @@ class TransformerCTCEncoder(FairseqEncoder):
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, _ = self.sae([norm_x, prob], ctc_padding_mask, oracle, oracle_mask)
x, _ = self.sae([norm_x, logit], ctc_padding_mask, oracle, oracle_mask)
x = x.permute(1, 2, 0)
# x = nn.functional.interpolate(x, scale_factor=1/self.interleaved_ctc_upsampling_ratio, mode="linear")
......@@ -915,7 +927,8 @@ class TransformerCTCEncoder(FairseqEncoder):
x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, ctc_padding_mask)
up_x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_logit = self.ctc(up_x, ctc_padding_mask)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
......@@ -1592,6 +1605,8 @@ def base_architecture(args):
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
args.sae_distribution_hard = getattr(args, "sae_distribution_hard", False)
args.sae_gumbel = getattr(args, "sae_gumbel", False)
@register_model_architecture("transformer_ctc", "transformer_ctc_relative")
......
......@@ -100,34 +100,45 @@ class Adapter(nn.Module):
if self.cal_context:
self.distribution_cutoff = strategy.get("distribution_cutoff", None)
self.distribution_temperature = strategy.get("ctc_temperature", 1.0)
self.gumbel = strategy.get("gumbel", False)
self.distribution_hard = strategy.get("distribution_hard", False)
self.ground_truth_ratio = strategy.get("gt_ratio", 0)
self.drop_prob = strategy.get("drop_prob", 0)
if self.distribution_cutoff is not None:
self.distribution_cutoff = int(self.distribution_cutoff)
logger.info("Distribution cutoff: %d" % self.distribution_cutoff)
self.drop_prob = strategy.get("drop_prob", 0)
if self.distribution_temperature != 1.0:
logger.info("Temperature: %f" % self.distribution_temperature)
if self.gumbel:
logger.info("Gumbel softmax.")
if self.distribution_hard:
logger.info("Hard distribution.")
if self.drop_prob != 0:
logger.info("Adapter drop probability: %f" % self.drop_prob)
logger.info("Drop probability: %f" % self.drop_prob)
self.ground_truth_ratio = strategy.get("gt_ratio", 0)
self.out_norm = strategy.get("out_norm", False)
if self.out_norm:
self.out_ln = LayerNorm(dim)
def forward(self, x, padding=None, oracle=None, oracle_mask=None):
representation, distribution = x
distribution = distribution.type_as(representation)
representation, logit = x
seq_len, bsz, dim = representation.size()
org_distribution = distribution
vocab_size = distribution.size(-1)
distribution = distribution.contiguous().view(-1, vocab_size)
linear_out = None
soft_out = None
if self.cal_linear:
linear_out = self.linear_adapter(representation)
if self.cal_context:
if self.training and self.gumbel:
distribution = F.gumbel_softmax(logit, tau=self.distribution_temperature, hard=self.distribution_hard)
else:
distribution = F.softmax(logit / self.distribution_temperature, dim=-1)
vocab_size = distribution.size(-1)
distribution = distribution.contiguous().view(-1, vocab_size)
org_distribution = distribution
if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), vocab_size - 1)
......@@ -184,11 +195,15 @@ class Adapter(nn.Module):
out = representation
elif self.adapter_type == "shrink":
if self.training and self.gumbel:
distribution = F.gumbel_softmax(logit, tau=self.distribution_temperature, hard=self.distribution_hard)
else:
distribution = F.softmax(logit / self.distribution_temperature, dim=-1)
lengths = (~padding).long().sum(-1)
with torch.no_grad():
batch_predicted = []
prob_ctc = org_distribution.transpose(0, 1) # T x B x D -> B x T x D
prob_ctc = distribution.transpose(0, 1) # T x B x D -> B x T x D
for b in range(prob_ctc.shape[0]):
predicted = prob_ctc[b][: lengths[b]].argmax(-1).tolist()
batch_predicted.append([(p[0], len(list(p[1]))) for p in groupby(predicted)])
......
......@@ -39,18 +39,23 @@ class CTC(nn.Module):
self.post_process = "sentencepiece"
self.blank_idx = 0
def set_infer(self, is_infer, text_post_process, dictionary):
def set_infer(self, is_infer, text_post_process, dictionary, path):
self.infer_decoding = is_infer
self.post_process = text_post_process
self.dictionary = dictionary
self.path = path
if self.path is not None:
self.save_stream = open(self.path, "a")
else:
self.save_stream = None
def forward(self, x, padding=None, tag=None):
def forward(self, x, padding=None, tag=None, is_top=False):
if self.need_layernorm:
x = self.LayerNorm(x)
x = self.ctc_projection(self.ctc_dropout_module(x))
if not self.training and self.infer_decoding:
if not self.training and self.infer_decoding and is_top:
assert self.dictionary is not None
input_lengths = (~padding).sum(-1)
self.infer(x.transpose(0, 1).float().contiguous().cpu(), input_lengths, tag)
......@@ -79,6 +84,9 @@ class CTC(nn.Module):
pred_units = self.dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
if self.save_stream is not None:
self.save_stream.write(" ".join(pred_words_raw) + "\n")
if tag is not None:
logger.info("%s CTC prediction: %s" % (tag, " ".join(pred_words_raw)))
else:
......
......@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for model in models:
if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"):
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece",
src_dict, tgt_dict)
src_dict, tgt_dict, translation_path) # os.path.dirname(translation_path))
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论