Commit 64a9e37e by xuchen

add the parameters "load_pretrained_decoder_from"

parent 81caa4ca
......@@ -240,7 +240,16 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
@classmethod
def build_decoder(cls, args, task, embed_tokens):
return TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None):
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_encoder_from
)
logger.info(
f"loaded pretrained decoder from: "
f"{args.load_pretrained_decoder_from}"
)
return decoder
@classmethod
def build_model(cls, args, task):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论