Commit b78c7894 by xuchen

support the pre-trained models

parent ea1870e5
......@@ -658,7 +658,7 @@ def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
def load_pretrained_component_from_model(
component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str
component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str, strict: bool = True,
):
"""
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
......@@ -684,7 +684,46 @@ def load_pretrained_component_from_model(
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
component_subkey = key[len(component_type) + 1 :]
component_state_dict[component_subkey] = state["model"][key]
component.load_state_dict(component_state_dict, strict=True)
mismatch_keys = []
if not strict:
def check(load_state_dict, modules, prefix=''):
for name, param in modules._parameters.items():
key = prefix + name
if key in load_state_dict:
load_param = load_state_dict[key]
if load_param.shape != param.shape:
mismatch_keys.append(key)
for name, child in modules._modules.items():
check(load_state_dict, child, name + prefix + ".")
check(component_state_dict, component)
# parameters = component.named_parameters()
# for key, tensor in parameters:
# if key in component_state_dict and tensor.shape != component_state_dict[key].shape:
# mismatch_keys.append(key)
for key in mismatch_keys:
del component_state_dict[key]
missing_keys, unexpected_keys = component.load_state_dict(component_state_dict, strict=strict)
if len(unexpected_keys) > 0:
logger.warning(
'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
logger.warning(
'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))
if len(mismatch_keys) > 0:
logger.warning(
'Mismatch key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in mismatch_keys)))
return component
......
......@@ -76,7 +76,7 @@ class S2TConformerModel(S2TTransformerModel):
encoder = S2TConformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info(
f"loaded pretrained encoder from: "
......
......@@ -223,6 +223,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
pass
@classmethod
......@@ -230,7 +236,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
encoder = S2TTransformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info(
f"loaded pretrained encoder from: "
......@@ -243,7 +249,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None):
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_encoder_from
component=decoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info(
f"loaded pretrained decoder from: "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论