Commit b48c002d by xuchen

optimize the logging

parent 4cffbd98
......@@ -711,6 +711,10 @@ def load_pretrained_component_from_model(
missing_keys, unexpected_keys = component.load_state_dict(component_state_dict, strict=strict)
if len(mismatch_keys) > 0:
logger.warning(
'Mismatch key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in mismatch_keys)))
if len(unexpected_keys) > 0:
logger.warning(
'Unexpected key(s) in state_dict: {}. '.format(
......@@ -719,10 +723,6 @@ def load_pretrained_component_from_model(
logger.warning(
'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))
if len(mismatch_keys) > 0:
logger.warning(
'Mismatch key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in mismatch_keys)))
return component
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论