Commit d6b1aadf by libei

revise some transformer decoding configuration, support relative position…

revise some transformer decoding configuration, support relative position reprasentation training, add transformer_rpr_base
parent aaa7a715
...@@ -502,7 +502,7 @@ def _generate_relative_positions_embeddings(length, depth, ...@@ -502,7 +502,7 @@ def _generate_relative_positions_embeddings(length, depth,
with tf.variable_scope(name): with tf.variable_scope(name):
# libei: relative_positions_matrix shape is (L, L), value range is 0 ~ 2K (total 2K+1) # libei: relative_positions_matrix shape is (L, L), value range is 0 ~ 2K (total 2K+1)
relative_positions_matrix = _generate_relative_positions_matrix( relative_positions_matrix = _generate_relative_positions_matrix(
length, max_relative_position,cache=cache) length, max_relative_position, cache=cache)
vocab_size = max_relative_position * 2 + 1 vocab_size = max_relative_position * 2 + 1
# Generates embedding for each relative position of dimension depth. # Generates embedding for each relative position of dimension depth.
embeddings_table = tf.get_variable("embeddings", [vocab_size, depth]) embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
...@@ -602,7 +602,7 @@ def dot_product_attention_relative(q, ...@@ -602,7 +602,7 @@ def dot_product_attention_relative(q,
# This calculation only works for self attention. # This calculation only works for self attention.
# q, k and v must therefore have the same shape. # q, k and v must therefore have the same shape.
# wq: compatible means same shape # libei: compatible means same shape
if not cache: if not cache:
q.get_shape().assert_is_compatible_with(k.get_shape()) q.get_shape().assert_is_compatible_with(k.get_shape())
q.get_shape().assert_is_compatible_with(v.get_shape()) q.get_shape().assert_is_compatible_with(v.get_shape())
...@@ -610,7 +610,7 @@ def dot_product_attention_relative(q, ...@@ -610,7 +610,7 @@ def dot_product_attention_relative(q,
# Use separate embeddings suitable for keys and values. # Use separate embeddings suitable for keys and values.
depth = q.get_shape().as_list()[3] depth = q.get_shape().as_list()[3]
length = common_layers.shape_list(q)[2] length = common_layers.shape_list(q)[2]
# wq: relations_keys: (L, L, H), where H is hidden size of head # libei: relations_keys: (L, L, H), where H is hidden size of head
relations_keys = _generate_relative_positions_embeddings( relations_keys = _generate_relative_positions_embeddings(
length, depth, max_relative_position, "relative_positions_keys", length, depth, max_relative_position, "relative_positions_keys",
debug_flag=debug_flag+"+rp-key" if debug_flag is not None else None, debug_flag=debug_flag+"+rp-key" if debug_flag is not None else None,
......
...@@ -33,7 +33,7 @@ from tensor2tensor.models import common_hparams ...@@ -33,7 +33,7 @@ from tensor2tensor.models import common_hparams
from tensor2tensor.models import common_layers from tensor2tensor.models import common_layers
from tensor2tensor.utils import registry from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model from tensor2tensor.utils import t2t_model
from tensor2tensor.utils import beam_search from tensor2tensor.utils import beam_search_fast
from tensorflow.python.util import nest from tensorflow.python.util import nest
import tensorflow as tf import tensorflow as tf
...@@ -106,7 +106,7 @@ class Transformer(t2t_model.T2TModel): ...@@ -106,7 +106,7 @@ class Transformer(t2t_model.T2TModel):
return decoder_output return decoder_output
def _beam_decode(self, features, decode_length, beam_size, top_beams, last_position_only, alpha): def _beam_decode_fast(self, features, decode_length, beam_size, top_beams, last_position_only, alpha):
"""Beam search decoding. """Beam search decoding.
Args: Args:
...@@ -274,7 +274,7 @@ def fast_decode(encoder_output, ...@@ -274,7 +274,7 @@ def fast_decode(encoder_output,
beam_size=1, beam_size=1,
top_beams=1, top_beams=1,
alpha=1.0, alpha=1.0,
eos_id=beam_search.EOS_ID, eos_id=beam_search_fast.EOS_ID,
batch_size=None, batch_size=None,
force_decode_length=False): force_decode_length=False):
"""Given encoder output and a symbols to logits function, does fast decoding. """Given encoder output and a symbols to logits function, does fast decoding.
...@@ -353,7 +353,7 @@ def fast_decode(encoder_output, ...@@ -353,7 +353,7 @@ def fast_decode(encoder_output,
if beam_size > 1: # Beam Search if beam_size > 1: # Beam Search
initial_ids = tf.zeros([batch_size], dtype=tf.int32) initial_ids = tf.zeros([batch_size], dtype=tf.int32)
decoded_ids, scores = beam_search.beam_search( decoded_ids, scores = beam_search_fast.beam_search(
symbols_to_logits_fn, symbols_to_logits_fn,
initial_ids, initial_ids,
beam_size, beam_size,
...@@ -410,7 +410,7 @@ def fast_decode(encoder_output, ...@@ -410,7 +410,7 @@ def fast_decode(encoder_output,
tf.TensorShape([None]), tf.TensorShape([None]),
tf.TensorShape([None, None]), tf.TensorShape([None, None]),
tf.TensorShape([None, None]), tf.TensorShape([None, None]),
nest.map_structure(beam_search.get_state_shape_invariants, cache), nest.map_structure(beam_search_fast.get_state_shape_invariants, cache),
tf.TensorShape([None]), tf.TensorShape([None]),
]) ])
scores = log_prob scores = log_prob
...@@ -810,12 +810,12 @@ def transformer_base_v2(): ...@@ -810,12 +810,12 @@ def transformer_base_v2():
@registry.register_hparams @registry.register_hparams
def transformer_base_rpr_dropout1(): def transformer_rpr_base():
hparams = transformer_base() hparams = transformer_before()
hparams.max_relative_length = 16 hparams.max_relative_length = 20
hparams.attention_type = "relative_dot_product" hparams.attention_type = "relative_dot_product"
hparams.relu_dropout = 0.1 # optimal
hparams.attention_dropout = 0.1 hparams.filter_size = 4096
return hparams return hparams
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
# Dependency imports # Dependency imports
import numpy as np import numpy as np
from tensor2tensor.utils import beam_search_slow from tensor2tensor.utils import beam_search
import tensorflow as tf import tensorflow as tf
...@@ -40,7 +40,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -40,7 +40,7 @@ class BeamSearchTest(tf.test.TestCase):
# Just return random logits # Just return random logits
return tf.random_uniform((batch_size * beam_size, vocab_size)) return tf.random_uniform((batch_size * beam_size, vocab_size))
final_ids, final_probs = beam_search_slow.beam_search( final_ids, final_probs = beam_search.beam_search(
symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size,
0.) 0.)
...@@ -60,7 +60,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -60,7 +60,7 @@ class BeamSearchTest(tf.test.TestCase):
flags = tf.constant([[True, False, False, True], flags = tf.constant([[True, False, False, True],
[False, False, False, True]]) [False, False, False, True]])
topk_seq, topk_scores, topk_flags = beam_search_slow.compute_topk_scores_and_seq( topk_seq, topk_scores, topk_flags = beam_search.compute_topk_scores_and_seq(
sequences, scores, scores, flags, beam_size, batch_size) sequences, scores, scores, flags, beam_size, batch_size)
with self.test_session(): with self.test_session():
...@@ -115,7 +115,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -115,7 +115,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :])) logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits return logits
final_ids, final_probs = beam_search_slow.beam_search( final_ids, final_probs = beam_search.beam_search(
symbols_to_logits, symbols_to_logits,
initial_ids, initial_ids,
beam_size, beam_size,
...@@ -146,7 +146,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -146,7 +146,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :])) logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits return logits
final_ids, final_probs = beam_search_slow.beam_search( final_ids, final_probs = beam_search.beam_search(
symbols_to_logits, symbols_to_logits,
initial_ids, initial_ids,
beam_size, beam_size,
...@@ -175,7 +175,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -175,7 +175,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :])) logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits return logits
final_ids, final_probs = beam_search_slow.beam_search( final_ids, final_probs = beam_search.beam_search(
symbols_to_logits, symbols_to_logits,
initial_ids, initial_ids,
beam_size, beam_size,
...@@ -215,7 +215,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -215,7 +215,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :])) logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits return logits
final_ids, final_scores = beam_search_slow.beam_search( final_ids, final_scores = beam_search.beam_search(
symbols_to_logits, symbols_to_logits,
initial_ids, initial_ids,
beam_size, beam_size,
...@@ -258,7 +258,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -258,7 +258,7 @@ class BeamSearchTest(tf.test.TestCase):
return logits return logits
# Disable early stopping # Disable early stopping
final_ids, final_scores = beam_search_slow.beam_search( final_ids, final_scores = beam_search.beam_search(
symbols_to_logits, symbols_to_logits,
initial_ids, initial_ids,
beam_size, beam_size,
......
...@@ -26,7 +26,7 @@ import time ...@@ -26,7 +26,7 @@ import time
import six import six
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensor2tensor.utils import beam_search_slow from tensor2tensor.utils import beam_search
from tensor2tensor.utils import expert_utils as eu from tensor2tensor.utils import expert_utils as eu
from tensor2tensor.utils import modality from tensor2tensor.utils import modality
from tensor2tensor.utils import registry from tensor2tensor.utils import registry
...@@ -234,9 +234,9 @@ class T2TModel(object): ...@@ -234,9 +234,9 @@ class T2TModel(object):
vocab_size = target_modality.top_dimensionality vocab_size = target_modality.top_dimensionality
# Setting decode length to input length + decode_length # Setting decode length to input length + decode_length
decode_length = tf.shape(features["inputs"])[1] + tf.constant(decode_length) decode_length = tf.shape(features["inputs"])[1] + tf.constant(decode_length)
ids, scores = beam_search_slow.beam_search(symbols_to_logits_fn, initial_ids, ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids,
beam_size, decode_length, vocab_size, beam_size, decode_length, vocab_size,
alpha) alpha)
# Set inputs back to the unexpanded inputs to not to confuse the Estimator! # Set inputs back to the unexpanded inputs to not to confuse the Estimator!
features["inputs"] = inputs_old features["inputs"] = inputs_old
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论