Commit cc78cf5b by libei

fix tensorflow fast decode bugs

parent 8a8dc876
......@@ -34,6 +34,7 @@ from tensor2tensor.models import common_layers
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
from tensor2tensor.utils import beam_search
from tensorflow.python.util import nest
import tensorflow as tf
......@@ -409,7 +410,7 @@ def fast_decode(encoder_output,
tf.TensorShape([None]),
tf.TensorShape([None, None]),
tf.TensorShape([None, None]),
nest.map_structure(beam_search_slow.get_state_shape_invariants, cache),
nest.map_structure(beam_search.get_state_shape_invariants, cache),
tf.TensorShape([None]),
])
scores = log_prob
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论