Commit ad6cc8e3 by libei

support dot product relative fast decode

parent cc78cf5b
......@@ -470,13 +470,18 @@ def multihead_attention(query_antecedent,
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
return x
def _generate_relative_positions_matrix(length, max_relative_position):
def _generate_relative_positions_matrix(length, max_relative_position,
cache=False):
"""Generates matrix of relative positions between inputs."""
range_vec = tf.range(length)
range_mat = tf.reshape(tf.tile(range_vec, [length]), [length, length])
#range_mat = tf.Print(range_mat, [range_mat], message="range_mat:", summarize=100)
distance_mat = range_mat - tf.transpose(range_mat)
#distance_mat = tf.Print(distance_mat, [distance_mat], message="distance_mat:", summarize=100)
if not cache:
#training process
range_vec = tf.range(length)
range_mat = tf.reshape(tf.tile(range_vec, [length]), [length, length])
# range_mat = tf.Print(range_mat, [range_mat], message="range_mat:", summarize=100)
distance_mat = range_mat - tf.transpose(range_mat)
# distance_mat = tf.Print(distance_mat, [distance_mat], message="distance_mat:", summarize=100)
else:
distance_mat = tf.expand_dims(tf.range(-length+1, 1, 1), 0)
distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position,
max_relative_position)
......@@ -490,12 +495,14 @@ def _generate_relative_positions_matrix(length, max_relative_position):
return final_mat
def _generate_relative_positions_embeddings(length, depth,
max_relative_position, name, debug_flag=None):
"""Generates tensor of size [length, length, depth]."""
max_relative_position, name,
debug_flag=None,
cache=False):
"""Generates tensor of size [1 if cache else length, length, depth]."""
with tf.variable_scope(name):
# wq: relative_positions_matrix shape is (L, L), value range is 0 ~ 2K (total 2K+1)
# libei: relative_positions_matrix shape is (L, L), value range is 0 ~ 2K (total 2K+1)
relative_positions_matrix = _generate_relative_positions_matrix(
length, max_relative_position)
length, max_relative_position,cache=cache)
vocab_size = max_relative_position * 2 + 1
# Generates embedding for each relative position of dimension depth.
embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
......@@ -562,7 +569,8 @@ def dot_product_attention_relative(q,
name=None,
make_image_summary=False,
dropout_broadcast_dims=None,
debug_flag=None):
debug_flag=None,
cache=False):
"""Calculate relative position-aware dot-product self-attention.
The attention calculation is augmented with learned representations for the
......@@ -595,8 +603,9 @@ def dot_product_attention_relative(q,
# This calculation only works for self attention.
# q, k and v must therefore have the same shape.
# wq: compatible means same shape
q.get_shape().assert_is_compatible_with(k.get_shape())
q.get_shape().assert_is_compatible_with(v.get_shape())
if not cache:
q.get_shape().assert_is_compatible_with(k.get_shape())
q.get_shape().assert_is_compatible_with(v.get_shape())
# Use separate embeddings suitable for keys and values.
depth = q.get_shape().as_list()[3]
......@@ -604,10 +613,12 @@ def dot_product_attention_relative(q,
# wq: relations_keys: (L, L, H), where H is hidden size of head
relations_keys = _generate_relative_positions_embeddings(
length, depth, max_relative_position, "relative_positions_keys",
debug_flag=debug_flag+"+rp-key" if debug_flag is not None else None)
debug_flag=debug_flag+"+rp-key" if debug_flag is not None else None,
cache=cache)
relations_values = _generate_relative_positions_embeddings(
length, depth, max_relative_position, "relative_positions_values",
debug_flag=debug_flag + "+rp-value" if debug_flag is not None else None)
debug_flag=debug_flag + "+rp-value" if debug_flag is not None else None,
cache=cache)
if debug_flag is not None:
q = tf.Print(q, [q], summarize=100, message="%s+q" % debug_flag)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论