Commit 0704bedb by xuchen

support to load the previous model

parent b48c002d
......@@ -678,6 +678,15 @@ def load_pretrained_component_from_model(
"component to load must be either a FairseqEncoder or "
"FairseqDecoder. Loading other component types are not supported."
)
new_model = dict()
for key in state["model"].keys():
org_key = key
key = key.replace("transformer_layer", "layer").replace("conformer_layer", "layer")
new_model[key] = state["model"][org_key]
del state["model"]
state["model"] = new_model
component_state_dict = OrderedDict()
for key in state["model"].keys():
if key.startswith(component_type):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论