Commit 504e81af by xuchen

fix bugs

parent e7422a42
......@@ -618,6 +618,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="share the weight of target ctc and embed",
)
parser.add_argument(
"--share-xctc-and-ctc",
action="store_true",
help="share the weight of target ctc and ctc",
)
parser.add_argument(
"--inter-xctc-layers",
default=None,
type=str,
......@@ -1197,7 +1202,7 @@ class S2TTransformerEncoder(FairseqEncoder):
xctc_norm = LayerNorm(dim)
setattr(self, "xctc_norm%d" % layer_idx, xctc_norm)
# consider layer norm
# consider layer norm
if not hasattr(self, "xctc"):
self.xctc = CTC(
dim,
......@@ -1212,6 +1217,8 @@ class S2TTransformerEncoder(FairseqEncoder):
== embed_tokens.weight.size()
):
self.xctc.ctc_projection.weight = embed_tokens.weight
elif getattr(self, "share_xctc_and_ctc", False) and hasattr(self, "ctc"):
self.xctc.ctc_projection.weight = self.ctc.ctc_projection.weight
strategy = {
"embed_norm": getattr(args, "pae_embed_norm", False),
......@@ -2365,6 +2372,7 @@ def base_architecture(args):
# XCTC
args.xctc_layer = getattr(args, "xctc_layer", 0)
args.share_xctc_and_embed = getattr(args, "share_xctc_and_embed", False)
args.share_xctc_and_ctc = getattr(args, "share_xctc_and_ctc", False)
args.xctc_pae = getattr(args, "xctc_pae", args.ctc_pae)
args.axctc_pae = getattr(args, "axctc_pae", args.xctc_pae)
args.share_pae_and_xctc = getattr(args, "share_pae_and_xctc", False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论