Commit d583bfcb by libeineu

fix share all embedding and share decoder embedding with the softmax bugs

parent 400e4f5a
...@@ -7,7 +7,7 @@ set -e ...@@ -7,7 +7,7 @@ set -e
# more device will not be used. e.g. you set device=(0 1 2 3), but you only choose three evalset, the gpu=3 will not be used # more device will not be used. e.g. you set device=(0 1 2 3), but you only choose three evalset, the gpu=3 will not be used
device=(0 1 2 3 4 5 6 7) device=(0 1 2 3 4 5 6 7)
# your tag, must set! # your tag, must set!
tag=baseline tag=base25_shared
# you should select model params for choosing the correct num heads # you should select model params for choosing the correct num heads
params=transformer_base params=transformer_base
model_dir=t2tmodel/$tag/ensemble15 model_dir=t2tmodel/$tag/ensemble15
......
...@@ -139,8 +139,7 @@ def load_param(): ...@@ -139,8 +139,7 @@ def load_param():
# source embedding, shape is src_vocab * emb_size # source embedding, shape is src_vocab * emb_size
if share_all_embedding: if share_all_embedding:
src_emb_tensor = _get_tensor('symbol_modality_{}_{}/shared/weights'.format(src_vocab_size, emb_size), src_emb_tensor = _get_tensor('symbol_modality_{}_{}/shared/weights'.format(src_vocab_size, emb_size), shard=shard)
shard=shard)
else: else:
src_emb_tensor = _get_tensor('symbol_modality_{}_{}/input_emb/weights'.format(src_vocab_size, emb_size), shard=shard) src_emb_tensor = _get_tensor('symbol_modality_{}_{}/input_emb/weights'.format(src_vocab_size, emb_size), shard=shard)
param_dict['src_emb'] = src_emb_tensor param_dict['src_emb'] = src_emb_tensor
...@@ -241,7 +240,6 @@ def load_param(): ...@@ -241,7 +240,6 @@ def load_param():
# step 8. dense transformer weight matrix # step 8. dense transformer weight matrix
if use_dense: if use_dense:
layer_weight = _get_tensor('body/encoder/layer_history/layer_weight') layer_weight = _get_tensor('body/encoder/layer_history/layer_weight')
print(type(layer_weight))
param_dict['enc_layer_weight'] = layer_weight param_dict['enc_layer_weight'] = layer_weight
# decoder # decoder
...@@ -552,12 +550,14 @@ def convert_param(): ...@@ -552,12 +550,14 @@ def convert_param():
model['decoder.embed_tokens.weight'] = torch.cat([lua_row, pad_row, eos_row, unk_row, embed[4:, :]], dim=0) model['decoder.embed_tokens.weight'] = torch.cat([lua_row, pad_row, eos_row, unk_row, embed[4:, :]], dim=0)
""" """
lua_row = torch.zeros(1, embed.size(1)) lua_row = torch.zeros(1, embed.size(1))
model['decoder.embed_tokens.weight'] = torch.cat([lua_row, embed], dim=0) if not share_all_embedding:
model['decoder.embed_tokens.weight'] = torch.cat([lua_row, embed], dim=0)
# model['decoder.embed_positions._float_tensor'] = torch.Tensor([]) # model['decoder.embed_positions._float_tensor'] = torch.Tensor([])
# in fairseq, pos-index is from 2 # in fairseq, pos-index is from 2
pos_emb = _get_param_numpy("tgt_pos_emb", sess) pos_emb = _get_param_numpy("tgt_pos_emb", sess)
pos_emb = torch.cat([torch.zeros(2, pos_emb.size(1)), pos_emb[:-2, :]], dim=0) pos_emb = torch.cat([torch.zeros(2, pos_emb.size(1)), pos_emb[:-2, :]], dim=0)
model['decoder.embed_positions.weight'] = pos_emb model['decoder.embed_positions.weight'] = pos_emb
for layer_id in range(int(tgt_layer_num)): for layer_id in range(int(tgt_layer_num)):
p1 = 'decoder.layers.%d.self_attn' % layer_id p1 = 'decoder.layers.%d.self_attn' % layer_id
...@@ -624,7 +624,8 @@ def convert_param(): ...@@ -624,7 +624,8 @@ def convert_param():
""" """
# lua_row = torch.zeros(1, softmax_w_tensor.size(1)).fill_(float('-inf')) # lua_row = torch.zeros(1, softmax_w_tensor.size(1)).fill_(float('-inf'))
lua_row = torch.zeros(1, softmax_w_tensor.size(1)) lua_row = torch.zeros(1, softmax_w_tensor.size(1))
model['decoder.embed_out'] = torch.cat([lua_row, softmax_w_tensor], dim=0) if share_all_embedding == False and share_decoder_input_and_softmax == False:
model['decoder.embed_out'] = torch.cat([lua_row, softmax_w_tensor], dim=0)
return model return model
def write_vocab(): def write_vocab():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论