Commit 2e5f98d6 by libei

suport share decoder input and softmax embedding

parent 8ab393f2
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
#*.so
# Checkpoints
checkpoints
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Pycharm
.idea
# binary dataset
lib
...@@ -460,8 +460,8 @@ def wmt_zhen_tokens(model_hparams, wrong_vocab_size): ...@@ -460,8 +460,8 @@ def wmt_zhen_tokens(model_hparams, wrong_vocab_size):
"""Chinese to English translation benchmark.""" """Chinese to English translation benchmark."""
p = default_problem_hparams() p = default_problem_hparams()
# This vocab file must be present within the data directory. # This vocab file must be present within the data directory.
if model_hparams.shared_embedding_and_softmax_weights == 1: # if model_hparams.shared_embedding_and_softmax_weights == 1:
model_hparams.shared_embedding_and_softmax_weights = 0 # model_hparams.shared_embedding_and_softmax_weights = 0
source_vocab_filename = os.path.join(model_hparams.data_dir, source_vocab_filename = os.path.join(model_hparams.data_dir,
"source_dic") "source_dic")
target_vocab_filename = os.path.join(model_hparams.data_dir, target_vocab_filename = os.path.join(model_hparams.data_dir,
......
...@@ -103,6 +103,8 @@ def basic_params1(): ...@@ -103,6 +103,8 @@ def basic_params1():
# by using a problem_hparams that uses the same modality object for # by using a problem_hparams that uses the same modality object for
# the input_modality and target_modality. # the input_modality and target_modality.
shared_embedding_and_softmax_weights=int(False), shared_embedding_and_softmax_weights=int(False),
# share the target_modality and softmax
shared_decoder_input_and_softmax_weights = int(False),
# For each feature for which you want to override the default input # For each feature for which you want to override the default input
# modality, add an entry to this semicolon-separated string. Entries are # modality, add an entry to this semicolon-separated string. Entries are
# formatted "feature_name:modality_type:modality_name", e.g. # formatted "feature_name:modality_type:modality_name", e.g.
......
...@@ -91,8 +91,11 @@ class SymbolModality(modality.Modality): ...@@ -91,8 +91,11 @@ class SymbolModality(modality.Modality):
return self.bottom_simple(x, "input_emb", reuse=None) return self.bottom_simple(x, "input_emb", reuse=None)
def targets_bottom(self, x): def targets_bottom(self, x):
assert not (self._model_hparams.shared_embedding_and_softmax_weights and self._model_hparams.shared_decoder_input_and_softmax_weights)
if self._model_hparams.shared_embedding_and_softmax_weights: if self._model_hparams.shared_embedding_and_softmax_weights:
return self.bottom_simple(x, "shared", reuse=True) return self.bottom_simple(x, "shared", reuse=True)
elif self._model_hparams.shared_decoder_input_and_softmax_weights:
return self.bottom_simple(x, "shared_target", reuse=tf.AUTO_REUSE)
else: else:
return self.bottom_simple(x, "target_emb", reuse=None) return self.bottom_simple(x, "target_emb", reuse=None)
...@@ -105,9 +108,13 @@ class SymbolModality(modality.Modality): ...@@ -105,9 +108,13 @@ class SymbolModality(modality.Modality):
Returns: Returns:
logits: A Tensor with shape [batch, p0, p1, ?, vocab_size]. logits: A Tensor with shape [batch, p0, p1, ?, vocab_size].
""" """
assert not (self._model_hparams.shared_embedding_and_softmax_weights and self._model_hparams.shared_decoder_input_and_softmax_weights)
if self._model_hparams.shared_embedding_and_softmax_weights: if self._model_hparams.shared_embedding_and_softmax_weights:
scope_name = "shared" scope_name = "shared"
reuse = True reuse = True
elif self._model_hparams.shared_decoder_input_and_softmax_weights:
scope_name = "shared_target"
reuse = True
else: else:
scope_name = "softmax" scope_name = "softmax"
reuse = False reuse = False
......
...@@ -723,7 +723,8 @@ def transformer_base(): ...@@ -723,7 +723,8 @@ def transformer_base():
hparams.optimizer_adam_beta2 = 0.98 hparams.optimizer_adam_beta2 = 0.98
hparams.num_sampled_classes = 0 hparams.num_sampled_classes = 0
hparams.label_smoothing = 0.1 hparams.label_smoothing = 0.1
hparams.shared_embedding_and_softmax_weights = int(True) hparams.shared_embedding_and_softmax_weights = int(False)
hparams.shared_decoder_input_and_softmax_weights = int(False)
hparams.add_hparam("filter_size", 2048) # Add new ones like this. hparams.add_hparam("filter_size", 2048) # Add new ones like this.
# attention-related flags # attention-related flags
...@@ -797,236 +798,6 @@ def transformer_before_big(): ...@@ -797,236 +798,6 @@ def transformer_before_big():
@registry.register_hparams @registry.register_hparams
def transformer_big_single_gpu():
"""HParams for transformer big model for single gpu."""
hparams = transformer_big()
hparams.residual_dropout = 0.1
hparams.learning_rate_warmup_steps = 16000
hparams.optimizer_adam_beta2 = 0.998
hparams.batching_mantissa_bits = 3
return hparams
@registry.register_hparams
def transformer_base_single_gpu():
"""HParams for transformer base model for single gpu."""
hparams = transformer_base()
hparams.batch_size = 8192
hparams.learning_rate_warmup_steps = 16000
hparams.batching_mantissa_bits = 2
return hparams
@registry.register_hparams
def transformer_big_dr1():
hparams = transformer_base()
hparams.hidden_size = 1024
hparams.filter_size = 4096
hparams.num_heads = 16
hparams.residual_dropout = 0.1
hparams.batching_mantissa_bits = 2
return hparams
@registry.register_hparams
def transformer_big_enfr():
hparams = transformer_big_dr1()
hparams.shared_embedding_and_softmax_weights = int(False)
hparams.filter_size = 8192
hparams.residual_dropout = 0.1
return hparams
@registry.register_hparams
def transformer_big_dr2():
hparams = transformer_big_dr1()
hparams.residual_dropout = 0.2
return hparams
@registry.register_hparams
def transformer_base_ldcd():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.learning_rate_decay_scheme = "ld&cd"
hparams.learning_rate_ldcd_epoch = 5
hparams.learning_rate_warmup_steps = 4000
return hparams
@registry.register_hparams
def transformer_base_ldcd_n10():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.learning_rate_decay_scheme = "ld&cd"
hparams.learning_rate_ldcd_epoch = 10
hparams.learning_rate_warmup_steps = 4000
return hparams
@registry.register_hparams
def transformer_base_ldcd_n2():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.learning_rate_decay_scheme = "ld&cd"
hparams.learning_rate_ldcd_epoch = 2
hparams.learning_rate_warmup_steps = 4000
return hparams
@registry.register_hparams
def transformer_base_ldcd_n1():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.learning_rate_decay_scheme = "ld&cd"
hparams.learning_rate_ldcd_epoch = 1
hparams.learning_rate_warmup_steps = 4000
return hparams
@registry.register_hparams
def transformer_base_amsgrad():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "AMSGrad"
return hparams
@registry.register_hparams
def transformer_base_amsgrad_v2():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "AMSGrad"
hparams.optimizer_adam_beta1 = 0.9
hparams.optimizer_adam_beta2 = 0.999
hparams.optimizer_adam_epsilon = 1e-8
return hparams
@registry.register_hparams
def transformer_base_amsgrad_v3():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "AMSGrad"
hparams.optimizer_adam_beta1 = 0.9
hparams.optimizer_adam_beta2 = 0.99
hparams.optimizer_adam_epsilon = 1e-8
return hparams
@registry.register_hparams
def transformer_base_amsgrad_v4():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "AMSGrad"
hparams.optimizer_adam_beta1 = 0.9
hparams.optimizer_adam_beta2 = 0.99
hparams.optimizer_adam_epsilon = 1e-9
return hparams
@registry.register_hparams
def transformer_base_ldrestart_n3():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.learning_rate_decay_scheme = "ld&restart"
hparams.learning_rate_ldrestart_epoch = 3
hparams.learning_rate_warmup_steps = 4000
return hparams
@registry.register_hparams
def transformer_base_powersign():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "PowerSign"
hparams.optimizer_powersign_beta = 0.9
hparams.optimizer_powersign_decay = ""
return hparams
@registry.register_hparams
def transformer_base_powersign_ld():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "PowerSign"
hparams.optimizer_powersign_beta = 0.9
hparams.optimizer_powersign_decay = "linear"
return hparams
@registry.register_hparams
def transformer_base_powersign_cd():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "PowerSign"
hparams.optimizer_powersign_beta = 0.9
hparams.optimizer_powersign_decay = "cosine"
hparams.optimizer_powersign_period = 1
return hparams
@registry.register_hparams
def transformer_base_powersign_rd():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "PowerSign"
hparams.optimizer_powersign_beta = 0.9
hparams.optimizer_powersign_decay = "restart"
hparams.optimizer_powersign_period = 1
return hparams
@registry.register_hparams
def transformer_base_powersign_rd_n10():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "PowerSign"
hparams.optimizer_powersign_beta = 0.9
hparams.optimizer_powersign_decay = "restart"
hparams.optimizer_powersign_period = 10
return hparams
@registry.register_hparams
def transformer_base_powersign_rd_n20():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.optimizer = "PowerSign"
hparams.optimizer_powersign_beta = 0.9
hparams.optimizer_powersign_decay = "restart"
hparams.optimizer_powersign_period = 20
return hparams
@registry.register_hparams
def transformer_base_swish1():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.ffn_layer = "conv_hidden_swish"
hparams.swish_dropout = 0.0
hparams.swish_beta = 1.0
hparams.swish_beta_is_trainable = False
return hparams
@registry.register_hparams
def transformer_base_swish_trainable():
"""Set of hyperparameters."""
hparams = transformer_base()
hparams.ffn_layer = "conv_hidden_swish"
hparams.swish_dropout = 0.0
#hparams.swish_beta = 1.0
hparams.swish_beta_is_trainable = True
return hparams
@registry.register_hparams
def transformer_big_adafactor():
"""HParams for transfomer big model on WMT."""
hparams = transformer_base()
hparams.hidden_size = 1024
hparams.filter_size = 4096
hparams.num_heads = 16
hparams.batching_mantissa_bits = 2
hparams.residual_dropout = 0.3
hparams.optimizer = "Adafactor"
hparams.epsilon = 1e-8
hparams.learning_rate_warmup_steps = 16000
hparams.optimizer_adafactor_beta2 = 0.997
return hparams
@registry.register_hparams
def transformer_base_v2(): def transformer_base_v2():
"""Set of hyperparameters. """Set of hyperparameters.
set relu_dropout and attention_dropout as 0.1 set relu_dropout and attention_dropout as 0.1
...@@ -1056,7 +827,6 @@ def transformer_base_v3(): ...@@ -1056,7 +827,6 @@ def transformer_base_v3():
return hparams return hparams
@registry.register_hparams @registry.register_hparams
def transformer_big_multistep2(): def transformer_big_multistep2():
# new model use optimizer MultistepAdam # new model use optimizer MultistepAdam
...@@ -1069,13 +839,9 @@ def transformer_big_multistep2(): ...@@ -1069,13 +839,9 @@ def transformer_big_multistep2():
return hparams return hparams
@registry.register_hparams @registry.register_hparams
def transformer_big_adafactor_test(): def transformer_before_test():
# new model use optimizer MultistepAdam # new model use optimizer MultistepAdam
hparams = transformer_big() hparams = transformer_before()
hparams.optimizer = "Adafactor" hparams.shared_decoder_input_and_softmax_weights = int(False)
hparams.learning_rate_warmup_steps = 8000
hparams.batch_size = 4096
hparams.optimizer_adafactor_beta2 = 0.999
return hparams return hparams
...@@ -396,7 +396,8 @@ def transformer_dla(): ...@@ -396,7 +396,8 @@ def transformer_dla():
hparams.optimizer_adam_beta2 = 0.98 hparams.optimizer_adam_beta2 = 0.98
hparams.num_sampled_classes = 0 hparams.num_sampled_classes = 0
hparams.label_smoothing = 0.1 hparams.label_smoothing = 0.1
hparams.shared_embedding_and_softmax_weights = int(True) hparams.shared_embedding_and_softmax_weights = int(False)
hparams.shared_decoder_input_and_softmax_weights = int(False)
hparams.add_hparam("filter_size", 2048) # Add new ones like this. hparams.add_hparam("filter_size", 2048) # Add new ones like this.
# attention-related flags # attention-related flags
...@@ -437,6 +438,8 @@ def transformer_dla(): ...@@ -437,6 +438,8 @@ def transformer_dla():
@registry.register_hparams @registry.register_hparams
def transformer_dla_base(): def transformer_dla_base():
hparams = transformer_dla() hparams = transformer_dla()
hparams.encoder_layers = 6
hparams.decoder_layers = 6
hparams.normalize_before = True hparams.normalize_before = True
hparams.attention_dropout = 0.1 hparams.attention_dropout = 0.1
hparams.residual_dropout = 0.1 hparams.residual_dropout = 0.1
...@@ -450,10 +453,19 @@ def transformer_dla_base(): ...@@ -450,10 +453,19 @@ def transformer_dla_base():
@registry.register_hparams @registry.register_hparams
def transformer_dla_big(): def transformer_dla_big():
"""HParams for transfomer big model on WMT.""" """HParams for transfomer big model on WMT."""
hparams = transformer_dla_base() hparams = transformer_dla()
hparams.hidden_size = 1024 hparams.hidden_size = 1024
hparams.filter_size = 4096 hparams.filter_size = 4096
hparams.num_heads = 16 hparams.num_heads = 16
hparams.batching_mantissa_bits = 2 hparams.batching_mantissa_bits = 2
hparams.residual_dropout = 0.3 hparams.residual_dropout = 0.3
return hparams return hparams
@registry.register_hparams
def transformer_dla_base25_shared():
hparams = transformer_dla_base()
hparams.shared_decoder_input_and_softmax_weights = int(True)
hparams.encoder_layers = 25
return hparams
return hparams
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论