Commit c31fd16c by libei

add conv-self attention

parent 868f56a6
......@@ -382,7 +382,6 @@ def multihead_attention(query_antecedent,
dropout_rate,
attention_type="dot_product",
max_relative_length=16,
window_size=-1,
conv_mask=None,
summaries=False,
image_shapes=None,
......@@ -467,7 +466,7 @@ def multihead_attention(query_antecedent,
q, k, v, bias, max_relative_length, dropout_rate, summaries, image_shapes, dropout_broadcast_dims=dropout_broadcast_dims)
elif attention_type == "convolutional_self_attention":
x = convolutional_self_attention(
q, k, v, bias, dropout_rate, summaries, image_shapes, dropout_broadcast_dims=dropout_broadcast_dims, window_size=window_size, conv_mask=conv_mask)
q, k, v, bias, dropout_rate, summaries, image_shapes, dropout_broadcast_dims=dropout_broadcast_dims, conv_mask=conv_mask)
else:
raise ValueError
......@@ -852,7 +851,6 @@ def convolutional_self_attention(q,
summaries=False,
image_shapes=None,
dropout_broadcast_dims=None,
window_size=-1,
conv_mask=None,
name=None):
"""dot-product attention.
......@@ -877,7 +875,7 @@ def convolutional_self_attention(q,
logits = tf.matmul(q, k, transpose_b=True)
if bias is not None:
logits += bias
if window_size != -1:
if conv_mask is not None:
logits += conv_mask
weights = tf.nn.softmax(logits, name="attention_weights")
# broadcast dropout can save memory
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论