Commit e186b036 by libei

revise share all embedding bugs and inference bugs

parent b97a0cfd
...@@ -131,36 +131,36 @@ def _upgrade_state_dict(state): ...@@ -131,36 +131,36 @@ def _upgrade_state_dict(state):
return state return state
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): # def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference. # """Load an ensemble of models for inference.
#
model_arg_overrides allows you to pass a dictionary model_arg_overrides -- # model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model # {'arg_name': arg} -- to override model args that were used during model
training # training
""" # """
# load model architectures and weights # # load model architectures and weights
states = [] # states = []
for filename in filenames: # for filename in filenames:
if not os.path.exists(filename): # if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename)) # raise IOError('Model file not found: {}'.format(filename))
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) # state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
#state = _upgrade_state_dict(state) # #state = _upgrade_state_dict(state)
states.append(state) # states.append(state)
args = states[0]['args'] # args = states[0]['args']
if model_arg_overrides is not None: # if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides) # args = _override_model_args(args, model_arg_overrides)
#
# build ensemble # # build ensemble
ensemble = [] # ensemble = []
for state in states: # for state in states:
model = task.build_model(args) # model = task.build_model(args)
print(model) # print(model)
model.upgrade_state_dict(state['model']) # model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True) # model.load_state_dict(state['model'], strict=True)
ensemble.append(model) # ensemble.append(model)
return ensemble, args # return ensemble, args
def load_ensemble_for_inference_2(filenames, task, model_arg_overrides=None): def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference. """Load an ensemble of models for inference.
model_arg_overrides allows you to pass a dictionary model_arg_overrides -- model_arg_overrides allows you to pass a dictionary model_arg_overrides --
......
...@@ -91,10 +91,12 @@ def find_useful_param(model_file): ...@@ -91,10 +91,12 @@ def find_useful_param(model_file):
share_decoder_input_and_softmax = True share_decoder_input_and_softmax = True
tgt_vocab_size = int(match.group(1)) tgt_vocab_size = int(match.group(1))
match = re.match('symbol_modality_(\d+)_(\d+)/shared/$', key) match = re.match('symbol_modality_(\d+)_(\d+)/shared/', key)
if match and share_all_embedding is False: if match and share_all_embedding is False:
share_all_embedding = True share_all_embedding = True
src_vocab_size = int(match.group(1))
tgt_vocab_size = int(match.group(1)) tgt_vocab_size = int(match.group(1))
emb_size = int(match.group(2))
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
print(str(e)) print(str(e))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论