Commit a8105353 by xuchen

update the module name during loading pre-trained model

parent 42ea101c
...@@ -377,6 +377,14 @@ def load_model_ensemble_and_task( ...@@ -377,6 +377,14 @@ def load_model_ensemble_and_task(
# build model for ensemble # build model for ensemble
model = task.build_model(cfg.model) model = task.build_model(cfg.model)
new_model = dict()
for key in state["model"].keys():
org_key = key
key = key.replace("transformer_layer", "layer").replace("conformer_layer", "layer")
new_model[key] = state["model"][org_key]
del state["model"]
state["model"] = new_model
model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model)
# reset state so it gets loaded for the next model in ensemble # reset state so it gets loaded for the next model in ensemble
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论