Commit 2e5f98d6 by libei

suport share decoder input and softmax embedding

parent 8ab393f2
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
#*.so
# Checkpoints
checkpoints
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Pycharm
.idea
# binary dataset
lib
......@@ -460,8 +460,8 @@ def wmt_zhen_tokens(model_hparams, wrong_vocab_size):
"""Chinese to English translation benchmark."""
p = default_problem_hparams()
# This vocab file must be present within the data directory.
if model_hparams.shared_embedding_and_softmax_weights == 1:
model_hparams.shared_embedding_and_softmax_weights = 0
# if model_hparams.shared_embedding_and_softmax_weights == 1:
# model_hparams.shared_embedding_and_softmax_weights = 0
source_vocab_filename = os.path.join(model_hparams.data_dir,
"source_dic")
target_vocab_filename = os.path.join(model_hparams.data_dir,
......
......@@ -103,6 +103,8 @@ def basic_params1():
# by using a problem_hparams that uses the same modality object for
# the input_modality and target_modality.
shared_embedding_and_softmax_weights=int(False),
# share the target_modality and softmax
shared_decoder_input_and_softmax_weights = int(False),
# For each feature for which you want to override the default input
# modality, add an entry to this semicolon-separated string. Entries are
# formatted "feature_name:modality_type:modality_name", e.g.
......
......@@ -91,8 +91,11 @@ class SymbolModality(modality.Modality):
return self.bottom_simple(x, "input_emb", reuse=None)
def targets_bottom(self, x):
assert not (self._model_hparams.shared_embedding_and_softmax_weights and self._model_hparams.shared_decoder_input_and_softmax_weights)
if self._model_hparams.shared_embedding_and_softmax_weights:
return self.bottom_simple(x, "shared", reuse=True)
elif self._model_hparams.shared_decoder_input_and_softmax_weights:
return self.bottom_simple(x, "shared_target", reuse=tf.AUTO_REUSE)
else:
return self.bottom_simple(x, "target_emb", reuse=None)
......@@ -105,9 +108,13 @@ class SymbolModality(modality.Modality):
Returns:
logits: A Tensor with shape [batch, p0, p1, ?, vocab_size].
"""
assert not (self._model_hparams.shared_embedding_and_softmax_weights and self._model_hparams.shared_decoder_input_and_softmax_weights)
if self._model_hparams.shared_embedding_and_softmax_weights:
scope_name = "shared"
reuse = True
elif self._model_hparams.shared_decoder_input_and_softmax_weights:
scope_name = "shared_target"
reuse = True
else:
scope_name = "softmax"
reuse = False
......
......@@ -723,7 +723,8 @@ def transformer_base():
hparams.optimizer_adam_beta2 = 0.98
hparams.num_sampled_classes = 0
hparams.label_smoothing = 0.1
hparams.shared_embedding_and_softmax_weights = int(True)
hparams.shared_embedding_and_softmax_weights = int(False)
hparams.shared_decoder_input_and_softmax_weights = int(False)
hparams.add_hparam("filter_size", 2048) # Add new ones like this.
# attention-related flags
......@@ -797,236 +798,6 @@ def transformer_before_big():
@registry.register_hparams
def transformer_big_single_gpu():
"""HParams for transformer big model for single gpu."""
hparams = transformer_big()
hparams.residual_dropout = 0.1
hparams.learning_rate_warmup_steps = 16000
hparams.optimizer_adam_beta2 = 0.998
hparams.batching_mantissa_bits = 3
return hparams
@registry.register_hparams
def transformer_base_single_gpu():
"""HParams for transformer base model for single gpu."""
hparams = transformer_base()
hparams.batch_size = 8192
hparams.learning_rate_warmup_steps = 16000
hparams.batching_mantissa_bits = 2
return hparams
@registry.register_hparams
def transformer_big_dr1():
hparams = transformer_base()
hparams.hidden_size = 1024
hparams.filter_size = 4096
hparams.num_heads = 16
hparams.residual_dropout = 0.1
hparams.batching_mantissa_bits = 2
return hparams
@registry.register_hparams
def transformer_big_enfr():
hparams = transformer_big_dr1()
hparams.shared_embedding_and_softmax_weights = int(False)
hparams.filter_size = 8192
hparams.residual_dropout = 0.1
return hparams
@registry.register_hparams
def transformer_big_dr2():
hparams = transformer_big_dr1()
hparams.residual_dropout = 0.2
return hparams
@registry.register_hparams
def transformer_base_ldcd():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.learning_rate_decay_scheme = "ld&cd"
hparams.learning_rate_ldcd_epoch = 5
hparams.learning_rate_warmup_steps = 4000
return hparams
@registry.register_hparams
def transformer_base_ldcd_n10():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.learning_rate_decay_scheme = "ld&cd"
hparams.learning_rate_ldcd_epoch = 10
hparams.learning_rate_warmup_steps = 4000
return hparams
@registry.register_hparams
def transformer_base_ldcd_n2():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.learning_rate_decay_scheme = "ld&cd"
hparams.learning_rate_ldcd_epoch = 2
hparams.learning_rate_warmup_steps = 4000
return hparams
@registry.register_hparams
def transformer_base_ldcd_n1():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.learning_rate_decay_scheme = "ld&cd"
hparams.learning_rate_ldcd_epoch = 1
hparams.learning_rate_warmup_steps = 4000
return hparams
@registry.register_hparams
def transformer_base_amsgrad():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "AMSGrad"
return hparams
@registry.register_hparams
def transformer_base_amsgrad_v2():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "AMSGrad"
hparams.optimizer_adam_beta1 = 0.9
hparams.optimizer_adam_beta2 = 0.999
hparams.optimizer_adam_epsilon = 1e-8
return hparams
@registry.register_hparams
def transformer_base_amsgrad_v3():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "AMSGrad"
hparams.optimizer_adam_beta1 = 0.9
hparams.optimizer_adam_beta2 = 0.99
hparams.optimizer_adam_epsilon = 1e-8
return hparams
@registry.register_hparams
def transformer_base_amsgrad_v4():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "AMSGrad"
hparams.optimizer_adam_beta1 = 0.9
hparams.optimizer_adam_beta2 = 0.99
hparams.optimizer_adam_epsilon = 1e-9
return hparams
@registry.register_hparams
def transformer_base_ldrestart_n3():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.learning_rate_decay_scheme = "ld&restart"
hparams.learning_rate_ldrestart_epoch = 3
hparams.learning_rate_warmup_steps = 4000
return hparams
@registry.register_hparams
def transformer_base_powersign():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "PowerSign"
hparams.optimizer_powersign_beta = 0.9
hparams.optimizer_powersign_decay = ""
return hparams
@registry.register_hparams
def transformer_base_powersign_ld():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "PowerSign"
hparams.optimizer_powersign_beta = 0.9
hparams.optimizer_powersign_decay = "linear"
return hparams
@registry.register_hparams
def transformer_base_powersign_cd():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "PowerSign"
hparams.optimizer_powersign_beta = 0.9
hparams.optimizer_powersign_decay = "cosine"
hparams.optimizer_powersign_period = 1
return hparams
@registry.register_hparams
def transformer_base_powersign_rd():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "PowerSign"
hparams.optimizer_powersign_beta = 0.9
hparams.optimizer_powersign_decay = "restart"
hparams.optimizer_powersign_period = 1
return hparams
@registry.register_hparams
def transformer_base_powersign_rd_n10():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "PowerSign"
hparams.optimizer_powersign_beta = 0.9
hparams.optimizer_powersign_decay = "restart"
hparams.optimizer_powersign_period = 10
return hparams
@registry.register_hparams
def transformer_base_powersign_rd_n20():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "PowerSign"
hparams.optimizer_powersign_beta = 0.9
hparams.optimizer_powersign_decay = "restart"
hparams.optimizer_powersign_period = 20
return hparams
@registry.register_hparams
def transformer_base_swish1():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.ffn_layer = "conv_hidden_swish"
hparams.swish_dropout = 0.0
hparams.swish_beta = 1.0
hparams.swish_beta_is_trainable = False
return hparams
@registry.register_hparams
def transformer_base_swish_trainable():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.ffn_layer = "conv_hidden_swish"
hparams.swish_dropout = 0.0
#hparams.swish_beta = 1.0
hparams.swish_beta_is_trainable = True
return hparams
@registry.register_hparams
def transformer_big_adafactor():
"""HParams for transfomer big model on WMT."""
hparams = transformer_base()
hparams.hidden_size = 1024
hparams.filter_size = 4096
hparams.num_heads = 16
hparams.batching_mantissa_bits = 2
hparams.residual_dropout = 0.3
hparams.optimizer = "Adafactor"
hparams.epsilon = 1e-8
hparams.learning_rate_warmup_steps = 16000
hparams.optimizer_adafactor_beta2 = 0.997
return hparams
@registry.register_hparams
def transformer_base_v2():
"""Set of hyperparameters.
set relu_dropout and attention_dropout as 0.1
......@@ -1056,7 +827,6 @@ def transformer_base_v3():
return hparams
@registry.register_hparams
def transformer_big_multistep2():
# new model use optimizer MultistepAdam
......@@ -1069,13 +839,9 @@ def transformer_big_multistep2():
return hparams
@registry.register_hparams
def transformer_big_adafactor_test():
def transformer_before_test():
# new model use optimizer MultistepAdam
hparams = transformer_big()
hparams.optimizer = "Adafactor"
hparams.learning_rate_warmup_steps = 8000
hparams.batch_size = 4096
hparams.optimizer_adafactor_beta2 = 0.999
hparams = transformer_before()
hparams.shared_decoder_input_and_softmax_weights = int(False)
return hparams
......@@ -396,7 +396,8 @@ def transformer_dla():
hparams.optimizer_adam_beta2 = 0.98
hparams.num_sampled_classes = 0
hparams.label_smoothing = 0.1
hparams.shared_embedding_and_softmax_weights = int(True)
hparams.shared_embedding_and_softmax_weights = int(False)
hparams.shared_decoder_input_and_softmax_weights = int(False)
hparams.add_hparam("filter_size", 2048) # Add new ones like this.
# attention-related flags
......@@ -437,6 +438,8 @@ def transformer_dla():
@registry.register_hparams
def transformer_dla_base():
hparams = transformer_dla()
hparams.encoder_layers = 6
hparams.decoder_layers = 6
hparams.normalize_before = True
hparams.attention_dropout = 0.1
hparams.residual_dropout = 0.1
......@@ -450,10 +453,19 @@ def transformer_dla_base():
@registry.register_hparams
def transformer_dla_big():
"""HParams for transfomer big model on WMT."""
hparams = transformer_dla_base()
hparams = transformer_dla()
hparams.hidden_size = 1024
hparams.filter_size = 4096
hparams.num_heads = 16
hparams.batching_mantissa_bits = 2
hparams.residual_dropout = 0.3
return hparams
@registry.register_hparams
def transformer_dla_base25_shared():
hparams = transformer_dla_base()
hparams.shared_decoder_input_and_softmax_weights = int(True)
hparams.encoder_layers = 25
return hparams
return hparams
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论