Commit 868f56a6 by libeineu

add convolutional self attention

parent 8a88d1b1
...@@ -382,6 +382,8 @@ def multihead_attention(query_antecedent, ...@@ -382,6 +382,8 @@ def multihead_attention(query_antecedent,
dropout_rate, dropout_rate,
attention_type="dot_product", attention_type="dot_product",
max_relative_length=16, max_relative_length=16,
window_size=-1,
conv_mask=None,
summaries=False, summaries=False,
image_shapes=None, image_shapes=None,
cache=None, cache=None,
...@@ -463,6 +465,9 @@ def multihead_attention(query_antecedent, ...@@ -463,6 +465,9 @@ def multihead_attention(query_antecedent,
elif attention_type == "relative_dot_product": elif attention_type == "relative_dot_product":
x = dot_product_attention_relative( x = dot_product_attention_relative(
q, k, v, bias, max_relative_length, dropout_rate, summaries, image_shapes, dropout_broadcast_dims=dropout_broadcast_dims) q, k, v, bias, max_relative_length, dropout_rate, summaries, image_shapes, dropout_broadcast_dims=dropout_broadcast_dims)
elif attention_type == "convolutional_self_attention":
x = convolutional_self_attention(
q, k, v, bias, dropout_rate, summaries, image_shapes, dropout_broadcast_dims=dropout_broadcast_dims, window_size=window_size, conv_mask=conv_mask)
else: else:
raise ValueError raise ValueError
...@@ -837,3 +842,48 @@ def parameter_attention(x, ...@@ -837,3 +842,48 @@ def parameter_attention(x,
y.set_shape([None, None, total_value_depth]) y.set_shape([None, None, total_value_depth])
y = common_layers.conv1d(y, output_depth, 1, name="output_transform") y = common_layers.conv1d(y, output_depth, 1, name="output_transform")
return y return y
def convolutional_self_attention(q,
k,
v,
bias,
dropout_rate=0.0,
summaries=False,
image_shapes=None,
dropout_broadcast_dims=None,
window_size=-1,
conv_mask=None,
name=None):
"""dot-product attention.
Args:
q: a Tensor with shape [batch, heads, length_q, depth_k]
k: a Tensor with shape [batch, heads, length_kv, depth_k]
v: a Tensor with shape [batch, heads, length_kv, depth_v]
bias: bias Tensor (see attention_bias())
dropout_rate: a floating point number
summaries: a boolean
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
name: an optional string
Returns:
A Tensor.
"""
with tf.variable_scope(
name, default_name="dot_product_attention", values=[q, k, v]):
# [batch, num_heads, query_length, memory_length]
logits = tf.matmul(q, k, transpose_b=True)
if bias is not None:
logits += bias
if window_size != -1:
logits += conv_mask
weights = tf.nn.softmax(logits, name="attention_weights")
# broadcast dropout can save memory
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
if summaries and not tf.get_variable_scope().reuse:
attention_image_summary(weights, image_shapes)
return tf.matmul(weights, v)
\ No newline at end of file
...@@ -78,10 +78,13 @@ class Transformer(t2t_model.T2TModel): ...@@ -78,10 +78,13 @@ class Transformer(t2t_model.T2TModel):
inputs = common_layers.flatten4d3d(inputs) inputs = common_layers.flatten4d3d(inputs)
(encoder_input, encoder_attention_bias, _) = (transformer_prepare_encoder( (encoder_input, encoder_attention_bias, _) = (transformer_prepare_encoder(
inputs, target_space, hparams)) inputs, target_space, hparams))
if hparams.attention_type == "convoluntional_self_attention":
conv_mask = get_conv_mask(inputs, hparams)
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout)
encoder_output = transformer_encoder(encoder_input, encoder_output = transformer_encoder(encoder_input,
encoder_attention_bias, encoder_attention_bias,
hparams) hparams,
conv_mask=conv_mask)
return encoder_output, encoder_attention_bias return encoder_output, encoder_attention_bias
...@@ -448,6 +451,26 @@ def transformer_prepare_encoder(inputs, target_space, hparams): ...@@ -448,6 +451,26 @@ def transformer_prepare_encoder(inputs, target_space, hparams):
encoder_input = common_attention.add_timing_signal_1d(encoder_input) encoder_input = common_attention.add_timing_signal_1d(encoder_input)
return (encoder_input, encoder_self_attention_bias, encoder_padding) return (encoder_input, encoder_self_attention_bias, encoder_padding)
def get_conv_mask(inputs, hparams):
"""Prepare one shard of the model for the encoder.
Args:
inputs: a Tensor.
hparams: run hyperparameters
Returns:
conv_mask: a Tensor, bcontaining large negative values
to implement masked attention and possibly baises for diagonal
alignments
"""
# Flatten inputs.
ishape_static = inputs.shape.as_list()
encoder_input = inputs
encoder_padding = common_attention.embedding_to_padding(encoder_input)
encoder_self_attention_bias = common_attention.attention_bias_ignore_padding(
encoder_padding)
return (encoder_input, encoder_self_attention_bias, encoder_padding)
def transformer_prepare_decoder(targets, hparams): def transformer_prepare_decoder(targets, hparams):
"""Prepare one shard of the model for the decoder. """Prepare one shard of the model for the decoder.
...@@ -479,6 +502,7 @@ def may_be_layernorm(input, hparams, before=False, after=False, name=None): ...@@ -479,6 +502,7 @@ def may_be_layernorm(input, hparams, before=False, after=False, name=None):
def transformer_encoder(encoder_input, def transformer_encoder(encoder_input,
encoder_self_attention_bias, encoder_self_attention_bias,
hparams, hparams,
conv_mask=None,
name="encoder"): name="encoder"):
"""A stack of transformer layers. """A stack of transformer layers.
...@@ -522,6 +546,8 @@ def transformer_encoder(encoder_input, ...@@ -522,6 +546,8 @@ def transformer_encoder(encoder_input,
hparams.attention_dropout, hparams.attention_dropout,
attention_type=hparams.attention_type, attention_type=hparams.attention_type,
max_relative_length=hparams.max_relative_length, max_relative_length=hparams.max_relative_length,
window_size=hparams.window_size,
conv_mask=conv_mask,
dropout_broadcast_dims=attention_dropout_broadcast_dims, dropout_broadcast_dims=attention_dropout_broadcast_dims,
summaries=False, summaries=False,
name="encoder_self_attention") name="encoder_self_attention")
...@@ -600,6 +626,8 @@ def transformer_decoder(decoder_input, ...@@ -600,6 +626,8 @@ def transformer_decoder(decoder_input,
hparams.attention_dropout, hparams.attention_dropout,
attention_type=hparams.attention_type, attention_type=hparams.attention_type,
max_relative_length=hparams.max_relative_length, max_relative_length=hparams.max_relative_length,
window_size=hparams.window_size,
conv_mask=conv_mask,
dropout_broadcast_dims=attention_dropout_broadcast_dims, dropout_broadcast_dims=attention_dropout_broadcast_dims,
cache=layer_cache, cache=layer_cache,
summaries=False, summaries=False,
...@@ -868,4 +896,12 @@ def transformer_before_shared25(): ...@@ -868,4 +896,12 @@ def transformer_before_shared25():
# you can define your own hparams like above to fix the target task # you can define your own hparams like above to fix the target task
# @registry.register_hparams # @registry.register_hparams
# def transformer_new****(): # def transformer_new****():
# return hparams # return hparams
\ No newline at end of file
@registry.register_hparams
def transformer_conv_sa():
# new model use optimizer MultistepAdam
hparams = transformer_before()
hparams.attention_type = "convolutional_self_attention"
hparams.add_hparam("attention_window_size", 5)
return hparams
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论