Commit 66913b25 by libei

support training deep dla model with relative position presentation by setting …

support training deep dla model with relative position presentation by setting  attention_type = "relative_dot_product"
parent d6b1aadf
......@@ -496,7 +496,6 @@ def _generate_relative_positions_matrix(length, max_relative_position,
def _generate_relative_positions_embeddings(length, depth,
max_relative_position, name,
debug_flag=None,
cache=False):
"""Generates tensor of size [1 if cache else length, length, depth]."""
with tf.variable_scope(name):
......@@ -506,14 +505,11 @@ def _generate_relative_positions_embeddings(length, depth,
vocab_size = max_relative_position * 2 + 1
# Generates embedding for each relative position of dimension depth.
embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
if debug_flag is not None:
embeddings_table = tf.Print(embeddings_table, [embeddings_table], summarize=100, message="%s-rp-emb"%debug_flag)
relative_positions_matrix = tf.Print(relative_positions_matrix, [relative_positions_matrix], summarize=100, message="%s-rel-pos"%debug_flag)
embeddings = tf.gather(embeddings_table, relative_positions_matrix)
return embeddings
def _relative_attention_inner(x, y, z, transpose, debug_flag=None):
def _relative_attention_inner(x, y, z, transpose):
"""Relative position-aware dot-product attention inner calculation.
This batches matrix multiply calculations to avoid unnecessary broadcasting.
......@@ -537,26 +533,16 @@ def _relative_attention_inner(x, y, z, transpose, debug_flag=None):
# xy_matmul is [batch_size, heads, length, length or depth]
xy_matmul = tf.matmul(x, y, transpose_b=transpose)
if debug_flag is not None:
xy_matmul_t = tf.transpose(xy_matmul, [1,0,2,3])
xy_matmul = tf.Print(xy_matmul, [xy_matmul_t], message="c1", summarize=100)
# x_t is [length, batch_size, heads, length or depth]
x_t = tf.transpose(x, [2, 0, 1, 3])
# x_t_r is [length, batch_size * heads, length or depth]
x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])
if debug_flag is not None:
a = tf.transpose(x_t, [0, 2, 1, 3])
x_t_r = tf.Print(x_t_r, [tf.reshape(a, [length, heads*batch_size,-1])], summarize=100, message="x_t_r")
z = tf.Print(z, [z], summarize=100, message="z")
# x_tz_matmul is [length, batch_size * heads, length or depth]
x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)
# x_tz_matmul_r is [length, batch_size, heads, length or depth]
x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])
# x_tz_matmul_r_t is [batch_size, heads, length, length or depth]
x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])
if debug_flag is not None:
tmp = tf.transpose(x_tz_matmul_r_t, [1,0,2,3])
x_tz_matmul_r_t = tf.Print(x_tz_matmul_r_t, [tmp], summarize=100, message="c2")
return xy_matmul + x_tz_matmul_r_t
def dot_product_attention_relative(q,
......@@ -569,7 +555,6 @@ def dot_product_attention_relative(q,
name=None,
make_image_summary=False,
dropout_broadcast_dims=None,
debug_flag=None,
cache=False):
"""Calculate relative position-aware dot-product self-attention.
......@@ -613,23 +598,15 @@ def dot_product_attention_relative(q,
# libei: relations_keys: (L, L, H), where H is hidden size of head
relations_keys = _generate_relative_positions_embeddings(
length, depth, max_relative_position, "relative_positions_keys",
debug_flag=debug_flag+"+rp-key" if debug_flag is not None else None,
cache=cache)
relations_values = _generate_relative_positions_embeddings(
length, depth, max_relative_position, "relative_positions_values",
debug_flag=debug_flag + "+rp-value" if debug_flag is not None else None,
cache=cache)
if debug_flag is not None:
q = tf.Print(q, [q], summarize=100, message="%s+q" % debug_flag)
k = tf.Print(k, [k], summarize=100, message="%s+k" % debug_flag)
relations_keys = tf.Print(relations_keys, [relations_keys], summarize=1000, message="%s+rk" % debug_flag)
# Compute self attention considering the relative position embeddings.
logits = _relative_attention_inner(q, k, relations_keys, True, debug_flag=debug_flag)
logits = _relative_attention_inner(q, k, relations_keys, True)
if debug_flag is not None:
logits = tf.Print(logits, [logits], summarize=100, message="%s+logits"%debug_flag)
if bias is not None:
logits += bias
......@@ -639,10 +616,7 @@ def dot_product_attention_relative(q,
if not tf.get_variable_scope().reuse and make_image_summary:
attention_image_summary(weights, image_shapes)
if debug_flag is not None:
weights = tf.Print(weights, [weights], summarize=100, message="%s+weight"%debug_flag)
v = tf.Print(v, [v], summarize=100, message="%s+v" % debug_flag)
relations_values = tf.Print(relations_values, [relations_values], summarize=100, message="%s+relations_values" % debug_flag)
return _relative_attention_inner(weights, v, relations_values, False)
......@@ -658,8 +632,7 @@ def multihead_attention_relative_pos(query_antecedent,
summaries=False,
image_shapes=None,
cache=None,
name=None,
debug_flag=None):
name=None):
"""Multihead scaled-dot-product attention with input/output transformations.
Args:
......@@ -720,18 +693,12 @@ def multihead_attention_relative_pos(query_antecedent,
q *= key_depth_per_head**-0.5
x = dot_product_attention_relative(
q, k, v, bias, max_relative_position, dropout_rate, summaries, image_shapes, debug_flag=debug_flag
)
if debug_flag is not None:
x = tf.Print(x, [x], summarize=100, message="att-out")
q, k, v, bias, max_relative_position, dropout_rate, summaries, image_shapes)
x = combine_heads(x)
if debug_flag is not None:
x = tf.Print(x, [x], summarize=100, message="att-out-merge")
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
if debug_flag is not None:
x = tf.Print(x, [x], summarize=100, message="final-att-out (after trans-out)")
return x
def ffn_self_attention_layer(x,
......
......@@ -495,4 +495,13 @@ def transformer_dla_base_v2():
hparams = transformer_dla_base()
hparams.learning_rate = 0.4 * (2**0.5)
hparams.learning_rate_warmup_steps = 16000
return hparams
@registry.register_hparams
def transformer_dla_rpr_base25():
# share the decoder input and the softmax embedding
hparams = transformer_dla_base()
hparams.encoder_layers = 25
hparams.attention_type = "relative_dot_product"
return hparams
\ No newline at end of file
......@@ -151,7 +151,6 @@ def transformer_encoder(encoder_input,
with tf.variable_scope(name):
for layer in xrange(hparams.num_hidden_layers):
#debug_flag = "%s+layer%d" %(name,layer) if layer == 0 else None
debug_flag = None
with tf.variable_scope("layer_%d" % layer):
x = residual_fn(
x,
......@@ -166,8 +165,7 @@ def transformer_encoder(encoder_input,
hparams.attention_dropout,
hparams.max_relative_length,
summaries=False,
name="encoder_self_attention",
debug_flag=debug_flag))
name="encoder_self_attention"))
#x = tf.Print(x, [x], message="%s+layer%d+++before-fnn" % (name, layer), summarize=100)
x = residual_fn(x, transformer_ffn_layer(x, hparams))
#x = tf.Print(x, [x], message="%s+layer%d+++out"%(name,layer), summarize=100)
......@@ -219,8 +217,7 @@ def transformer_decoder(decoder_input,
hparams.attention_dropout,
hparams.max_relative_length,
summaries=False,
name="decoder_self_attention",
debug_flag=debug_flag))
name="decoder_self_attention"))
x = residual_fn(
x,
common_attention.multihead_attention(
......@@ -235,8 +232,6 @@ def transformer_decoder(decoder_input,
summaries=False,
name="encdec_attention"))
x = residual_fn(x, transformer_ffn_layer(x, hparams))
if debug_flag is not None:
x = tf.Print(x, [x], message="%s+layer%d+++out"%(name,layer), summarize=100)
return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论