Commit 8feddb2d by libei

support rpr model convert

parent ed1f4232
......@@ -195,6 +195,16 @@ def load_param():
param_dict['enc_%d_out_trans_w' % layer_id] = output_tarns_w
param_dict['enc_%d_out_trans_b' % layer_id] = tf.reshape(output_trans_b, [1, -1])
# step 10. optional. used for relative position representation
if use_relative_position_representation is not None and use_relative_position_representation:
rp_key = _get_tensor(
'body/encoder/layer_%d/encoder_self_attention/dot_product_attention_relative/relative_positions_keys/embeddings' % layer_id)
rp_value = _get_tensor(
'body/encoder/layer_%d/encoder_self_attention/dot_product_attention_relative/relative_positions_values/embeddings' % layer_id)
param_dict['enc_%d_rp_key' % layer_id] = rp_key
param_dict['enc_%d_rp_value' % layer_id] = rp_value
# step 3. layer normalization
layer_normal_bias = _get_tensor('body/encoder/layer_%d/layer_norm/layer_norm_bias' % layer_id)
layer_normal_scale = _get_tensor('body/encoder/layer_%d/layer_norm/layer_norm_scale' % layer_id)
......@@ -273,6 +283,14 @@ def load_param():
param_dict['dec_%d_decatt_out_trans_w' % layer_id] = output_tarns_w
param_dict['dec_%d_decatt_out_trans_b' % layer_id] = tf.reshape(output_trans_b, [1, -1])
if use_relative_position_representation:
rp_key = _get_tensor(
'body/decoder/layer_%d/decoder_self_attention/dot_product_attention_relative/relative_positions_keys/embeddings' % layer_id)
rp_value = _get_tensor(
'body/decoder/layer_%d/decoder_self_attention/dot_product_attention_relative/relative_positions_values/embeddings' % layer_id)
param_dict['dec_%d_rp_key' % layer_id] = rp_key
param_dict['dec_%d_rp_value' % layer_id] = rp_value
# layer normalization
layer_normal_bias = _get_tensor('body/decoder/layer_%d/layer_norm/layer_norm_bias' % layer_id)
layer_normal_scale = _get_tensor('body/decoder/layer_%d/layer_norm/layer_norm_scale' % layer_id)
......@@ -450,6 +468,10 @@ def convert_settings(settings):
if share_decoder_input_and_softmax:
args['share_decoder_input_output_embed'] = True
if use_relative_position_representation:
args['max_relative_length'] = int(settings['max_relative_length'])
args['arch'] = 'relative_transformer'
return argparse.Namespace(**args)
......@@ -509,6 +531,9 @@ def convert_param():
model['%s.in_proj_bias' % p1] = _get_param_numpy('enc_%d_qvk_trans_b' % p2, sess, transpose=True)
model['%s.out_proj.weight' % p1] = _get_param_numpy('enc_%d_out_trans_w' % p2, sess, transpose=True)
model['%s.out_proj.bias' % p1] = _get_param_numpy('enc_%d_out_trans_b' % p2, sess, transpose=True)
if use_relative_position_representation:
model['%s.relative_position_keys' % p1] = _get_param_numpy('enc_%d_rp_key' % p2, sess, transpose=False)
model['%s.relative_position_values' % p1] = _get_param_numpy('enc_%d_rp_value' % p2, sess, transpose=False)
p1 = 'encoder.layers.%d' % layer_id
......@@ -567,6 +592,9 @@ def convert_param():
model['%s.in_proj_bias' % p1] = _get_param_numpy('dec_%d_decatt_qvk_trans_b' % p2, sess, transpose=True)
model['%s.out_proj.weight' % p1] = _get_param_numpy('dec_%d_decatt_out_trans_w' % p2, sess, transpose=True)
model['%s.out_proj.bias' % p1] = _get_param_numpy('dec_%d_decatt_out_trans_b' % p2, sess, transpose=True)
if use_relative_position_representation:
model['%s.relative_position_keys' % p1] = _get_param_numpy('dec_%d_rp_key' % p2, sess, transpose=False)
model['%s.relative_position_values' % p1] = _get_param_numpy('dec_%d_rp_value' % p2, sess, transpose=False)
p1 = 'decoder.layers.%d.encoder_attn' % layer_id
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论