Commit 94469024 by libei

revise transformer haprams

parent c31fd16c
...@@ -168,4 +168,4 @@ class SqueezeExcitationLayerHistory(BaseLayerHistory): ...@@ -168,4 +168,4 @@ class SqueezeExcitationLayerHistory(BaseLayerHistory):
def clean(self): def clean(self):
self.sum = None self.sum = None
self.count = 0 self.count = 0
self.layers = [] self.layers = []
\ No newline at end of file
...@@ -58,11 +58,17 @@ class Transformer(t2t_model.T2TModel): ...@@ -58,11 +58,17 @@ class Transformer(t2t_model.T2TModel):
decoder_input, decoder_self_attention_bias = transformer_prepare_decoder( decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(
targets, hparams) targets, hparams)
if hparams.attention_type == "convoluntional_self_attention":
conv_mask_future = get_conv_mask(tf.shape(inputs)[1], hparams, mask_future=True)
else:
conv_mask_future = None
decoder_output = self.decode(decoder_input, decoder_output = self.decode(decoder_input,
encoder_output, encoder_output,
encoder_attention_bias, encoder_attention_bias,
decoder_self_attention_bias, decoder_self_attention_bias,
hparams) hparams,
conv_mask=conv_mask_future)
decoder_output = tf.expand_dims(decoder_output, 2) decoder_output = tf.expand_dims(decoder_output, 2)
...@@ -79,7 +85,9 @@ class Transformer(t2t_model.T2TModel): ...@@ -79,7 +85,9 @@ class Transformer(t2t_model.T2TModel):
(encoder_input, encoder_attention_bias, _) = (transformer_prepare_encoder( (encoder_input, encoder_attention_bias, _) = (transformer_prepare_encoder(
inputs, target_space, hparams)) inputs, target_space, hparams))
if hparams.attention_type == "convoluntional_self_attention": if hparams.attention_type == "convoluntional_self_attention":
conv_mask = get_conv_mask(inputs, hparams) conv_mask = get_conv_mask(tf.shape(inputs)[1], hparams)
else:
conv_mask = None
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout)
encoder_output = transformer_encoder(encoder_input, encoder_output = transformer_encoder(encoder_input,
encoder_attention_bias, encoder_attention_bias,
...@@ -94,6 +102,7 @@ class Transformer(t2t_model.T2TModel): ...@@ -94,6 +102,7 @@ class Transformer(t2t_model.T2TModel):
encoder_attention_bias, encoder_attention_bias,
decoder_self_attention_bias, decoder_self_attention_bias,
hparams, hparams,
conv_mask=None,
cache=None): cache=None):
...@@ -104,6 +113,7 @@ class Transformer(t2t_model.T2TModel): ...@@ -104,6 +113,7 @@ class Transformer(t2t_model.T2TModel):
decoder_self_attention_bias, decoder_self_attention_bias,
encoder_attention_bias, encoder_attention_bias,
hparams, hparams,
conv_mask=conv_mask,
cache=cache) cache=cache)
return decoder_output return decoder_output
...@@ -451,7 +461,7 @@ def transformer_prepare_encoder(inputs, target_space, hparams): ...@@ -451,7 +461,7 @@ def transformer_prepare_encoder(inputs, target_space, hparams):
encoder_input = common_attention.add_timing_signal_1d(encoder_input) encoder_input = common_attention.add_timing_signal_1d(encoder_input)
return (encoder_input, encoder_self_attention_bias, encoder_padding) return (encoder_input, encoder_self_attention_bias, encoder_padding)
def get_conv_mask(inputs, hparams): def get_conv_mask(length, hparams, mask_future=False):
"""Prepare one shard of the model for the encoder. """Prepare one shard of the model for the encoder.
Args: Args:
...@@ -464,12 +474,14 @@ def get_conv_mask(inputs, hparams): ...@@ -464,12 +474,14 @@ def get_conv_mask(inputs, hparams):
alignments alignments
""" """
# Flatten inputs. # Flatten inputs.
ishape_static = inputs.shape.as_list() down_bang = hparams.attention_window_size // 2
encoder_input = inputs if mask_future:
encoder_padding = common_attention.embedding_to_padding(encoder_input) up_bang = 0
encoder_self_attention_bias = common_attention.attention_bias_ignore_padding( else:
encoder_padding) up_bang = down_bang
return (encoder_input, encoder_self_attention_bias, encoder_padding) lower_triangle = tf.matrix_band_part(tf.ones([length, length]), down_bang, up_bang)
ret = -1e9 * (1.0 - lower_triangle)
return tf.reshape(ret, [1, 1, length, length])
def transformer_prepare_decoder(targets, hparams): def transformer_prepare_decoder(targets, hparams):
...@@ -546,7 +558,6 @@ def transformer_encoder(encoder_input, ...@@ -546,7 +558,6 @@ def transformer_encoder(encoder_input,
hparams.attention_dropout, hparams.attention_dropout,
attention_type=hparams.attention_type, attention_type=hparams.attention_type,
max_relative_length=hparams.max_relative_length, max_relative_length=hparams.max_relative_length,
window_size=hparams.window_size,
conv_mask=conv_mask, conv_mask=conv_mask,
dropout_broadcast_dims=attention_dropout_broadcast_dims, dropout_broadcast_dims=attention_dropout_broadcast_dims,
summaries=False, summaries=False,
...@@ -577,6 +588,7 @@ def transformer_decoder(decoder_input, ...@@ -577,6 +588,7 @@ def transformer_decoder(decoder_input,
decoder_self_attention_bias, decoder_self_attention_bias,
encoder_decoder_attention_bias, encoder_decoder_attention_bias,
hparams, hparams,
conv_mask=None,
cache=None, cache=None,
name="decoder"): name="decoder"):
"""A stack of transformer layers. """A stack of transformer layers.
...@@ -626,7 +638,6 @@ def transformer_decoder(decoder_input, ...@@ -626,7 +638,6 @@ def transformer_decoder(decoder_input,
hparams.attention_dropout, hparams.attention_dropout,
attention_type=hparams.attention_type, attention_type=hparams.attention_type,
max_relative_length=hparams.max_relative_length, max_relative_length=hparams.max_relative_length,
window_size=hparams.window_size,
conv_mask=conv_mask, conv_mask=conv_mask,
dropout_broadcast_dims=attention_dropout_broadcast_dims, dropout_broadcast_dims=attention_dropout_broadcast_dims,
cache=layer_cache, cache=layer_cache,
...@@ -801,7 +812,7 @@ def transformer_big(): ...@@ -801,7 +812,7 @@ def transformer_big():
@registry.register_hparams @registry.register_hparams
def transformer_before(): def transformer_base_before():
"""HParams for transfomer big model on WMT.""" """HParams for transfomer big model on WMT."""
hparams = transformer_base() hparams = transformer_base()
hparams.normalize_before = True hparams.normalize_before = True
...@@ -815,9 +826,9 @@ def transformer_before(): ...@@ -815,9 +826,9 @@ def transformer_before():
@registry.register_hparams @registry.register_hparams
def transformer_before_big(): def transformer_big_before():
"""HParams for transfomer big model on WMT.""" """HParams for transfomer big model on WMT."""
hparams = transformer_before() hparams = transformer_base_before()
hparams.hidden_size = 1024 hparams.hidden_size = 1024
hparams.filter_size = 4096 hparams.filter_size = 4096
hparams.num_heads = 16 hparams.num_heads = 16
...@@ -839,7 +850,7 @@ def transformer_base_v2(): ...@@ -839,7 +850,7 @@ def transformer_base_v2():
@registry.register_hparams @registry.register_hparams
def transformer_rpr_base(): def transformer_rpr_base():
hparams = transformer_before() hparams = transformer_base_before()
hparams.max_relative_length = 20 hparams.max_relative_length = 20
hparams.attention_type = "relative_dot_product" hparams.attention_type = "relative_dot_product"
# optimal # optimal
...@@ -871,7 +882,7 @@ def transformer_big_multistep2(): ...@@ -871,7 +882,7 @@ def transformer_big_multistep2():
@registry.register_hparams @registry.register_hparams
def transformer_before_shared(): def transformer_before_shared():
# new model use optimizer MultistepAdam # new model use optimizer MultistepAdam
hparams = transformer_before() hparams = transformer_base_before()
hparams.shared_decoder_input_and_softmax_weights = int(True) hparams.shared_decoder_input_and_softmax_weights = int(True)
return hparams return hparams
...@@ -886,10 +897,10 @@ def transformer_before_shared25(): ...@@ -886,10 +897,10 @@ def transformer_before_shared25():
hparams.learning_rate = 0.4 hparams.learning_rate = 0.4
hparams.learning_rate_warmup_steps = 8000 hparams.learning_rate_warmup_steps = 8000
hparams.optimizer = "MultistepAdam" hparams.optimizer = "MultistepAdam"
hparams.optimizer_multistep_accumulate_steps = 4 hparams.optimizer_multistep_accumulate_steps = 3
hparams.encoder_layers = 25 hparams.encoder_layers = 25
# it's likely to oom when you train deep transformer-pre-norm within 4096 batch_size # it's likely to oom when you train deep transformer-pre-norm within 4096 batch_size
hparams.batch_size = 2048 hparams.batch_size = 3072
return hparams return hparams
...@@ -897,11 +908,3 @@ def transformer_before_shared25(): ...@@ -897,11 +908,3 @@ def transformer_before_shared25():
# @registry.register_hparams # @registry.register_hparams
# def transformer_new****(): # def transformer_new****():
# return hparams # return hparams
@registry.register_hparams
def transformer_conv_sa():
# new model use optimizer MultistepAdam
hparams = transformer_before()
hparams.attention_type = "convolutional_self_attention"
hparams.add_hparam("attention_window_size", 5)
return hparams
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论