Commit cc78cf5b by libei

fix tensorflow fast decode bugs

parent 8a8dc876
...@@ -34,6 +34,7 @@ from tensor2tensor.models import common_layers ...@@ -34,6 +34,7 @@ from tensor2tensor.models import common_layers
from tensor2tensor.utils import registry from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model from tensor2tensor.utils import t2t_model
from tensor2tensor.utils import beam_search from tensor2tensor.utils import beam_search
from tensorflow.python.util import nest
import tensorflow as tf import tensorflow as tf
...@@ -409,7 +410,7 @@ def fast_decode(encoder_output, ...@@ -409,7 +410,7 @@ def fast_decode(encoder_output,
tf.TensorShape([None]), tf.TensorShape([None]),
tf.TensorShape([None, None]), tf.TensorShape([None, None]),
tf.TensorShape([None, None]), tf.TensorShape([None, None]),
nest.map_structure(beam_search_slow.get_state_shape_invariants, cache), nest.map_structure(beam_search.get_state_shape_invariants, cache),
tf.TensorShape([None]), tf.TensorShape([None]),
]) ])
scores = log_prob scores = log_prob
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论