Commit 868f56a6 by libeineu

add convolutional self attention

parent 8a88d1b1
......@@ -382,6 +382,8 @@ def multihead_attention(query_antecedent,
dropout_rate,
attention_type="dot_product",
max_relative_length=16,
window_size=-1,
conv_mask=None,
summaries=False,
image_shapes=None,
cache=None,
......@@ -463,6 +465,9 @@ def multihead_attention(query_antecedent,
elif attention_type == "relative_dot_product":
x = dot_product_attention_relative(
q, k, v, bias, max_relative_length, dropout_rate, summaries, image_shapes, dropout_broadcast_dims=dropout_broadcast_dims)
elif attention_type == "convolutional_self_attention":
x = convolutional_self_attention(
q, k, v, bias, dropout_rate, summaries, image_shapes, dropout_broadcast_dims=dropout_broadcast_dims, window_size=window_size, conv_mask=conv_mask)
else:
raise ValueError
......@@ -837,3 +842,48 @@ def parameter_attention(x,
y.set_shape([None, None, total_value_depth])
y = common_layers.conv1d(y, output_depth, 1, name="output_transform")
return y
def convolutional_self_attention(q,
k,
v,
bias,
dropout_rate=0.0,
summaries=False,
image_shapes=None,
dropout_broadcast_dims=None,
window_size=-1,
conv_mask=None,
name=None):
"""dot-product attention.
Args:
q: a Tensor with shape [batch, heads, length_q, depth_k]
k: a Tensor with shape [batch, heads, length_kv, depth_k]
v: a Tensor with shape [batch, heads, length_kv, depth_v]
bias: bias Tensor (see attention_bias())
dropout_rate: a floating point number
summaries: a boolean
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
name: an optional string
Returns:
A Tensor.
"""
with tf.variable_scope(
name, default_name="dot_product_attention", values=[q, k, v]):
# [batch, num_heads, query_length, memory_length]
logits = tf.matmul(q, k, transpose_b=True)
if bias is not None:
logits += bias
if window_size != -1:
logits += conv_mask
weights = tf.nn.softmax(logits, name="attention_weights")
# broadcast dropout can save memory
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
if summaries and not tf.get_variable_scope().reuse:
attention_image_summary(weights, image_shapes)
return tf.matmul(weights, v)
\ No newline at end of file
......@@ -78,10 +78,13 @@ class Transformer(t2t_model.T2TModel):
inputs = common_layers.flatten4d3d(inputs)
(encoder_input, encoder_attention_bias, _) = (transformer_prepare_encoder(
inputs, target_space, hparams))
if hparams.attention_type == "convoluntional_self_attention":
conv_mask = get_conv_mask(inputs, hparams)
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout)
encoder_output = transformer_encoder(encoder_input,
encoder_attention_bias,
hparams)
hparams,
conv_mask=conv_mask)
return encoder_output, encoder_attention_bias
......@@ -448,6 +451,26 @@ def transformer_prepare_encoder(inputs, target_space, hparams):
encoder_input = common_attention.add_timing_signal_1d(encoder_input)
return (encoder_input, encoder_self_attention_bias, encoder_padding)
def get_conv_mask(inputs, hparams):
"""Prepare one shard of the model for the encoder.
Args:
inputs: a Tensor.
hparams: run hyperparameters
Returns:
conv_mask: a Tensor, bcontaining large negative values
to implement masked attention and possibly baises for diagonal
alignments
"""
# Flatten inputs.
ishape_static = inputs.shape.as_list()
encoder_input = inputs
encoder_padding = common_attention.embedding_to_padding(encoder_input)
encoder_self_attention_bias = common_attention.attention_bias_ignore_padding(
encoder_padding)
return (encoder_input, encoder_self_attention_bias, encoder_padding)
def transformer_prepare_decoder(targets, hparams):
"""Prepare one shard of the model for the decoder.
......@@ -479,6 +502,7 @@ def may_be_layernorm(input, hparams, before=False, after=False, name=None):
def transformer_encoder(encoder_input,
encoder_self_attention_bias,
hparams,
conv_mask=None,
name="encoder"):
"""A stack of transformer layers.
......@@ -522,6 +546,8 @@ def transformer_encoder(encoder_input,
hparams.attention_dropout,
attention_type=hparams.attention_type,
max_relative_length=hparams.max_relative_length,
window_size=hparams.window_size,
conv_mask=conv_mask,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
summaries=False,
name="encoder_self_attention")
......@@ -600,6 +626,8 @@ def transformer_decoder(decoder_input,
hparams.attention_dropout,
attention_type=hparams.attention_type,
max_relative_length=hparams.max_relative_length,
window_size=hparams.window_size,
conv_mask=conv_mask,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
cache=layer_cache,
summaries=False,
......@@ -868,4 +896,12 @@ def transformer_before_shared25():
# you can define your own hparams like above to fix the target task
# @registry.register_hparams
# def transformer_new****():
# return hparams
\ No newline at end of file
# return hparams
@registry.register_hparams
def transformer_conv_sa():
# new model use optimizer MultistepAdam
hparams = transformer_before()
hparams.attention_type = "convolutional_self_attention"
hparams.add_hparam("attention_window_size", 5)
return hparams
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论