Commit 0a70c5c5 by xuchen

add the settings for the weight sharing of interleaved CTC

parent 21734086
......@@ -426,6 +426,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=float,
help="probability of dropping the followed layers",
)
parser.add_argument(
"--share-interleaved-ctc",
action="store_true",
help="share the weight of all interleaved ctc modules",
)
# Semantics-augmented Encoding (SAE)
parser.add_argument(
......@@ -672,6 +677,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.sae_ground_truth_ratio = getattr(args, "sae_ground_truth_ratio", 0)
self.share_interleaved_ctc = getattr(args, "share_interleaved_ctc", False)
self.interleaved_ctc_layers = []
self.use_inter_ctc = False
if args.interleaved_ctc_layers is not None:
......@@ -685,15 +691,25 @@ class S2TTransformerEncoder(FairseqEncoder):
logger.info("Interleaved CTC loss in layer %d" % layer_idx)
if not self.use_ctc:
self.ctc = CTC(dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
)
if getattr(args, "share_ctc_and_embed", False) and \
task.source_dictionary == task.target_dictionary and \
embed_tokens is not None and dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight
if not (self.use_ctc and self.share_interleaved_ctc):
if not self.share_interleaved_ctc:
for layer_idx in self.interleaved_ctc_layers:
inter_layer_norm = LayerNorm(dim)
inter_ctc = CTC(dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
)
setattr(self, "inter_ctc%d" % layer_idx, inter_ctc)
setattr(self, "inter_layer_norm%d" % layer_idx, inter_layer_norm)
else:
self.ctc = CTC(dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
)
if getattr(args, "share_ctc_and_embed", False) and \
task.source_dictionary == task.target_dictionary and \
embed_tokens is not None and dim == embed_tokens.embedding_dim:
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = {
"embed_norm": getattr(args, "sae_embed_norm", False),
......@@ -707,12 +723,24 @@ class S2TTransformerEncoder(FairseqEncoder):
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
self.sae = Adapter(dim, args.sae_adapter,
len(task.source_dictionary),
strategy=strategy,
)
if args.share_sae_and_ctc and hasattr(self.sae, "embed_adapter"):
self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
if not self.share_interleaved_ctc:
for layer_idx in self.interleaved_ctc_layers:
sae = Adapter(dim, args.sae_adapter,
len(task.source_dictionary),
strategy=strategy,
)
inter_ctc = getattr(self, "inter_ctc%d" % layer_idx)
if args.share_sae_and_ctc and hasattr(sae, "embed_adapter"):
sae.embed_adapter.weight = inter_ctc.ctc_projection.weight
setattr(self, "sae%d" % layer_idx, sae)
else:
self.sae = Adapter(dim, args.sae_adapter,
len(task.source_dictionary),
strategy=strategy,
)
if args.share_sae_and_ctc and hasattr(self.sae, "embed_adapter"):
self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
# mixup
self.mixup = getattr(args, "inter_mixup", False)
......@@ -753,8 +781,8 @@ class S2TTransformerEncoder(FairseqEncoder):
if hasattr(self, "ctc"):
return self.ctc.valid(lprobs, targets, input_lengths,
dictionary)
else:
logger.error("No ctc module in textual encoder")
logger.error("No ctc module in textual encoder")
def set_debug_var(self, debug_var_flag):
self.debug_var = debug_var_flag
......@@ -927,8 +955,17 @@ class S2TTransformerEncoder(FairseqEncoder):
if p < self.interleaved_ctc_drop_prob:
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x, encoder_padding_mask, "Source Layer %d" % layer_idx)
if self.share_interleaved_ctc:
inter_ctc = self.ctc
sae = self.sae
layer_norm = self.layer_norm
else:
inter_ctc = getattr(self, "inter_ctc%d" % layer_idx)
sae = getattr(self, "sae%d" % layer_idx)
layer_norm = getattr(self, "inter_layer_norm%d" % layer_idx)
norm_x = layer_norm(x)
logit = inter_ctc(norm_x, encoder_padding_mask, "Source Layer %d" % layer_idx)
interleaved_ctc_logits.append(logit)
# CTC alignment
......@@ -943,8 +980,8 @@ class S2TTransformerEncoder(FairseqEncoder):
device=oracle.device) < self.sae_ground_truth_ratio).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
if self.sae.adapter_type != "none":
x, encoder_padding_mask = self.sae([norm_x, logit], encoder_padding_mask, oracle, oracle_mask)
if sae.adapter_type != "none":
x, encoder_padding_mask = sae([norm_x, logit], encoder_padding_mask, oracle, oracle_mask)
self.show_debug(x, "x after sae")
# gather cosine similarity
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论