Commit e186b036 by libei

revise share all embedding bugs and inference bugs

parent b97a0cfd
......@@ -131,36 +131,36 @@ def _upgrade_state_dict(state):
return state
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model
training
"""
# load model architectures and weights
states = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
#state = _upgrade_state_dict(state)
states.append(state)
args = states[0]['args']
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
# build ensemble
ensemble = []
for state in states:
model = task.build_model(args)
print(model)
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
# def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
# """Load an ensemble of models for inference.
#
# model_arg_overrides allows you to pass a dictionary model_arg_overrides --
# {'arg_name': arg} -- to override model args that were used during model
# training
# """
# # load model architectures and weights
# states = []
# for filename in filenames:
# if not os.path.exists(filename):
# raise IOError('Model file not found: {}'.format(filename))
# state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
# #state = _upgrade_state_dict(state)
# states.append(state)
# args = states[0]['args']
# if model_arg_overrides is not None:
# args = _override_model_args(args, model_arg_overrides)
#
# # build ensemble
# ensemble = []
# for state in states:
# model = task.build_model(args)
# print(model)
# model.upgrade_state_dict(state['model'])
# model.load_state_dict(state['model'], strict=True)
# ensemble.append(model)
# return ensemble, args
def load_ensemble_for_inference_2(filenames, task, model_arg_overrides=None):
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
......
......@@ -91,10 +91,12 @@ def find_useful_param(model_file):
share_decoder_input_and_softmax = True
tgt_vocab_size = int(match.group(1))
match = re.match('symbol_modality_(\d+)_(\d+)/shared/$', key)
match = re.match('symbol_modality_(\d+)_(\d+)/shared/', key)
if match and share_all_embedding is False:
share_all_embedding = True
src_vocab_size = int(match.group(1))
tgt_vocab_size = int(match.group(1))
emb_size = int(match.group(2))
except Exception as e: # pylint: disable=broad-except
print(str(e))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论