Commit d6b1aadf by libei

revise some transformer decoding configuration, support relative position…

revise some transformer decoding configuration, support relative position reprasentation training, add transformer_rpr_base
parent aaa7a715
......@@ -502,7 +502,7 @@ def _generate_relative_positions_embeddings(length, depth,
with tf.variable_scope(name):
# libei: relative_positions_matrix shape is (L, L), value range is 0 ~ 2K (total 2K+1)
relative_positions_matrix = _generate_relative_positions_matrix(
length, max_relative_position,cache=cache)
length, max_relative_position, cache=cache)
vocab_size = max_relative_position * 2 + 1
# Generates embedding for each relative position of dimension depth.
embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
......@@ -602,7 +602,7 @@ def dot_product_attention_relative(q,
# This calculation only works for self attention.
# q, k and v must therefore have the same shape.
# wq: compatible means same shape
# libei: compatible means same shape
if not cache:
q.get_shape().assert_is_compatible_with(k.get_shape())
q.get_shape().assert_is_compatible_with(v.get_shape())
......@@ -610,7 +610,7 @@ def dot_product_attention_relative(q,
# Use separate embeddings suitable for keys and values.
depth = q.get_shape().as_list()[3]
length = common_layers.shape_list(q)[2]
# wq: relations_keys: (L, L, H), where H is hidden size of head
# libei: relations_keys: (L, L, H), where H is hidden size of head
relations_keys = _generate_relative_positions_embeddings(
length, depth, max_relative_position, "relative_positions_keys",
debug_flag=debug_flag+"+rp-key" if debug_flag is not None else None,
......
......@@ -33,7 +33,7 @@ from tensor2tensor.models import common_hparams
from tensor2tensor.models import common_layers
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
from tensor2tensor.utils import beam_search
from tensor2tensor.utils import beam_search_fast
from tensorflow.python.util import nest
import tensorflow as tf
......@@ -106,7 +106,7 @@ class Transformer(t2t_model.T2TModel):
return decoder_output
def _beam_decode(self, features, decode_length, beam_size, top_beams, last_position_only, alpha):
def _beam_decode_fast(self, features, decode_length, beam_size, top_beams, last_position_only, alpha):
"""Beam search decoding.
Args:
......@@ -274,7 +274,7 @@ def fast_decode(encoder_output,
beam_size=1,
top_beams=1,
alpha=1.0,
eos_id=beam_search.EOS_ID,
eos_id=beam_search_fast.EOS_ID,
batch_size=None,
force_decode_length=False):
"""Given encoder output and a symbols to logits function, does fast decoding.
......@@ -353,7 +353,7 @@ def fast_decode(encoder_output,
if beam_size > 1: # Beam Search
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
decoded_ids, scores = beam_search.beam_search(
decoded_ids, scores = beam_search_fast.beam_search(
symbols_to_logits_fn,
initial_ids,
beam_size,
......@@ -410,7 +410,7 @@ def fast_decode(encoder_output,
tf.TensorShape([None]),
tf.TensorShape([None, None]),
tf.TensorShape([None, None]),
nest.map_structure(beam_search.get_state_shape_invariants, cache),
nest.map_structure(beam_search_fast.get_state_shape_invariants, cache),
tf.TensorShape([None]),
])
scores = log_prob
......@@ -810,12 +810,12 @@ def transformer_base_v2():
@registry.register_hparams
def transformer_base_rpr_dropout1():
hparams = transformer_base()
hparams.max_relative_length = 16
def transformer_rpr_base():
hparams = transformer_before()
hparams.max_relative_length = 20
hparams.attention_type = "relative_dot_product"
hparams.relu_dropout = 0.1
hparams.attention_dropout = 0.1
# optimal
hparams.filter_size = 4096
return hparams
......
......@@ -21,7 +21,7 @@ from __future__ import print_function
# Dependency imports
import numpy as np
from tensor2tensor.utils import beam_search_slow
from tensor2tensor.utils import beam_search
import tensorflow as tf
......@@ -40,7 +40,7 @@ class BeamSearchTest(tf.test.TestCase):
# Just return random logits
return tf.random_uniform((batch_size * beam_size, vocab_size))
final_ids, final_probs = beam_search_slow.beam_search(
final_ids, final_probs = beam_search.beam_search(
symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size,
0.)
......@@ -60,7 +60,7 @@ class BeamSearchTest(tf.test.TestCase):
flags = tf.constant([[True, False, False, True],
[False, False, False, True]])
topk_seq, topk_scores, topk_flags = beam_search_slow.compute_topk_scores_and_seq(
topk_seq, topk_scores, topk_flags = beam_search.compute_topk_scores_and_seq(
sequences, scores, scores, flags, beam_size, batch_size)
with self.test_session():
......@@ -115,7 +115,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits
final_ids, final_probs = beam_search_slow.beam_search(
final_ids, final_probs = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
......@@ -146,7 +146,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits
final_ids, final_probs = beam_search_slow.beam_search(
final_ids, final_probs = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
......@@ -175,7 +175,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits
final_ids, final_probs = beam_search_slow.beam_search(
final_ids, final_probs = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
......@@ -215,7 +215,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits
final_ids, final_scores = beam_search_slow.beam_search(
final_ids, final_scores = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
......@@ -258,7 +258,7 @@ class BeamSearchTest(tf.test.TestCase):
return logits
# Disable early stopping
final_ids, final_scores = beam_search_slow.beam_search(
final_ids, final_scores = beam_search.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
......
......@@ -26,7 +26,7 @@ import time
import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensor2tensor.utils import beam_search_slow
from tensor2tensor.utils import beam_search
from tensor2tensor.utils import expert_utils as eu
from tensor2tensor.utils import modality
from tensor2tensor.utils import registry
......@@ -234,9 +234,9 @@ class T2TModel(object):
vocab_size = target_modality.top_dimensionality
# Setting decode length to input length + decode_length
decode_length = tf.shape(features["inputs"])[1] + tf.constant(decode_length)
ids, scores = beam_search_slow.beam_search(symbols_to_logits_fn, initial_ids,
beam_size, decode_length, vocab_size,
alpha)
ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids,
beam_size, decode_length, vocab_size,
alpha)
# Set inputs back to the unexpanded inputs to not to confuse the Estimator!
features["inputs"] = inputs_old
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论