Commit 8ab393f2 by libei

add fast decoding

parent 81667eab
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" /> <orderEntry type="jdk" jdkName="Python 3.6 (1)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
<component name="TestRunnerService"> <component name="TestRunnerService">
......
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
</state>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="39.104.93.174"> <component name="PublishConfigData" autoUpload="Always" serverName="39.104.62.93">
<serverData> <serverData>
<paths name="39.104.62.93">
<serverdata>
<mappings>
<mapping deploy="/WMT19" local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="39.104.93.174"> <paths name="39.104.93.174">
<serverdata> <serverdata>
<mappings> <mappings>
......
<component name="ProjectDictionaryState">
<dictionary name="LiBei" />
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (1)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
...@@ -10,6 +10,14 @@ ...@@ -10,6 +10,14 @@
<option name="port" value="22" /> <option name="port" value="22" />
</fileTransfer> </fileTransfer>
</webServer> </webServer>
<webServer id="73ec0a32-7aa0-46b4-9d59-59f5be3898b1" name="39.104.62.93" url="http://39.104.62.93">
<fileTransfer host="39.104.62.93" port="22" rootFolder="/wmt/libei" accessType="SFTP">
<advancedOptions>
<advancedOptions dataProtectionLevel="Private" />
</advancedOptions>
<option name="port" value="22" />
</fileTransfer>
</webServer>
</option> </option>
</component> </component>
</project> </project>
\ No newline at end of file
...@@ -27,6 +27,12 @@ import tensorflow as tf ...@@ -27,6 +27,12 @@ import tensorflow as tf
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
length = common_layers.shape_list(x)[1]
channels = common_layers.shape_list(x)[2]
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
"""Adds a bunch of sinusoids of different frequencies to a Tensor. """Adds a bunch of sinusoids of different frequencies to a Tensor.
Each channel of the input Tensor is incremented by a sinusoid of a different Each channel of the input Tensor is incremented by a sinusoid of a different
...@@ -54,8 +60,7 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): ...@@ -54,8 +60,7 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
Returns: Returns:
a Tensor the same shape as x. a Tensor the same shape as x.
""" """
length = tf.shape(x)[1]
channels = tf.shape(x)[2]
position = tf.to_float(tf.range(length)) position = tf.to_float(tf.range(length))
num_timescales = channels // 2 num_timescales = channels // 2
log_timescale_increment = ( log_timescale_increment = (
...@@ -67,7 +72,7 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): ...@@ -67,7 +72,7 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]]) signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]])
signal = tf.reshape(signal, [1, length, channels]) signal = tf.reshape(signal, [1, length, channels])
return x + signal return signal
def add_timing_signal_nd(x, min_timescale=1.0, max_timescale=1.0e4): def add_timing_signal_nd(x, min_timescale=1.0, max_timescale=1.0e4):
...@@ -204,7 +209,7 @@ def attention_bias_ignore_padding(memory_padding): ...@@ -204,7 +209,7 @@ def attention_bias_ignore_padding(memory_padding):
return tf.expand_dims(tf.expand_dims(ret, 1), 1) return tf.expand_dims(tf.expand_dims(ret, 1), 1)
def split_last_dimension(x, n): def split_last_dimension1(x, n):
"""Reshape x so that the last dimension becomes two dimensions. """Reshape x so that the last dimension becomes two dimensions.
The first of these two dimensions is n. The first of these two dimensions is n.
...@@ -223,6 +228,23 @@ def split_last_dimension(x, n): ...@@ -223,6 +228,23 @@ def split_last_dimension(x, n):
ret.set_shape(new_shape) ret.set_shape(new_shape)
return ret return ret
def split_last_dimension(x, n):
"""Reshape x so that the last dimension becomes two dimensions.
The first of these two dimensions is n.
Args:
x: a Tensor with shape [..., m]
n: an integer.
Returns:
a Tensor with shape [..., n, m/n]
"""
x_shape = common_layers.shape_list(x)
m = x_shape[-1]
if isinstance(m, int) and isinstance(n, int):
assert m % n == 0
return tf.reshape(x, x_shape[:-1] + [n, m // n])
def combine_last_two_dimensions(x): def combine_last_two_dimensions(x):
"""Reshape x so that the last two dimension become one. """Reshape x so that the last two dimension become one.
...@@ -409,21 +431,32 @@ def multihead_attention(query_antecedent, ...@@ -409,21 +431,32 @@ def multihead_attention(query_antecedent,
q, k, v = tf.split( q, k, v = tf.split(
combined, [total_key_depth, total_key_depth, total_value_depth], combined, [total_key_depth, total_key_depth, total_value_depth],
axis=2) axis=2)
k = split_heads(k, num_heads)
v = split_heads(v, num_heads)
if cache is not None:
k = cache["k"] = tf.concat([cache["k"], k], axis=2)
v = cache["v"] = tf.concat([cache["v"], v], axis=2)
else: else:
q = common_layers.conv1d( q = common_layers.conv1d(
query_antecedent, total_key_depth, 1, name="q_transform") query_antecedent, total_key_depth, 1, name="q_transform")
if cache is not None:
k = cache["k_encdec"]
v = cache["v_encdec"]
else:
combined = common_layers.conv1d( combined = common_layers.conv1d(
memory_antecedent, memory_antecedent,
total_key_depth + total_value_depth, total_key_depth + total_value_depth,
1, 1,
name="kv_transform") name="kv_transform")
k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2) k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2)
q = split_heads(q, num_heads)
k = split_heads(k, num_heads) k = split_heads(k, num_heads)
v = split_heads(v, num_heads) v = split_heads(v, num_heads)
q = split_heads(q, num_heads)
key_depth_per_head = total_key_depth // num_heads key_depth_per_head = total_key_depth // num_heads
q *= key_depth_per_head**-0.5 q *= key_depth_per_head**-0.5
if attention_type == "dot_product": if attention_type == "dot_product":
x = dot_product_attention( x = dot_product_attention(
q, k, v, bias, dropout_rate, summaries, image_shapes, dropout_broadcast_dims=dropout_broadcast_dims) q, k, v, bias, dropout_rate, summaries, image_shapes, dropout_broadcast_dims=dropout_broadcast_dims)
......
...@@ -1669,3 +1669,6 @@ def Linear(input, output_dim, name, activation=None, bias=True): ...@@ -1669,3 +1669,6 @@ def Linear(input, output_dim, name, activation=None, bias=True):
bias: a boolean to choose if use bias bias: a boolean to choose if use bias
""" """
return tf.layers.dense(input, output_dim, name=name, activation=activation, use_bias=bias) return tf.layers.dense(input, output_dim, name=name, activation=activation, use_bias=bias)
def log_prob_from_logits(logits, reduce_axis=-1):
return logits - tf.reduce_logsumexp(logits, axis=reduce_axis, keepdims=True)
\ No newline at end of file
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
# Dependency imports # Dependency imports
import numpy as np import numpy as np
from tensor2tensor.utils import beam_search from tensor2tensor.utils import beam_search_slow
import tensorflow as tf import tensorflow as tf
...@@ -40,7 +40,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -40,7 +40,7 @@ class BeamSearchTest(tf.test.TestCase):
# Just return random logits # Just return random logits
return tf.random_uniform((batch_size * beam_size, vocab_size)) return tf.random_uniform((batch_size * beam_size, vocab_size))
final_ids, final_probs = beam_search.beam_search( final_ids, final_probs = beam_search_slow.beam_search(
symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size,
0.) 0.)
...@@ -60,7 +60,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -60,7 +60,7 @@ class BeamSearchTest(tf.test.TestCase):
flags = tf.constant([[True, False, False, True], flags = tf.constant([[True, False, False, True],
[False, False, False, True]]) [False, False, False, True]])
topk_seq, topk_scores, topk_flags = beam_search.compute_topk_scores_and_seq( topk_seq, topk_scores, topk_flags = beam_search_slow.compute_topk_scores_and_seq(
sequences, scores, scores, flags, beam_size, batch_size) sequences, scores, scores, flags, beam_size, batch_size)
with self.test_session(): with self.test_session():
...@@ -115,7 +115,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -115,7 +115,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :])) logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits return logits
final_ids, final_probs = beam_search.beam_search( final_ids, final_probs = beam_search_slow.beam_search(
symbols_to_logits, symbols_to_logits,
initial_ids, initial_ids,
beam_size, beam_size,
...@@ -146,7 +146,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -146,7 +146,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :])) logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits return logits
final_ids, final_probs = beam_search.beam_search( final_ids, final_probs = beam_search_slow.beam_search(
symbols_to_logits, symbols_to_logits,
initial_ids, initial_ids,
beam_size, beam_size,
...@@ -175,7 +175,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -175,7 +175,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :])) logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits return logits
final_ids, final_probs = beam_search.beam_search( final_ids, final_probs = beam_search_slow.beam_search(
symbols_to_logits, symbols_to_logits,
initial_ids, initial_ids,
beam_size, beam_size,
...@@ -215,7 +215,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -215,7 +215,7 @@ class BeamSearchTest(tf.test.TestCase):
logits = tf.to_float(tf.log(probabilities[pos - 1, :])) logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
return logits return logits
final_ids, final_scores = beam_search.beam_search( final_ids, final_scores = beam_search_slow.beam_search(
symbols_to_logits, symbols_to_logits,
initial_ids, initial_ids,
beam_size, beam_size,
...@@ -258,7 +258,7 @@ class BeamSearchTest(tf.test.TestCase): ...@@ -258,7 +258,7 @@ class BeamSearchTest(tf.test.TestCase):
return logits return logits
# Disable early stopping # Disable early stopping
final_ids, final_scores = beam_search.beam_search( final_ids, final_scores = beam_search_slow.beam_search(
symbols_to_logits, symbols_to_logits,
initial_ids, initial_ids,
beam_size, beam_size,
......
...@@ -156,3 +156,24 @@ class Modality(object): ...@@ -156,3 +156,24 @@ class Modality(object):
weights_fn=weights_fn) weights_fn=weights_fn)
loss = tf.add_n(loss_num) / tf.maximum(1.0, tf.add_n(loss_den)) loss = tf.add_n(loss_num) / tf.maximum(1.0, tf.add_n(loss_den))
return sharded_logits, loss return sharded_logits, loss
def top_sharded_logits(self,
sharded_body_output,
sharded_targets,
data_parallelism):
"""Transform all shards of targets.
Classes with cross-shard interaction will override this function.
Args:
sharded_body_output: A list of Tensors.
sharded_targets: A list of Tensors.
data_parallelism: a expert_utils.Parallelism object.
weights_fn: function from targets to target weights.
Returns:
shaded_logits: A list of Tensors.
training_loss: a Scalar.
"""
sharded_logits = data_parallelism(self.top, sharded_body_output,
sharded_targets)
return sharded_logits
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
import copy import copy
import time import time
...@@ -25,11 +26,11 @@ import time ...@@ -25,11 +26,11 @@ import time
import six import six
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensor2tensor.utils import beam_search from tensor2tensor.utils import beam_search_slow
from tensor2tensor.utils import expert_utils as eu from tensor2tensor.utils import expert_utils as eu
from tensor2tensor.utils import modality from tensor2tensor.utils import modality
from tensor2tensor.utils import registry from tensor2tensor.utils import registry
from tensorflow.python.layers import base
import tensorflow as tf import tensorflow as tf
...@@ -233,7 +234,7 @@ class T2TModel(object): ...@@ -233,7 +234,7 @@ class T2TModel(object):
vocab_size = target_modality.top_dimensionality vocab_size = target_modality.top_dimensionality
# Setting decode length to input length + decode_length # Setting decode length to input length + decode_length
decode_length = tf.shape(features["inputs"])[1] + tf.constant(decode_length) decode_length = tf.shape(features["inputs"])[1] + tf.constant(decode_length)
ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids, ids, scores = beam_search_slow.beam_search(symbols_to_logits_fn, initial_ids,
beam_size, decode_length, vocab_size, beam_size, decode_length, vocab_size,
alpha) alpha)
...@@ -490,6 +491,8 @@ class T2TModel(object): ...@@ -490,6 +491,8 @@ class T2TModel(object):
""" """
raise NotImplementedError("Abstract Method") raise NotImplementedError("Abstract Method")
@property @property
def hparams(self): def hparams(self):
return self._hparams return self._hparams
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论