Commit 8ab393f2 by libei

add fast decoding

parent 81667eab
......@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="jdk" jdkName="Python 3.6 (1)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">
......
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
</state>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="39.104.93.174">
<component name="PublishConfigData" autoUpload="Always" serverName="39.104.62.93">
<serverData>
<paths name="39.104.62.93">
<serverdata>
<mappings>
<mapping deploy="/WMT19" local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="39.104.93.174">
<serverdata>
<mappings>
......
<component name="ProjectDictionaryState">
<dictionary name="LiBei" />
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (1)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
......@@ -10,6 +10,14 @@
<option name="port" value="22" />
</fileTransfer>
</webServer>
<webServer id="73ec0a32-7aa0-46b4-9d59-59f5be3898b1" name="39.104.62.93" url="http://39.104.62.93">
<fileTransfer host="39.104.62.93" port="22" rootFolder="/wmt/libei" accessType="SFTP">
<advancedOptions>
<advancedOptions dataProtectionLevel="Private" />
</advancedOptions>
<option name="port" value="22" />
</fileTransfer>
</webServer>
</option>
</component>
</project>
\ No newline at end of file
......@@ -27,6 +27,12 @@ import tensorflow as tf
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
length = common_layers.shape_list(x)[1]
channels = common_layers.shape_list(x)[2]
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
"""Adds a bunch of sinusoids of different frequencies to a Tensor.
Each channel of the input Tensor is incremented by a sinusoid of a different
......@@ -54,8 +60,7 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
Returns:
a Tensor the same shape as x.
"""
length = tf.shape(x)[1]
channels = tf.shape(x)[2]
position = tf.to_float(tf.range(length))
num_timescales = channels // 2
log_timescale_increment = (
......@@ -67,7 +72,7 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]])
signal = tf.reshape(signal, [1, length, channels])
return x + signal
return signal
def add_timing_signal_nd(x, min_timescale=1.0, max_timescale=1.0e4):
......@@ -204,7 +209,7 @@ def attention_bias_ignore_padding(memory_padding):
return tf.expand_dims(tf.expand_dims(ret, 1), 1)
def split_last_dimension(x, n):
def split_last_dimension1(x, n):
"""Reshape x so that the last dimension becomes two dimensions.
The first of these two dimensions is n.
......@@ -223,6 +228,23 @@ def split_last_dimension(x, n):
ret.set_shape(new_shape)
return ret
def split_last_dimension(x, n):
"""Reshape x so that the last dimension becomes two dimensions.
The first of these two dimensions is n.
Args:
x: a Tensor with shape [..., m]
n: an integer.
Returns:
a Tensor with shape [..., n, m/n]
"""
x_shape = common_layers.shape_list(x)
m = x_shape[-1]
if isinstance(m, int) and isinstance(n, int):
assert m % n == 0
return tf.reshape(x, x_shape[:-1] + [n, m // n])
def combine_last_two_dimensions(x):
"""Reshape x so that the last two dimension become one.
......@@ -409,21 +431,32 @@ def multihead_attention(query_antecedent,
q, k, v = tf.split(
combined, [total_key_depth, total_key_depth, total_value_depth],
axis=2)
k = split_heads(k, num_heads)
v = split_heads(v, num_heads)
if cache is not None:
k = cache["k"] = tf.concat([cache["k"], k], axis=2)
v = cache["v"] = tf.concat([cache["v"], v], axis=2)
else:
q = common_layers.conv1d(
query_antecedent, total_key_depth, 1, name="q_transform")
combined = common_layers.conv1d(
if cache is not None:
k = cache["k_encdec"]
v = cache["v_encdec"]
else:
combined = common_layers.conv1d(
memory_antecedent,
total_key_depth + total_value_depth,
1,
name="kv_transform")
k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2)
k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2)
k = split_heads(k, num_heads)
v = split_heads(v, num_heads)
q = split_heads(q, num_heads)
k = split_heads(k, num_heads)
v = split_heads(v, num_heads)
key_depth_per_head = total_key_depth // num_heads
q *= key_depth_per_head**-0.5
if attention_type == "dot_product":
x = dot_product_attention(
q, k, v, bias, dropout_rate, summaries, image_shapes, dropout_broadcast_dims=dropout_broadcast_dims)
......
......@@ -1668,4 +1668,7 @@ def Linear(input, output_dim, name, activation=None, bias=True):
activation: use activation function, default none
bias: a boolean to choose if use bias
"""
return tf.layers.dense(input, output_dim, name=name, activation=activation, use_bias=bias)
\ No newline at end of file
return tf.layers.dense(input, output_dim, name=name, activation=activation, use_bias=bias)
def log_prob_from_logits(logits, reduce_axis=-1):
return logits - tf.reduce_logsumexp(logits, axis=reduce_axis, keepdims=True)
\ No newline at end of file
......@@ -21,7 +21,7 @@ from __future__ import print_function
# Dependency imports
import numpy as np
from tensor2tensor.utils import beam_search
from tensor2tensor.utils import beam_search_slow
import tensorflow as tf
......@@ -40,7 +40,7 @@ class BeamSearchTest(tf.test.TestCase):
# Just return random logits
return tf.random_uniform((batch_size * beam_size, vocab_size))
final_ids, final_probs = beam_search.beam_search(
final_ids, final_probs = beam_search_slow.beam_search(
symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size,
0.)
......@@ -60,7 +60,7 @@ class BeamSearchTest(tf.test.TestCase):
flags = tf.constant([[True, False, False, True],
[False, False, False, True]])
topk_seq, topk_scores, topk_flags = beam_search.compute_topk_scores_and_seq(
topk_seq, topk_scores, topk_flags = beam_search_slow.compute_topk_scores_and_seq(
sequences, scores, scores, flags, beam_size, batch_size)
with self.test_session():
......@@ -115,7 +115,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits
final_ids, final_probs = beam_search.beam_search(
final_ids, final_probs = beam_search_slow.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
......@@ -146,7 +146,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits
final_ids, final_probs = beam_search.beam_search(
final_ids, final_probs = beam_search_slow.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
......@@ -175,7 +175,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits
final_ids, final_probs = beam_search.beam_search(
final_ids, final_probs = beam_search_slow.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
......@@ -215,7 +215,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits
final_ids, final_scores = beam_search.beam_search(
final_ids, final_scores = beam_search_slow.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
......@@ -258,7 +258,7 @@ class BeamSearchTest(tf.test.TestCase):
return logits
# Disable early stopping
final_ids, final_scores = beam_search.beam_search(
final_ids, final_scores = beam_search_slow.beam_search(
symbols_to_logits,
initial_ids,
beam_size,
......
......@@ -156,3 +156,24 @@ class Modality(object):
weights_fn=weights_fn)
loss = tf.add_n(loss_num) / tf.maximum(1.0, tf.add_n(loss_den))
return sharded_logits, loss
def top_sharded_logits(self,
sharded_body_output,
sharded_targets,
data_parallelism):
"""Transform all shards of targets.
Classes with cross-shard interaction will override this function.
Args:
sharded_body_output: A list of Tensors.
sharded_targets: A list of Tensors.
data_parallelism: a expert_utils.Parallelism object.
weights_fn: function from targets to target weights.
Returns:
shaded_logits: A list of Tensors.
training_loss: a Scalar.
"""
sharded_logits = data_parallelism(self.top, sharded_body_output,
sharded_targets)
return sharded_logits
......@@ -16,6 +16,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import copy
import time
......@@ -25,11 +26,11 @@ import time
import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensor2tensor.utils import beam_search
from tensor2tensor.utils import beam_search_slow
from tensor2tensor.utils import expert_utils as eu
from tensor2tensor.utils import modality
from tensor2tensor.utils import registry
from tensorflow.python.layers import base
import tensorflow as tf
......@@ -211,7 +212,7 @@ class T2TModel(object):
logits = sharded_logits[0] # Assuming we have one shard.
if last_position_only:
return tf.squeeze(logits, axis=[1, 2, 3])
current_output_position = tf.shape(ids)[1] - 1 # -1 due to the pad above.
current_output_position = tf.shape(ids)[1] - 1 # -1 due to the pad above.
logits = logits[:, current_output_position, :, :]
return tf.squeeze(logits, axis=[1, 2])
......@@ -233,9 +234,9 @@ class T2TModel(object):
vocab_size = target_modality.top_dimensionality
# Setting decode length to input length + decode_length
decode_length = tf.shape(features["inputs"])[1] + tf.constant(decode_length)
ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids,
beam_size, decode_length, vocab_size,
alpha)
ids, scores = beam_search_slow.beam_search(symbols_to_logits_fn, initial_ids,
beam_size, decode_length, vocab_size,
alpha)
# Set inputs back to the unexpanded inputs to not to confuse the Estimator!
features["inputs"] = inputs_old
......@@ -490,6 +491,8 @@ class T2TModel(object):
"""
raise NotImplementedError("Abstract Method")
@property
def hparams(self):
return self._hparams
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论