Commit a201a883 by xuchen

Try more settings of adapter

parent 5d84c743
...@@ -56,6 +56,11 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -56,6 +56,11 @@ class CtcCriterionConfig(FairseqDataclass):
default=0.0, default=0.0,
metadata={"help": "weight of interleaved CTC loss"}, metadata={"help": "weight of interleaved CTC loss"},
) )
aligned_target_ctc: bool = field(
default=False,
metadata={"help": "calculate target ctc by aligned text"},
)
target_ctc_weight: float = field( target_ctc_weight: float = field(
default=0.0, default=0.0,
metadata={"help": "weight of CTC loss for target sentence"}, metadata={"help": "weight of CTC loss for target sentence"},
...@@ -157,6 +162,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -157,6 +162,7 @@ class CtcCriterion(FairseqCriterion):
self.cal_all_ctc = cfg.cal_all_ctc self.cal_all_ctc = cfg.cal_all_ctc
self.ctc_weight = ctc_weight self.ctc_weight = ctc_weight
self.interleaved_ctc_weight = cfg.interleaved_ctc_weight self.interleaved_ctc_weight = cfg.interleaved_ctc_weight
self.aligned_target_ctc = cfg.aligned_target_ctc
self.target_ctc_weight = cfg.target_ctc_weight self.target_ctc_weight = cfg.target_ctc_weight
self.target_interleaved_ctc_weight = cfg.target_interleaved_ctc_weight self.target_interleaved_ctc_weight = cfg.target_interleaved_ctc_weight
...@@ -314,6 +320,12 @@ class CtcCriterion(FairseqCriterion): ...@@ -314,6 +320,12 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_num += 1 ctc_self_distill_num += 1
return ctc_self_distill_num, ctc_self_distill_loss return ctc_self_distill_num, ctc_self_distill_loss
def get_target_text(self, sample):
if self.aligned_target_ctc and "aligned_target" in sample:
return sample["aligned_target"]["tokens"]
else:
return sample["target"]
def compute_ctc_loss(self, model, sample, net_output, logging_output): def compute_ctc_loss(self, model, sample, net_output, logging_output):
if "transcript" in sample: if "transcript" in sample:
tokens = sample["transcript"]["tokens"] tokens = sample["transcript"]["tokens"]
...@@ -405,7 +417,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -405,7 +417,7 @@ class CtcCriterion(FairseqCriterion):
target_interleaved_ctc_loss = 0 target_interleaved_ctc_loss = 0
target_interleaved_ctc_num = 0 target_interleaved_ctc_num = 0
if self.use_target_ctc: if self.use_target_ctc:
target_tokens = sample["target"] target_tokens = self.get_target_text(sample)
target_pad_mask = (target_tokens != self.pad_idx) & (target_tokens != self.eos_idx) target_pad_mask = (target_tokens != self.pad_idx) & (target_tokens != self.eos_idx)
target_no_padding_mask = ~target_pad_mask target_no_padding_mask = ~target_pad_mask
...@@ -557,7 +569,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -557,7 +569,7 @@ class CtcCriterion(FairseqCriterion):
if target_lprobs is not None: if target_lprobs is not None:
target_lprobs_t = target_lprobs.transpose(0, 1).float().contiguous().cpu() target_lprobs_t = target_lprobs.transpose(0, 1).float().contiguous().cpu()
target_tokens = sample["target"] target_tokens = self.get_target_text(sample)
if mixup: if mixup:
idx = mixup_idx1 if mixup_coef > 0.5 else mixup_idx2 idx = mixup_idx1 if mixup_coef > 0.5 else mixup_idx2
target_tokens = target_tokens[idx] target_tokens = target_tokens[idx]
......
...@@ -283,6 +283,8 @@ def base_architecture(args): ...@@ -283,6 +283,8 @@ def base_architecture(args):
args.sae_out_norm = getattr(args, "sae_out_norm", False) args.sae_out_norm = getattr(args, "sae_out_norm", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0) args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None) args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
args.sae_distribution_hard = getattr(args, "sae_distribution_hard", False)
args.sae_gumbel = getattr(args, "sae_gumbel", False)
# mixup # mixup
args.inter_mixup = getattr(args, "inter_mixup", False) args.inter_mixup = getattr(args, "inter_mixup", False)
......
import logging import logging
import math import math
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -124,6 +125,11 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -124,6 +125,11 @@ class S2TSATEModel(S2TTransformerModel):
) )
# target CTC # target CTC
parser.add_argument( parser.add_argument(
"--target-sae-adapter",
type=str,
help="adapter type of target sae ",
)
parser.add_argument(
"--target-ctc-layer", "--target-ctc-layer",
default=0, default=0,
type=int, type=int,
...@@ -300,7 +306,6 @@ class TextualEncoder(FairseqEncoder): ...@@ -300,7 +306,6 @@ class TextualEncoder(FairseqEncoder):
self.ctc.ctc_projection.weight.size() == embed_tokens.weight.size(): self.ctc.ctc_projection.weight.size() == embed_tokens.weight.size():
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.interleaved_ctc_layers = [] self.interleaved_ctc_layers = []
self.target_interleaved_ctc_layers = getattr(args, "target_interleaved_ctc_layers", None) self.target_interleaved_ctc_layers = getattr(args, "target_interleaved_ctc_layers", None)
...@@ -330,11 +335,14 @@ class TextualEncoder(FairseqEncoder): ...@@ -330,11 +335,14 @@ class TextualEncoder(FairseqEncoder):
"embed_norm": getattr(args, "sae_embed_norm", False), "embed_norm": getattr(args, "sae_embed_norm", False),
"out_norm": getattr(args, "sae_out_norm", False), "out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None), "ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"ctc_temperature": getattr(args, "sae_ctc_temperature", 1.0),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None), "distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"gumbel": getattr(args, "sae_gumbel", False),
"distribution_hard": getattr(args, "sae_distribution_hard", None),
"drop_prob": getattr(args, "sae_drop_prob", 0), "drop_prob": getattr(args, "sae_drop_prob", 0),
} }
self.sae = Adapter(embed_dim, args.sae_adapter, self.sae = Adapter(embed_dim, args.target_sae_adapter,
len(dictionary), len(dictionary),
strategy=strategy) strategy=strategy)
if args.share_target_sae_and_ctc and hasattr(self.sae, "embed_adapter"): if args.share_target_sae_and_ctc and hasattr(self.sae, "embed_adapter"):
...@@ -372,7 +380,6 @@ class TextualEncoder(FairseqEncoder): ...@@ -372,7 +380,6 @@ class TextualEncoder(FairseqEncoder):
norm_x = self.layer_norm(x) norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, encoder_padding_mask, "Target Layer %d" % layer_idx) logit = self.ctc(norm_x, encoder_padding_mask, "Target Layer %d" % layer_idx)
target_interleaved_ctc_logits.append(logit) target_interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
# CTC alignment # CTC alignment
oracle = None oracle = None
...@@ -386,7 +393,8 @@ class TextualEncoder(FairseqEncoder): ...@@ -386,7 +393,8 @@ class TextualEncoder(FairseqEncoder):
device=oracle.device) < self.sae_ground_truth_ratio).bool() device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1) force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask, oracle, oracle_mask) if self.sae.adapter_type != "none":
x, encoder_padding_mask = self.sae([norm_x, logit], encoder_padding_mask, oracle, oracle_mask)
if history is not None: if history is not None:
history.push(x) history.push(x)
...@@ -398,7 +406,7 @@ class TextualEncoder(FairseqEncoder): ...@@ -398,7 +406,7 @@ class TextualEncoder(FairseqEncoder):
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and target_ctc_logit is None: if self.use_ctc and target_ctc_logit is None:
target_ctc_logit = self.ctc(x, encoder_padding_mask, "Target output") target_ctc_logit = self.ctc(x, encoder_padding_mask, "Target output", is_top=True)
return x, target_ctc_logit, target_interleaved_ctc_logits return x, target_ctc_logit, target_interleaved_ctc_logits
...@@ -460,13 +468,19 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -460,13 +468,19 @@ class S2TSATEEncoder(FairseqEncoder):
else: else:
self.history = None self.history = None
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None): def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None):
if hasattr(self.acoustic_encoder, "ctc"): if hasattr(self.acoustic_encoder, "ctc"):
assert src_dict is not None assert src_dict is not None
self.acoustic_encoder.ctc.set_infer(ctc_infer, post_process, src_dict) logger.info("Acoustic Encoder CTC Inference")
self.acoustic_encoder.ctc.set_infer(ctc_infer, post_process, src_dict,
path=path + ".src_ctc" if path is not None else None)
# path=os.path.join(path, "src_ctc") if path is not None else None)
if hasattr(self.textual_encoder, "ctc"): if hasattr(self.textual_encoder, "ctc"):
assert tgt_dict is not None assert tgt_dict is not None
self.textual_encoder.ctc.set_infer(ctc_infer, post_process, tgt_dict) logger.info("Textual Encoder CTC Inference")
self.textual_encoder.ctc.set_infer(ctc_infer, post_process, tgt_dict,
path=path + ".tgt_ctc" if path is not None else None)
# path=os.path.join(path, "tgt_ctc") if path is not None else None)
def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"): def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"):
if lang == "source": if lang == "source":
...@@ -500,11 +514,11 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -500,11 +514,11 @@ class S2TSATEEncoder(FairseqEncoder):
if "ctc_logit" in acoustic_encoder_out and len(acoustic_encoder_out["ctc_logit"]) > 0: if "ctc_logit" in acoustic_encoder_out and len(acoustic_encoder_out["ctc_logit"]) > 0:
ctc_logit = acoustic_encoder_out["ctc_logit"][0] ctc_logit = acoustic_encoder_out["ctc_logit"][0]
ctc_prob = F.softmax(ctc_logit / self.adapter_temperature, dim=-1, dtype=torch.float32) # ctc_prob = F.softmax(ctc_logit / self.adapter_temperature, dim=-1, dtype=torch.float32)
else: else:
ctc_logit = None ctc_logit = None
ctc_prob = None # ctc_prob = None
x = (encoder_out, ctc_prob) x = (encoder_out, ctc_logit)
x, encoder_padding_mask = self.adapter(x, encoder_padding_mask) x, encoder_padding_mask = self.adapter(x, encoder_padding_mask)
...@@ -677,10 +691,13 @@ def base_architecture(args): ...@@ -677,10 +691,13 @@ def base_architecture(args):
# Semantics-augmented Encoding (sae) # Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none") args.sae_adapter = getattr(args, "sae_adapter", "none")
args.target_sae_adapter = getattr(args, "target_sae_adapter", args.sae_adapter)
args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False) args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False)
args.share_target_sae_and_ctc = getattr(args, "share_target_sae_and_ctc", False) args.share_target_sae_and_ctc = getattr(args, "share_target_sae_and_ctc", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0) args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None) args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
args.sae_distribution_hard = getattr(args, "sae_distribution_hard", False)
args.sae_gumbel = getattr(args, "sae_gumbel", False)
# mixup # mixup
args.inter_mixup = getattr(args, "inter_mixup", False) args.inter_mixup = getattr(args, "inter_mixup", False)
......
...@@ -415,7 +415,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -415,7 +415,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="the position of interleaved ctc layers, separated by comma ", help="the position of interleaved ctc layers, separated by comma ",
) )
parser.add_argument( parser.add_argument(
"--interleaved-ctc-temperature", "--sae-ctc-temperature",
default=1, default=1,
type=float, type=float,
help="temperature of the CTC probability in sae", help="temperature of the CTC probability in sae",
...@@ -447,6 +447,16 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -447,6 +447,16 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="cutoff of the distribution in sae", help="cutoff of the distribution in sae",
) )
parser.add_argument( parser.add_argument(
"--sae-gumbel",
action="store_true",
help="use gumbel softmax in sae",
)
parser.add_argument(
"--sae-distribution-hard",
action="store_true",
help="use hard distribution in sae",
)
parser.add_argument(
"--sae-ground-truth-ratio", "--sae-ground-truth-ratio",
default=0, default=0,
type=float, type=float,
...@@ -643,7 +653,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -643,7 +653,8 @@ class S2TTransformerEncoder(FairseqEncoder):
else: else:
self.history = None self.history = None
self.use_ctc = "sate" in args.arch or getattr(args, "ctc_weight", 0) > 0 # self.use_ctc = "sate" in args.arch or getattr(args, "ctc_weight", 0) > 0
self.use_ctc = getattr(args, "ctc_weight", 0) > 0
if self.use_ctc: if self.use_ctc:
self.ctc_layer = args.ctc_layer self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != 0 and self.ctc_layer != args.encoder_layers else False self.inter_ctc = True if self.ctc_layer != 0 and self.ctc_layer != args.encoder_layers else False
...@@ -659,11 +670,12 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -659,11 +670,12 @@ class S2TTransformerEncoder(FairseqEncoder):
embed_tokens is not None and dim == embed_tokens.embedding_dim: embed_tokens is not None and dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = embed_tokens.weight
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0) self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
self.interleaved_ctc_layers = [] self.interleaved_ctc_layers = []
self.use_inter_ctc = False
if args.interleaved_ctc_layers is not None: if args.interleaved_ctc_layers is not None:
self.use_inter_ctc = True
interleaved_ctc_layers = args.interleaved_ctc_layers.split(",") interleaved_ctc_layers = args.interleaved_ctc_layers.split(",")
for layer_idx in interleaved_ctc_layers: for layer_idx in interleaved_ctc_layers:
layer_idx = int(layer_idx) layer_idx = int(layer_idx)
...@@ -687,7 +699,10 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -687,7 +699,10 @@ class S2TTransformerEncoder(FairseqEncoder):
"embed_norm": getattr(args, "sae_embed_norm", False), "embed_norm": getattr(args, "sae_embed_norm", False),
"out_norm": getattr(args, "sae_out_norm", False), "out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None), "ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"ctc_temperature": getattr(args, "sae_ctc_temperature", 1.0),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None), "distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"gumbel": getattr(args, "sae_gumbel", False),
"distribution_hard": getattr(args, "sae_distribution_hard", None),
"gt_ratio": self.sae_ground_truth_ratio, "gt_ratio": self.sae_ground_truth_ratio,
"drop_prob": getattr(args, "sae_drop_prob", 0), "drop_prob": getattr(args, "sae_drop_prob", 0),
} }
...@@ -720,10 +735,18 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -720,10 +735,18 @@ class S2TTransformerEncoder(FairseqEncoder):
# debug the variance # debug the variance
self.debug_var = False self.debug_var = False
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None): self.update_num = 0
self.curr_temp = 0
def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
self.update_num = num_updates
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None):
if hasattr(self, "ctc"): if hasattr(self, "ctc"):
assert src_dict is not None assert src_dict is not None
self.ctc.set_infer(ctc_infer, post_process, src_dict) self.ctc.set_infer(ctc_infer, post_process, src_dict,
path=path + ".ctc" if path is not None else None)
def ctc_valid(self, lprobs, targets, input_lengths, def ctc_valid(self, lprobs, targets, input_lengths,
dictionary, lang="source"): dictionary, lang="source"):
...@@ -906,13 +929,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -906,13 +929,8 @@ class S2TTransformerEncoder(FairseqEncoder):
norm_x = self.layer_norm(x) norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, encoder_padding_mask, "Source Layer %d" % layer_idx) logit = self.ctc(norm_x, encoder_padding_mask, "Source Layer %d" % layer_idx)
interleaved_ctc_logits.append(logit) interleaved_ctc_logits.append(logit)
logit = logit.clamp(min=-1e8 if logit.dtype == torch.float32 else -1e4,
max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
# CTC alignment # CTC alignment
oracle = None oracle = None
oracle_mask = None oracle_mask = None
...@@ -925,7 +943,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -925,7 +943,8 @@ class S2TTransformerEncoder(FairseqEncoder):
device=oracle.device) < self.sae_ground_truth_ratio).bool() device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1) force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, encoder_padding_mask = self.sae([norm_x, prob], encoder_padding_mask, oracle, oracle_mask) if self.sae.adapter_type != "none":
x, encoder_padding_mask = self.sae([norm_x, logit], encoder_padding_mask, oracle, oracle_mask)
self.show_debug(x, "x after sae") self.show_debug(x, "x after sae")
# gather cosine similarity # gather cosine similarity
...@@ -945,7 +964,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -945,7 +964,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.show_debug(x, "x after encoding layer norm") self.show_debug(x, "x after encoding layer norm")
if self.use_ctc and ctc_logit is None: if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, encoder_padding_mask, "Source output") ctc_logit = self.ctc(x, encoder_padding_mask, "Source output", is_top=True)
self.show_debug(x, "x after ctc") self.show_debug(x, "x after ctc")
return { return {
...@@ -1145,6 +1164,8 @@ def base_architecture(args): ...@@ -1145,6 +1164,8 @@ def base_architecture(args):
args.sae_out_norm = getattr(args, "sae_out_norm", False) args.sae_out_norm = getattr(args, "sae_out_norm", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0) args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None) args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
args.sae_distribution_hard = getattr(args, "sae_distribution_hard", False)
args.sae_gumbel = getattr(args, "sae_gumbel", False)
# mixup # mixup
args.inter_mixup = getattr(args, "inter_mixup", False) args.inter_mixup = getattr(args, "inter_mixup", False)
......
...@@ -319,7 +319,7 @@ class TransformerCTCModel(FairseqEncoderDecoderModel): ...@@ -319,7 +319,7 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
help="upsampling ratio of the representation for CTC calculation", help="upsampling ratio of the representation for CTC calculation",
) )
parser.add_argument( parser.add_argument(
"--interleaved-ctc-temperature", "--sae-ctc-temperature",
default=1, default=1,
type=float, type=float,
help="temperature of the CTC probability in sae", help="temperature of the CTC probability in sae",
...@@ -351,6 +351,16 @@ class TransformerCTCModel(FairseqEncoderDecoderModel): ...@@ -351,6 +351,16 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
help="cutoff of the distribution in sae", help="cutoff of the distribution in sae",
) )
parser.add_argument( parser.add_argument(
"--sae-gumbel",
action="store_true",
help="use gumbel softmax in sae",
)
parser.add_argument(
"--sae-distribution-hard",
action="store_true",
help="use hard distribution in sae",
)
parser.add_argument(
"--share-ctc-and-sae", "--share-ctc-and-sae",
action="store_true", action="store_true",
help="share the weight of ctc and sae", help="share the weight of ctc and sae",
...@@ -629,7 +639,6 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -629,7 +639,6 @@ class TransformerCTCEncoder(FairseqEncoder):
self.ctc.ctc_projection.weight = decoder_embed_tokens.weight self.ctc.ctc_projection.weight = decoder_embed_tokens.weight
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.interleaved_ctc_upsampling_ratio = int(args.interleaved_ctc_upsampling_ratio) self.interleaved_ctc_upsampling_ratio = int(args.interleaved_ctc_upsampling_ratio)
self.interleaved_ctc_layers = [] self.interleaved_ctc_layers = []
...@@ -661,7 +670,10 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -661,7 +670,10 @@ class TransformerCTCEncoder(FairseqEncoder):
"embed_norm": getattr(args, "sae_embed_norm", False), "embed_norm": getattr(args, "sae_embed_norm", False),
"out_norm": getattr(args, "sae_out_norm", False), "out_norm": getattr(args, "sae_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None), "ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"ctc_temperature": getattr(args, "sae_ctc_temperature", 1.0),
"distribution_cutoff": getattr(args, "sae_distribution_cutoff", None), "distribution_cutoff": getattr(args, "sae_distribution_cutoff", None),
"gumbel": getattr(args, "sae_gumbel", False),
"distribution_hard": getattr(args, "sae_distribution_hard", None),
"drop_prob": getattr(args, "sae_drop_prob", 0), "drop_prob": getattr(args, "sae_drop_prob", 0),
"gt_ratio": self.sae_ground_truth_ratio, "gt_ratio": self.sae_ground_truth_ratio,
} }
...@@ -743,9 +755,6 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -743,9 +755,6 @@ class TransformerCTCEncoder(FairseqEncoder):
return x return x
if len(x.size()) == 3: if len(x.size()) == 3:
# bsz, seq_len, dim = x.size()
# up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
seq_len, bsz, dim = x.size() seq_len, bsz, dim = x.size()
x = x.permute(1, 2, 0) x = x.permute(1, 2, 0)
up_x = self.un_sample(x) up_x = self.un_sample(x)
...@@ -755,20 +764,25 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -755,20 +764,25 @@ class TransformerCTCEncoder(FairseqEncoder):
up_x = x.unsqueeze(2).expand(-1, -1, ratio).reshape(bsz, -1) up_x = x.unsqueeze(2).expand(-1, -1, ratio).reshape(bsz, -1)
up_padding = padding.unsqueeze(-1).expand(-1, -1, int(ratio)).reshape(bsz, -1) up_padding = padding.unsqueeze(-1).expand(-1, -1, int(ratio)).reshape(bsz, -1)
# output_length = int(seq_len * ratio * 2/3) perturb = False
# select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device) if perturb:
# select_matrix[:, 1::ratio] = 1 output_length = int(seq_len * ratio * 2/3)
# mask = select_matrix.sort(dim=-1, descending=True)[1][:, :output_length] select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
# mask = mask.sort(dim=-1)[0] select_matrix[:, 1::ratio] = 1
# mask = select_matrix.sort(dim=-1, descending=True)[1][:, :output_length]
# if len(x.size()) == 3: mask = mask.sort(dim=-1)[0]
# out_x = torch.gather(up_x, dim=1, index=mask.unsqueeze(-1).expand(-1, -1, dim)).contiguous()
# else: if len(x.size()) == 3:
# out_x = torch.gather(up_x, dim=1, index=mask).contiguous() up_x = up_x.transpose(0, 1)
# out_padding = torch.gather(up_padding, dim=1, index=mask).contiguous() out_x = torch.gather(up_x, dim=1, index=mask.unsqueeze(-1).expand(-1, -1, dim)).contiguous()
out_x = out_x.transpose(0, 1)
out_x = up_x else:
out_padding = up_padding out_x = torch.gather(up_x, dim=1, index=mask).contiguous()
out_padding = torch.gather(up_padding, dim=1, index=mask).contiguous()
else:
out_x = up_x.contiguous()
out_padding = up_padding.contiguous()
return out_x, out_padding return out_x, out_padding
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None): def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None):
...@@ -869,8 +883,8 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -869,8 +883,8 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC # CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask) up_x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_logit = self.ctc(x.clone(), ctc_padding_mask) ctc_logit = self.ctc(up_x, ctc_padding_mask)
# Interleaved CTC # Interleaved CTC
if layer_idx in self.interleaved_ctc_layers: if layer_idx in self.interleaved_ctc_layers:
...@@ -879,12 +893,10 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -879,12 +893,10 @@ class TransformerCTCEncoder(FairseqEncoder):
if p < self.interleaved_ctc_drop_prob: if p < self.interleaved_ctc_drop_prob:
break break
x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask) up_x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
norm_x = self.layer_norm(x) norm_x = self.layer_norm(up_x)
logit = self.ctc(norm_x, ctc_padding_mask) logit = self.ctc(norm_x, ctc_padding_mask)
interleaved_ctc_logits.append(logit) interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
# CTC alignment # CTC alignment
oracle = None oracle = None
...@@ -898,7 +910,7 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -898,7 +910,7 @@ class TransformerCTCEncoder(FairseqEncoder):
device=oracle.device) < self.sae_ground_truth_ratio).bool() device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1) force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
x, _ = self.sae([norm_x, prob], ctc_padding_mask, oracle, oracle_mask) x, _ = self.sae([norm_x, logit], ctc_padding_mask, oracle, oracle_mask)
x = x.permute(1, 2, 0) x = x.permute(1, 2, 0)
# x = nn.functional.interpolate(x, scale_factor=1/self.interleaved_ctc_upsampling_ratio, mode="linear") # x = nn.functional.interpolate(x, scale_factor=1/self.interleaved_ctc_upsampling_ratio, mode="linear")
...@@ -915,7 +927,8 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -915,7 +927,8 @@ class TransformerCTCEncoder(FairseqEncoder):
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None: if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, ctc_padding_mask) up_x, ctc_padding_mask = self.upsampling(x, encoder_padding_mask)
ctc_logit = self.ctc(up_x, ctc_padding_mask)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead. # `forward` so we use a dictionary instead.
...@@ -1592,6 +1605,8 @@ def base_architecture(args): ...@@ -1592,6 +1605,8 @@ def base_architecture(args):
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False) args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0) args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None) args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
args.sae_distribution_hard = getattr(args, "sae_distribution_hard", False)
args.sae_gumbel = getattr(args, "sae_gumbel", False)
@register_model_architecture("transformer_ctc", "transformer_ctc_relative") @register_model_architecture("transformer_ctc", "transformer_ctc_relative")
......
...@@ -100,34 +100,45 @@ class Adapter(nn.Module): ...@@ -100,34 +100,45 @@ class Adapter(nn.Module):
if self.cal_context: if self.cal_context:
self.distribution_cutoff = strategy.get("distribution_cutoff", None) self.distribution_cutoff = strategy.get("distribution_cutoff", None)
self.distribution_temperature = strategy.get("ctc_temperature", 1.0)
self.gumbel = strategy.get("gumbel", False)
self.distribution_hard = strategy.get("distribution_hard", False)
self.ground_truth_ratio = strategy.get("gt_ratio", 0)
self.drop_prob = strategy.get("drop_prob", 0)
if self.distribution_cutoff is not None: if self.distribution_cutoff is not None:
self.distribution_cutoff = int(self.distribution_cutoff)
logger.info("Distribution cutoff: %d" % self.distribution_cutoff) logger.info("Distribution cutoff: %d" % self.distribution_cutoff)
if self.distribution_temperature != 1.0:
self.drop_prob = strategy.get("drop_prob", 0) logger.info("Temperature: %f" % self.distribution_temperature)
if self.gumbel:
logger.info("Gumbel softmax.")
if self.distribution_hard:
logger.info("Hard distribution.")
if self.drop_prob != 0: if self.drop_prob != 0:
logger.info("Adapter drop probability: %f" % self.drop_prob) logger.info("Drop probability: %f" % self.drop_prob)
self.ground_truth_ratio = strategy.get("gt_ratio", 0)
self.out_norm = strategy.get("out_norm", False) self.out_norm = strategy.get("out_norm", False)
if self.out_norm: if self.out_norm:
self.out_ln = LayerNorm(dim) self.out_ln = LayerNorm(dim)
def forward(self, x, padding=None, oracle=None, oracle_mask=None): def forward(self, x, padding=None, oracle=None, oracle_mask=None):
representation, logit = x
representation, distribution = x
distribution = distribution.type_as(representation)
seq_len, bsz, dim = representation.size() seq_len, bsz, dim = representation.size()
org_distribution = distribution
vocab_size = distribution.size(-1)
distribution = distribution.contiguous().view(-1, vocab_size)
linear_out = None linear_out = None
soft_out = None soft_out = None
if self.cal_linear: if self.cal_linear:
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
if self.cal_context: if self.cal_context:
if self.training and self.gumbel:
distribution = F.gumbel_softmax(logit, tau=self.distribution_temperature, hard=self.distribution_hard)
else:
distribution = F.softmax(logit / self.distribution_temperature, dim=-1)
vocab_size = distribution.size(-1)
distribution = distribution.contiguous().view(-1, vocab_size)
org_distribution = distribution
if self.distribution_cutoff is not None: if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), vocab_size - 1) cutoff = min(int(self.distribution_cutoff), vocab_size - 1)
...@@ -184,11 +195,15 @@ class Adapter(nn.Module): ...@@ -184,11 +195,15 @@ class Adapter(nn.Module):
out = representation out = representation
elif self.adapter_type == "shrink": elif self.adapter_type == "shrink":
if self.training and self.gumbel:
distribution = F.gumbel_softmax(logit, tau=self.distribution_temperature, hard=self.distribution_hard)
else:
distribution = F.softmax(logit / self.distribution_temperature, dim=-1)
lengths = (~padding).long().sum(-1) lengths = (~padding).long().sum(-1)
with torch.no_grad(): with torch.no_grad():
batch_predicted = [] batch_predicted = []
prob_ctc = org_distribution.transpose(0, 1) # T x B x D -> B x T x D prob_ctc = distribution.transpose(0, 1) # T x B x D -> B x T x D
for b in range(prob_ctc.shape[0]): for b in range(prob_ctc.shape[0]):
predicted = prob_ctc[b][: lengths[b]].argmax(-1).tolist() predicted = prob_ctc[b][: lengths[b]].argmax(-1).tolist()
batch_predicted.append([(p[0], len(list(p[1]))) for p in groupby(predicted)]) batch_predicted.append([(p[0], len(list(p[1]))) for p in groupby(predicted)])
......
...@@ -39,18 +39,23 @@ class CTC(nn.Module): ...@@ -39,18 +39,23 @@ class CTC(nn.Module):
self.post_process = "sentencepiece" self.post_process = "sentencepiece"
self.blank_idx = 0 self.blank_idx = 0
def set_infer(self, is_infer, text_post_process, dictionary): def set_infer(self, is_infer, text_post_process, dictionary, path):
self.infer_decoding = is_infer self.infer_decoding = is_infer
self.post_process = text_post_process self.post_process = text_post_process
self.dictionary = dictionary self.dictionary = dictionary
self.path = path
if self.path is not None:
self.save_stream = open(self.path, "a")
else:
self.save_stream = None
def forward(self, x, padding=None, tag=None): def forward(self, x, padding=None, tag=None, is_top=False):
if self.need_layernorm: if self.need_layernorm:
x = self.LayerNorm(x) x = self.LayerNorm(x)
x = self.ctc_projection(self.ctc_dropout_module(x)) x = self.ctc_projection(self.ctc_dropout_module(x))
if not self.training and self.infer_decoding: if not self.training and self.infer_decoding and is_top:
assert self.dictionary is not None assert self.dictionary is not None
input_lengths = (~padding).sum(-1) input_lengths = (~padding).sum(-1)
self.infer(x.transpose(0, 1).float().contiguous().cpu(), input_lengths, tag) self.infer(x.transpose(0, 1).float().contiguous().cpu(), input_lengths, tag)
...@@ -79,6 +84,9 @@ class CTC(nn.Module): ...@@ -79,6 +84,9 @@ class CTC(nn.Module):
pred_units = self.dictionary.string(pred_units_arr) pred_units = self.dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split() pred_words_raw = post_process(pred_units, self.post_process).split()
if self.save_stream is not None:
self.save_stream.write(" ".join(pred_words_raw) + "\n")
if tag is not None: if tag is not None:
logger.info("%s CTC prediction: %s" % (tag, " ".join(pred_words_raw))) logger.info("%s CTC prediction: %s" % (tag, " ".join(pred_words_raw)))
else: else:
......
...@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for model in models: for model in models:
if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"): if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"):
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece", model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece",
src_dict, tgt_dict) src_dict, tgt_dict, translation_path) # os.path.dirname(translation_path))
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论