Commit ad6cc8e3 by libei

support dot product relative fast decode

parent cc78cf5b
...@@ -470,13 +470,18 @@ def multihead_attention(query_antecedent, ...@@ -470,13 +470,18 @@ def multihead_attention(query_antecedent,
x = common_layers.conv1d(x, output_depth, 1, name="output_transform") x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
return x return x
def _generate_relative_positions_matrix(length, max_relative_position): def _generate_relative_positions_matrix(length, max_relative_position,
cache=False):
"""Generates matrix of relative positions between inputs.""" """Generates matrix of relative positions between inputs."""
range_vec = tf.range(length) if not cache:
range_mat = tf.reshape(tf.tile(range_vec, [length]), [length, length]) #training process
#range_mat = tf.Print(range_mat, [range_mat], message="range_mat:", summarize=100) range_vec = tf.range(length)
distance_mat = range_mat - tf.transpose(range_mat) range_mat = tf.reshape(tf.tile(range_vec, [length]), [length, length])
#distance_mat = tf.Print(distance_mat, [distance_mat], message="distance_mat:", summarize=100) # range_mat = tf.Print(range_mat, [range_mat], message="range_mat:", summarize=100)
distance_mat = range_mat - tf.transpose(range_mat)
# distance_mat = tf.Print(distance_mat, [distance_mat], message="distance_mat:", summarize=100)
else:
distance_mat = tf.expand_dims(tf.range(-length+1, 1, 1), 0)
distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position, distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position,
max_relative_position) max_relative_position)
...@@ -490,12 +495,14 @@ def _generate_relative_positions_matrix(length, max_relative_position): ...@@ -490,12 +495,14 @@ def _generate_relative_positions_matrix(length, max_relative_position):
return final_mat return final_mat
def _generate_relative_positions_embeddings(length, depth, def _generate_relative_positions_embeddings(length, depth,
max_relative_position, name, debug_flag=None): max_relative_position, name,
"""Generates tensor of size [length, length, depth].""" debug_flag=None,
cache=False):
"""Generates tensor of size [1 if cache else length, length, depth]."""
with tf.variable_scope(name): with tf.variable_scope(name):
# wq: relative_positions_matrix shape is (L, L), value range is 0 ~ 2K (total 2K+1) # libei: relative_positions_matrix shape is (L, L), value range is 0 ~ 2K (total 2K+1)
relative_positions_matrix = _generate_relative_positions_matrix( relative_positions_matrix = _generate_relative_positions_matrix(
length, max_relative_position) length, max_relative_position,cache=cache)
vocab_size = max_relative_position * 2 + 1 vocab_size = max_relative_position * 2 + 1
# Generates embedding for each relative position of dimension depth. # Generates embedding for each relative position of dimension depth.
embeddings_table = tf.get_variable("embeddings", [vocab_size, depth]) embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
...@@ -562,7 +569,8 @@ def dot_product_attention_relative(q, ...@@ -562,7 +569,8 @@ def dot_product_attention_relative(q,
name=None, name=None,
make_image_summary=False, make_image_summary=False,
dropout_broadcast_dims=None, dropout_broadcast_dims=None,
debug_flag=None): debug_flag=None,
cache=False):
"""Calculate relative position-aware dot-product self-attention. """Calculate relative position-aware dot-product self-attention.
The attention calculation is augmented with learned representations for the The attention calculation is augmented with learned representations for the
...@@ -595,8 +603,9 @@ def dot_product_attention_relative(q, ...@@ -595,8 +603,9 @@ def dot_product_attention_relative(q,
# This calculation only works for self attention. # This calculation only works for self attention.
# q, k and v must therefore have the same shape. # q, k and v must therefore have the same shape.
# wq: compatible means same shape # wq: compatible means same shape
q.get_shape().assert_is_compatible_with(k.get_shape()) if not cache:
q.get_shape().assert_is_compatible_with(v.get_shape()) q.get_shape().assert_is_compatible_with(k.get_shape())
q.get_shape().assert_is_compatible_with(v.get_shape())
# Use separate embeddings suitable for keys and values. # Use separate embeddings suitable for keys and values.
depth = q.get_shape().as_list()[3] depth = q.get_shape().as_list()[3]
...@@ -604,10 +613,12 @@ def dot_product_attention_relative(q, ...@@ -604,10 +613,12 @@ def dot_product_attention_relative(q,
# wq: relations_keys: (L, L, H), where H is hidden size of head # wq: relations_keys: (L, L, H), where H is hidden size of head
relations_keys = _generate_relative_positions_embeddings( relations_keys = _generate_relative_positions_embeddings(
length, depth, max_relative_position, "relative_positions_keys", length, depth, max_relative_position, "relative_positions_keys",
debug_flag=debug_flag+"+rp-key" if debug_flag is not None else None) debug_flag=debug_flag+"+rp-key" if debug_flag is not None else None,
cache=cache)
relations_values = _generate_relative_positions_embeddings( relations_values = _generate_relative_positions_embeddings(
length, depth, max_relative_position, "relative_positions_values", length, depth, max_relative_position, "relative_positions_values",
debug_flag=debug_flag + "+rp-value" if debug_flag is not None else None) debug_flag=debug_flag + "+rp-value" if debug_flag is not None else None,
cache=cache)
if debug_flag is not None: if debug_flag is not None:
q = tf.Print(q, [q], summarize=100, message="%s+q" % debug_flag) q = tf.Print(q, [q], summarize=100, message="%s+q" % debug_flag)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论