Commit d2cb256f by libei

fix bugs

parent 22b357ea
...@@ -188,7 +188,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): ...@@ -188,7 +188,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
for index,state in enumerate(states): for index,state in enumerate(states):
model = task.build_model(args[index]) model = task.build_model(args[index])
model.upgrade_state_dict(state['model']) model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=False) model.load_state_dict(state['model'], strict=True)
ensemble.append(model) ensemble.append(model)
return ensemble, args return ensemble, args
......
...@@ -54,7 +54,6 @@ def main(args): ...@@ -54,7 +54,6 @@ def main(args):
model = task.build_model(args) model = task.build_model(args)
# for k,v in model.named_parameters(): # for k,v in model.named_parameters():
# print("k:%s"%k) # print("k:%s"%k)
print(model)
criterion = task.build_criterion(args) criterion = task.build_criterion(args)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters()))) print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
......
...@@ -3,7 +3,7 @@ set -e ...@@ -3,7 +3,7 @@ set -e
model_root_dir=checkpoints model_root_dir=checkpoints
# set tag # set tag
model_dir_tag=baseline_v2 model_dir_tag=baseline
model_dir=$model_root_dir/$model_dir_tag model_dir=$model_root_dir/$model_dir_tag
ensemble= ensemble=
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论