Commit d2e69e64 by libei

support share_all_embedding and share_decoder_input_and_softmax

parent 37be33b3
......@@ -30,7 +30,8 @@ def find_useful_param(model_file):
global max_relative_length
global normalize_before
global use_dense
global share_all_embedding
global share_decoder_input_and_softmax
trainable_param = dict()
try:
reader = pywrap_tensorflow.NewCheckpointReader(model_file)
......@@ -85,6 +86,16 @@ def find_useful_param(model_file):
if match and use_dense is False:
use_dense = True
match = re.match('symbol_modality_(\d+)_(\d+)/shared_target',key)
if match and share_decoder_input_and_softmax is False:
share_decoder_input_and_softmax = True
tgt_vocab_size = int(match.group(1))
match = re.match('symbol_modality_(\d+)_(\d+)/shared/$', key)
if match and share_all_embedding is False:
share_all_embedding = True
tgt_vocab_size = int(match.group(1))
except Exception as e: # pylint: disable=broad-except
print(str(e))
assert len(trainable_param) > 0, "not found any trainable parameters"
......@@ -99,9 +110,9 @@ def find_useful_param(model_file):
if use_relative_position_representation is None:
use_relative_position_representation = 0
print('find src-vocab:{} tgt-vocab:{} emb-size:{} src-layer:{} tgt-layer:{} activation:{} relative:{} normalize_before:{} use_dense:{}'.
print('find src-vocab:{} tgt-vocab:{} emb-size:{} src-layer:{} tgt-layer:{} activation:{} relative:{} normalize_before:{} use_dense:{} share_decoder_input_and_softmax:{} share_all_embedding:{}'.
format(src_vocab_size, tgt_vocab_size, emb_size, src_layer_num, tgt_layer_num,
activation_function, use_relative_position_representation, normalize_before, use_dense))
activation_function, use_relative_position_representation, normalize_before, use_dense, share_decoder_input_and_softmax, share_all_embedding))
......@@ -127,7 +138,11 @@ def load_param():
return tensor
# source embedding, shape is src_vocab * emb_size
src_emb_tensor = _get_tensor('symbol_modality_{}_{}/input_emb/weights'.format(src_vocab_size, emb_size), shard=shard)
if share_all_embedding:
src_emb_tensor = _get_tensor('symbol_modality_{}_{}/shared/weights'.format(src_vocab_size, emb_size),
shard=shard)
else:
src_emb_tensor = _get_tensor('symbol_modality_{}_{}/input_emb/weights'.format(src_vocab_size, emb_size), shard=shard)
param_dict['src_emb'] = src_emb_tensor
# space embedding, we only use target_space_id, so the shape is H
......@@ -231,8 +246,12 @@ def load_param():
# decoder
# target embedding, shape is tgt_vocab * emb_size
tgt_emb_tensor = _get_tensor('symbol_modality_{}_{}/target_emb/weights'.format(tgt_vocab_size, emb_size),
shard=shard)
if share_all_embedding:
tgt_emb_tensor = _get_tensor('symbol_modality_{}_{}/shared/weights'.format(tgt_vocab_size, emb_size),shard=shard)
elif share_decoder_input_and_softmax:
tgt_emb_tensor = _get_tensor('symbol_modality_{}_{}/shared_target/weights'.format(tgt_vocab_size, emb_size),shard=shard)
else:
tgt_emb_tensor = _get_tensor('symbol_modality_{}_{}/target_emb/weights'.format(tgt_vocab_size, emb_size),shard=shard)
param_dict['tgt_emb'] = tgt_emb_tensor
param_dict['tgt_pos_emb'] = get_pos_emb(tgt_max_length, emb_size=emb_size)
......@@ -333,7 +352,15 @@ def load_param():
layer_weight = _get_tensor('body/decoder/layer_history/layer_weight')
param_dict['dec_layer_weight'] = layer_weight
softmax_w_tensor = _get_tensor('symbol_modality_{}_{}/softmax/weights'.format(tgt_vocab_size, emb_size),
if share_all_embedding:
softmax_w_tensor = _get_tensor('symbol_modality_{}_{}/shared/weights'.format(tgt_vocab_size, emb_size),
shard=shard)
if share_decoder_input_and_softmax:
softmax_w_tensor = _get_tensor('symbol_modality_{}_{}/shared_target/weights'.format(tgt_vocab_size, emb_size),
shard=shard)
else:
softmax_w_tensor = _get_tensor('symbol_modality_{}_{}/softmax/weights'.format(tgt_vocab_size, emb_size),
shard=shard)
# note: we transpose the matrix, from (V,H) to (H,V)
param_dict['softmax_w'] = softmax_w_tensor
......@@ -421,6 +448,11 @@ def convert_settings(settings):
args['decoder_normalize_before'] = True
args['attention_dropout'] = 0.1
args['relu_dropout'] = 0.1
assert share_all_embedding ^ share_decoder_input_and_softmax
if share_all_embedding:
args['share_all_embeddings'] = True
if share_decoder_input_and_softmax:
args['share_decoder_input_output_embed'] = True
return argparse.Namespace(**args)
......@@ -690,6 +722,8 @@ if __name__ == '__main__':
max_relative_length = -1
normalize_before = False
use_dense = False
share_all_embedding=False
share_decoder_input_and_softmax=False
start = time.time()
......
......@@ -52,6 +52,8 @@ def main(args):
# Build model and criterion
model = task.build_model(args)
# for k,v in model.named_parameters():
# print("k:%s"%k)
print(model)
criterion = task.build_criterion(args)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论