Commit 94469024 by libei

revise transformer haprams

parent c31fd16c
......@@ -168,4 +168,4 @@ class SqueezeExcitationLayerHistory(BaseLayerHistory):
def clean(self):
self.sum = None
self.count = 0
self.layers = []
self.layers = []
\ No newline at end of file
......@@ -58,11 +58,17 @@ class Transformer(t2t_model.T2TModel):
decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(
targets, hparams)
if hparams.attention_type == "convoluntional_self_attention":
conv_mask_future = get_conv_mask(tf.shape(inputs)[1], hparams, mask_future=True)
else:
conv_mask_future = None
decoder_output = self.decode(decoder_input,
encoder_output,
encoder_attention_bias,
decoder_self_attention_bias,
hparams)
hparams,
conv_mask=conv_mask_future)
decoder_output = tf.expand_dims(decoder_output, 2)
......@@ -79,7 +85,9 @@ class Transformer(t2t_model.T2TModel):
(encoder_input, encoder_attention_bias, _) = (transformer_prepare_encoder(
inputs, target_space, hparams))
if hparams.attention_type == "convoluntional_self_attention":
conv_mask = get_conv_mask(inputs, hparams)
conv_mask = get_conv_mask(tf.shape(inputs)[1], hparams)
else:
conv_mask = None
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout)
encoder_output = transformer_encoder(encoder_input,
encoder_attention_bias,
......@@ -94,6 +102,7 @@ class Transformer(t2t_model.T2TModel):
encoder_attention_bias,
decoder_self_attention_bias,
hparams,
conv_mask=None,
cache=None):
......@@ -104,6 +113,7 @@ class Transformer(t2t_model.T2TModel):
decoder_self_attention_bias,
encoder_attention_bias,
hparams,
conv_mask=conv_mask,
cache=cache)
return decoder_output
......@@ -451,7 +461,7 @@ def transformer_prepare_encoder(inputs, target_space, hparams):
encoder_input = common_attention.add_timing_signal_1d(encoder_input)
return (encoder_input, encoder_self_attention_bias, encoder_padding)
def get_conv_mask(inputs, hparams):
def get_conv_mask(length, hparams, mask_future=False):
"""Prepare one shard of the model for the encoder.
Args:
......@@ -464,12 +474,14 @@ def get_conv_mask(inputs, hparams):
alignments
"""
# Flatten inputs.
ishape_static = inputs.shape.as_list()
encoder_input = inputs
encoder_padding = common_attention.embedding_to_padding(encoder_input)
encoder_self_attention_bias = common_attention.attention_bias_ignore_padding(
encoder_padding)
return (encoder_input, encoder_self_attention_bias, encoder_padding)
down_bang = hparams.attention_window_size // 2
if mask_future:
up_bang = 0
else:
up_bang = down_bang
lower_triangle = tf.matrix_band_part(tf.ones([length, length]), down_bang, up_bang)
ret = -1e9 * (1.0 - lower_triangle)
return tf.reshape(ret, [1, 1, length, length])
def transformer_prepare_decoder(targets, hparams):
......@@ -546,7 +558,6 @@ def transformer_encoder(encoder_input,
hparams.attention_dropout,
attention_type=hparams.attention_type,
max_relative_length=hparams.max_relative_length,
window_size=hparams.window_size,
conv_mask=conv_mask,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
summaries=False,
......@@ -577,6 +588,7 @@ def transformer_decoder(decoder_input,
decoder_self_attention_bias,
encoder_decoder_attention_bias,
hparams,
conv_mask=None,
cache=None,
name="decoder"):
"""A stack of transformer layers.
......@@ -626,7 +638,6 @@ def transformer_decoder(decoder_input,
hparams.attention_dropout,
attention_type=hparams.attention_type,
max_relative_length=hparams.max_relative_length,
window_size=hparams.window_size,
conv_mask=conv_mask,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
cache=layer_cache,
......@@ -801,7 +812,7 @@ def transformer_big():
@registry.register_hparams
def transformer_before():
def transformer_base_before():
"""HParams for transfomer big model on WMT."""
hparams = transformer_base()
hparams.normalize_before = True
......@@ -815,9 +826,9 @@ def transformer_before():
@registry.register_hparams
def transformer_before_big():
def transformer_big_before():
"""HParams for transfomer big model on WMT."""
hparams = transformer_before()
hparams = transformer_base_before()
hparams.hidden_size = 1024
hparams.filter_size = 4096
hparams.num_heads = 16
......@@ -839,7 +850,7 @@ def transformer_base_v2():
@registry.register_hparams
def transformer_rpr_base():
hparams = transformer_before()
hparams = transformer_base_before()
hparams.max_relative_length = 20
hparams.attention_type = "relative_dot_product"
# optimal
......@@ -871,7 +882,7 @@ def transformer_big_multistep2():
@registry.register_hparams
def transformer_before_shared():
# new model use optimizer MultistepAdam
hparams = transformer_before()
hparams = transformer_base_before()
hparams.shared_decoder_input_and_softmax_weights = int(True)
return hparams
......@@ -886,10 +897,10 @@ def transformer_before_shared25():
hparams.learning_rate = 0.4
hparams.learning_rate_warmup_steps = 8000
hparams.optimizer = "MultistepAdam"
hparams.optimizer_multistep_accumulate_steps = 4
hparams.optimizer_multistep_accumulate_steps = 3
hparams.encoder_layers = 25
# it's likely to oom when you train deep transformer-pre-norm within 4096 batch_size
hparams.batch_size = 2048
hparams.batch_size = 3072
return hparams
......@@ -897,11 +908,3 @@ def transformer_before_shared25():
# @registry.register_hparams
# def transformer_new****():
# return hparams
@registry.register_hparams
def transformer_conv_sa():
# new model use optimizer MultistepAdam
hparams = transformer_before()
hparams.attention_type = "convolutional_self_attention"
hparams.add_hparam("attention_window_size", 5)
return hparams
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论