Commit 400e4f5a by libei

fix bugs

parent d2e69e64
......@@ -354,14 +354,11 @@ def load_param():
if share_all_embedding:
softmax_w_tensor = _get_tensor('symbol_modality_{}_{}/shared/weights'.format(tgt_vocab_size, emb_size),
shard=shard)
if share_decoder_input_and_softmax:
softmax_w_tensor = _get_tensor('symbol_modality_{}_{}/shared_target/weights'.format(tgt_vocab_size, emb_size),
shard=shard)
softmax_w_tensor = _get_tensor('symbol_modality_{}_{}/shared/weights'.format(tgt_vocab_size, emb_size),shard=shard)
elif share_decoder_input_and_softmax:
softmax_w_tensor = _get_tensor('symbol_modality_{}_{}/shared_target/weights'.format(tgt_vocab_size, emb_size),shard=shard)
else:
softmax_w_tensor = _get_tensor('symbol_modality_{}_{}/softmax/weights'.format(tgt_vocab_size, emb_size),
shard=shard)
softmax_w_tensor = _get_tensor('symbol_modality_{}_{}/softmax/weights'.format(tgt_vocab_size, emb_size),shard=shard)
# note: we transpose the matrix, from (V,H) to (H,V)
param_dict['softmax_w'] = softmax_w_tensor
return param_dict
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论