Commit 504e81af by xuchen

fix bugs

parent e7422a42
...@@ -618,6 +618,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -618,6 +618,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="share the weight of target ctc and embed", help="share the weight of target ctc and embed",
) )
parser.add_argument( parser.add_argument(
"--share-xctc-and-ctc",
action="store_true",
help="share the weight of target ctc and ctc",
)
parser.add_argument(
"--inter-xctc-layers", "--inter-xctc-layers",
default=None, default=None,
type=str, type=str,
...@@ -1197,7 +1202,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1197,7 +1202,7 @@ class S2TTransformerEncoder(FairseqEncoder):
xctc_norm = LayerNorm(dim) xctc_norm = LayerNorm(dim)
setattr(self, "xctc_norm%d" % layer_idx, xctc_norm) setattr(self, "xctc_norm%d" % layer_idx, xctc_norm)
# consider layer norm # consider layer norm
if not hasattr(self, "xctc"): if not hasattr(self, "xctc"):
self.xctc = CTC( self.xctc = CTC(
dim, dim,
...@@ -1212,6 +1217,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1212,6 +1217,8 @@ class S2TTransformerEncoder(FairseqEncoder):
== embed_tokens.weight.size() == embed_tokens.weight.size()
): ):
self.xctc.ctc_projection.weight = embed_tokens.weight self.xctc.ctc_projection.weight = embed_tokens.weight
elif getattr(self, "share_xctc_and_ctc", False) and hasattr(self, "ctc"):
self.xctc.ctc_projection.weight = self.ctc.ctc_projection.weight
strategy = { strategy = {
"embed_norm": getattr(args, "pae_embed_norm", False), "embed_norm": getattr(args, "pae_embed_norm", False),
...@@ -2365,6 +2372,7 @@ def base_architecture(args): ...@@ -2365,6 +2372,7 @@ def base_architecture(args):
# XCTC # XCTC
args.xctc_layer = getattr(args, "xctc_layer", 0) args.xctc_layer = getattr(args, "xctc_layer", 0)
args.share_xctc_and_embed = getattr(args, "share_xctc_and_embed", False) args.share_xctc_and_embed = getattr(args, "share_xctc_and_embed", False)
args.share_xctc_and_ctc = getattr(args, "share_xctc_and_ctc", False)
args.xctc_pae = getattr(args, "xctc_pae", args.ctc_pae) args.xctc_pae = getattr(args, "xctc_pae", args.ctc_pae)
args.axctc_pae = getattr(args, "axctc_pae", args.xctc_pae) args.axctc_pae = getattr(args, "axctc_pae", args.xctc_pae)
args.share_pae_and_xctc = getattr(args, "share_pae_and_xctc", False) args.share_pae_and_xctc = getattr(args, "share_pae_and_xctc", False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论