Commit 5e169df8 by xiaotong

improve the design of search state

parent b9871b8d
...@@ -28,18 +28,41 @@ ...@@ -28,18 +28,41 @@
namespace transformer namespace transformer
{ {
/* state in decoder - it keeps all previously-generated words and their /* state for search. It keeps the path (back-pointer), prediction distribution,
hidden states */ and etc. */
class T2TState class T2TState
{ {
/* we assume that the prediction is an integer number */
int prediction;
/* probability of the prediction */
float prob;
/* probability of the path */
float pathProb;
/* pointer to the last state */
T2TState * last;
/* pointers to the following states */
XList * followings;
};
/* a bundle of states */
class T2TStateBundle
{
/* distribution of every prediction (last state of the path) */
XTensor probs;
/* distribution of every path */
XTensor pathProbs;
/* layers on the encoder side. We actually use the encoder output instead /* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */ of all hidden layers. */
XList * encoderLayers; XList encoderLayers;
/* layers on the decoder side */ /* layers on the decoder side */
XList * decoderLayers; XList decoderLayers;
/* */
}; };
/* The predictor reads the current state and then predicts the next. /* The predictor reads the current state and then predicts the next.
...@@ -60,10 +83,10 @@ public: ...@@ -60,10 +83,10 @@ public:
~T2TPredictor(); ~T2TPredictor();
/* read a state */ /* read a state */
void Read(T2TModel * model, T2TState * current); void Read(T2TModel * model, T2TStateBundle * current);
/* predict the next state */ /* predict the next state */
void Predict(T2TState * next); void Predict(T2TStateBundle * next);
}; };
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论