Commit 5e169df8 by xiaotong

improve the design of search state

parent b9871b8d
......@@ -28,18 +28,41 @@
namespace transformer
{
/* state in decoder - it keeps all previously-generated words and their
hidden states */
/* state for search. It keeps the path (back-pointer), prediction distribution,
and etc. */
class T2TState
{
/* we assume that the prediction is an integer number */
int prediction;
/* probability of the prediction */
float prob;
/* probability of the path */
float pathProb;
/* pointer to the last state */
T2TState * last;
/* pointers to the following states */
XList * followings;
};
/* a bundle of states */
class T2TStateBundle
{
/* distribution of every prediction (last state of the path) */
XTensor probs;
/* distribution of every path */
XTensor pathProbs;
/* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */
XList * encoderLayers;
XList encoderLayers;
/* layers on the decoder side */
XList * decoderLayers;
/* */
XList decoderLayers;
};
/* The predictor reads the current state and then predicts the next.
......@@ -60,10 +83,10 @@ public:
~T2TPredictor();
/* read a state */
void Read(T2TModel * model, T2TState * current);
void Read(T2TModel * model, T2TStateBundle * current);
/* predict the next state */
void Predict(T2TState * next);
void Predict(T2TStateBundle * next);
};
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论