Commit 66913b25 by libei

support training deep dla model with relative position presentation by setting …

support training deep dla model with relative position presentation by setting  attention_type = "relative_dot_product"
parent d6b1aadf
...@@ -496,7 +496,6 @@ def _generate_relative_positions_matrix(length, max_relative_position, ...@@ -496,7 +496,6 @@ def _generate_relative_positions_matrix(length, max_relative_position,
def _generate_relative_positions_embeddings(length, depth, def _generate_relative_positions_embeddings(length, depth,
max_relative_position, name, max_relative_position, name,
debug_flag=None,
cache=False): cache=False):
"""Generates tensor of size [1 if cache else length, length, depth].""" """Generates tensor of size [1 if cache else length, length, depth]."""
with tf.variable_scope(name): with tf.variable_scope(name):
...@@ -506,14 +505,11 @@ def _generate_relative_positions_embeddings(length, depth, ...@@ -506,14 +505,11 @@ def _generate_relative_positions_embeddings(length, depth,
vocab_size = max_relative_position * 2 + 1 vocab_size = max_relative_position * 2 + 1
# Generates embedding for each relative position of dimension depth. # Generates embedding for each relative position of dimension depth.
embeddings_table = tf.get_variable("embeddings", [vocab_size, depth]) embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
if debug_flag is not None:
embeddings_table = tf.Print(embeddings_table, [embeddings_table], summarize=100, message="%s-rp-emb"%debug_flag)
relative_positions_matrix = tf.Print(relative_positions_matrix, [relative_positions_matrix], summarize=100, message="%s-rel-pos"%debug_flag)
embeddings = tf.gather(embeddings_table, relative_positions_matrix) embeddings = tf.gather(embeddings_table, relative_positions_matrix)
return embeddings return embeddings
def _relative_attention_inner(x, y, z, transpose, debug_flag=None): def _relative_attention_inner(x, y, z, transpose):
"""Relative position-aware dot-product attention inner calculation. """Relative position-aware dot-product attention inner calculation.
This batches matrix multiply calculations to avoid unnecessary broadcasting. This batches matrix multiply calculations to avoid unnecessary broadcasting.
...@@ -537,26 +533,16 @@ def _relative_attention_inner(x, y, z, transpose, debug_flag=None): ...@@ -537,26 +533,16 @@ def _relative_attention_inner(x, y, z, transpose, debug_flag=None):
# xy_matmul is [batch_size, heads, length, length or depth] # xy_matmul is [batch_size, heads, length, length or depth]
xy_matmul = tf.matmul(x, y, transpose_b=transpose) xy_matmul = tf.matmul(x, y, transpose_b=transpose)
if debug_flag is not None:
xy_matmul_t = tf.transpose(xy_matmul, [1,0,2,3])
xy_matmul = tf.Print(xy_matmul, [xy_matmul_t], message="c1", summarize=100)
# x_t is [length, batch_size, heads, length or depth] # x_t is [length, batch_size, heads, length or depth]
x_t = tf.transpose(x, [2, 0, 1, 3]) x_t = tf.transpose(x, [2, 0, 1, 3])
# x_t_r is [length, batch_size * heads, length or depth] # x_t_r is [length, batch_size * heads, length or depth]
x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1]) x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])
if debug_flag is not None:
a = tf.transpose(x_t, [0, 2, 1, 3])
x_t_r = tf.Print(x_t_r, [tf.reshape(a, [length, heads*batch_size,-1])], summarize=100, message="x_t_r")
z = tf.Print(z, [z], summarize=100, message="z")
# x_tz_matmul is [length, batch_size * heads, length or depth] # x_tz_matmul is [length, batch_size * heads, length or depth]
x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose) x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)
# x_tz_matmul_r is [length, batch_size, heads, length or depth] # x_tz_matmul_r is [length, batch_size, heads, length or depth]
x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1]) x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])
# x_tz_matmul_r_t is [batch_size, heads, length, length or depth] # x_tz_matmul_r_t is [batch_size, heads, length, length or depth]
x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3]) x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])
if debug_flag is not None:
tmp = tf.transpose(x_tz_matmul_r_t, [1,0,2,3])
x_tz_matmul_r_t = tf.Print(x_tz_matmul_r_t, [tmp], summarize=100, message="c2")
return xy_matmul + x_tz_matmul_r_t return xy_matmul + x_tz_matmul_r_t
def dot_product_attention_relative(q, def dot_product_attention_relative(q,
...@@ -569,7 +555,6 @@ def dot_product_attention_relative(q, ...@@ -569,7 +555,6 @@ def dot_product_attention_relative(q,
name=None, name=None,
make_image_summary=False, make_image_summary=False,
dropout_broadcast_dims=None, dropout_broadcast_dims=None,
debug_flag=None,
cache=False): cache=False):
"""Calculate relative position-aware dot-product self-attention. """Calculate relative position-aware dot-product self-attention.
...@@ -613,23 +598,15 @@ def dot_product_attention_relative(q, ...@@ -613,23 +598,15 @@ def dot_product_attention_relative(q,
# libei: relations_keys: (L, L, H), where H is hidden size of head # libei: relations_keys: (L, L, H), where H is hidden size of head
relations_keys = _generate_relative_positions_embeddings( relations_keys = _generate_relative_positions_embeddings(
length, depth, max_relative_position, "relative_positions_keys", length, depth, max_relative_position, "relative_positions_keys",
debug_flag=debug_flag+"+rp-key" if debug_flag is not None else None,
cache=cache) cache=cache)
relations_values = _generate_relative_positions_embeddings( relations_values = _generate_relative_positions_embeddings(
length, depth, max_relative_position, "relative_positions_values", length, depth, max_relative_position, "relative_positions_values",
debug_flag=debug_flag + "+rp-value" if debug_flag is not None else None,
cache=cache) cache=cache)
if debug_flag is not None:
q = tf.Print(q, [q], summarize=100, message="%s+q" % debug_flag)
k = tf.Print(k, [k], summarize=100, message="%s+k" % debug_flag)
relations_keys = tf.Print(relations_keys, [relations_keys], summarize=1000, message="%s+rk" % debug_flag)
# Compute self attention considering the relative position embeddings. # Compute self attention considering the relative position embeddings.
logits = _relative_attention_inner(q, k, relations_keys, True, debug_flag=debug_flag) logits = _relative_attention_inner(q, k, relations_keys, True)
if debug_flag is not None:
logits = tf.Print(logits, [logits], summarize=100, message="%s+logits"%debug_flag)
if bias is not None: if bias is not None:
logits += bias logits += bias
...@@ -639,10 +616,7 @@ def dot_product_attention_relative(q, ...@@ -639,10 +616,7 @@ def dot_product_attention_relative(q,
if not tf.get_variable_scope().reuse and make_image_summary: if not tf.get_variable_scope().reuse and make_image_summary:
attention_image_summary(weights, image_shapes) attention_image_summary(weights, image_shapes)
if debug_flag is not None:
weights = tf.Print(weights, [weights], summarize=100, message="%s+weight"%debug_flag)
v = tf.Print(v, [v], summarize=100, message="%s+v" % debug_flag)
relations_values = tf.Print(relations_values, [relations_values], summarize=100, message="%s+relations_values" % debug_flag)
return _relative_attention_inner(weights, v, relations_values, False) return _relative_attention_inner(weights, v, relations_values, False)
...@@ -658,8 +632,7 @@ def multihead_attention_relative_pos(query_antecedent, ...@@ -658,8 +632,7 @@ def multihead_attention_relative_pos(query_antecedent,
summaries=False, summaries=False,
image_shapes=None, image_shapes=None,
cache=None, cache=None,
name=None, name=None):
debug_flag=None):
"""Multihead scaled-dot-product attention with input/output transformations. """Multihead scaled-dot-product attention with input/output transformations.
Args: Args:
...@@ -720,18 +693,12 @@ def multihead_attention_relative_pos(query_antecedent, ...@@ -720,18 +693,12 @@ def multihead_attention_relative_pos(query_antecedent,
q *= key_depth_per_head**-0.5 q *= key_depth_per_head**-0.5
x = dot_product_attention_relative( x = dot_product_attention_relative(
q, k, v, bias, max_relative_position, dropout_rate, summaries, image_shapes, debug_flag=debug_flag q, k, v, bias, max_relative_position, dropout_rate, summaries, image_shapes)
)
if debug_flag is not None:
x = tf.Print(x, [x], summarize=100, message="att-out")
x = combine_heads(x) x = combine_heads(x)
if debug_flag is not None:
x = tf.Print(x, [x], summarize=100, message="att-out-merge")
x = common_layers.conv1d(x, output_depth, 1, name="output_transform") x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
if debug_flag is not None:
x = tf.Print(x, [x], summarize=100, message="final-att-out (after trans-out)")
return x return x
def ffn_self_attention_layer(x, def ffn_self_attention_layer(x,
......
...@@ -495,4 +495,13 @@ def transformer_dla_base_v2(): ...@@ -495,4 +495,13 @@ def transformer_dla_base_v2():
hparams = transformer_dla_base() hparams = transformer_dla_base()
hparams.learning_rate = 0.4 * (2**0.5) hparams.learning_rate = 0.4 * (2**0.5)
hparams.learning_rate_warmup_steps = 16000 hparams.learning_rate_warmup_steps = 16000
return hparams
@registry.register_hparams
def transformer_dla_rpr_base25():
# share the decoder input and the softmax embedding
hparams = transformer_dla_base()
hparams.encoder_layers = 25
hparams.attention_type = "relative_dot_product"
return hparams return hparams
\ No newline at end of file
...@@ -151,7 +151,6 @@ def transformer_encoder(encoder_input, ...@@ -151,7 +151,6 @@ def transformer_encoder(encoder_input,
with tf.variable_scope(name): with tf.variable_scope(name):
for layer in xrange(hparams.num_hidden_layers): for layer in xrange(hparams.num_hidden_layers):
#debug_flag = "%s+layer%d" %(name,layer) if layer == 0 else None #debug_flag = "%s+layer%d" %(name,layer) if layer == 0 else None
debug_flag = None
with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("layer_%d" % layer):
x = residual_fn( x = residual_fn(
x, x,
...@@ -166,8 +165,7 @@ def transformer_encoder(encoder_input, ...@@ -166,8 +165,7 @@ def transformer_encoder(encoder_input,
hparams.attention_dropout, hparams.attention_dropout,
hparams.max_relative_length, hparams.max_relative_length,
summaries=False, summaries=False,
name="encoder_self_attention", name="encoder_self_attention"))
debug_flag=debug_flag))
#x = tf.Print(x, [x], message="%s+layer%d+++before-fnn" % (name, layer), summarize=100) #x = tf.Print(x, [x], message="%s+layer%d+++before-fnn" % (name, layer), summarize=100)
x = residual_fn(x, transformer_ffn_layer(x, hparams)) x = residual_fn(x, transformer_ffn_layer(x, hparams))
#x = tf.Print(x, [x], message="%s+layer%d+++out"%(name,layer), summarize=100) #x = tf.Print(x, [x], message="%s+layer%d+++out"%(name,layer), summarize=100)
...@@ -219,8 +217,7 @@ def transformer_decoder(decoder_input, ...@@ -219,8 +217,7 @@ def transformer_decoder(decoder_input,
hparams.attention_dropout, hparams.attention_dropout,
hparams.max_relative_length, hparams.max_relative_length,
summaries=False, summaries=False,
name="decoder_self_attention", name="decoder_self_attention"))
debug_flag=debug_flag))
x = residual_fn( x = residual_fn(
x, x,
common_attention.multihead_attention( common_attention.multihead_attention(
...@@ -235,8 +232,6 @@ def transformer_decoder(decoder_input, ...@@ -235,8 +232,6 @@ def transformer_decoder(decoder_input,
summaries=False, summaries=False,
name="encdec_attention")) name="encdec_attention"))
x = residual_fn(x, transformer_ffn_layer(x, hparams)) x = residual_fn(x, transformer_ffn_layer(x, hparams))
if debug_flag is not None:
x = tf.Print(x, [x], message="%s+layer%d+++out"%(name,layer), summarize=100)
return x return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论