Commit 96bdb988 by xiaotong

generate index of previous states in inference

parent f5958ffa
...@@ -52,6 +52,9 @@ class T2TStateBundle ...@@ -52,6 +52,9 @@ class T2TStateBundle
public: public:
/* predictions */ /* predictions */
XTensor prediction; XTensor prediction;
/* id of the previous state that generates the current one */
XTensor preID;
/* score of every prediction (last state of the path) */ /* score of every prediction (last state of the path) */
XTensor score; XTensor score;
......
...@@ -95,6 +95,7 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -95,6 +95,7 @@ void T2TSearch::Generate(T2TStateBundle * beam)
XTensor scoreTopK; XTensor scoreTopK;
XTensor &score = beam->score; XTensor &score = beam->score;
XTensor &index = beam->prediction; XTensor &index = beam->prediction;
XTensor &preID = beam->preID;
int order = score.order; int order = score.order;
CheckNTErrors(order >= 2, "The tensor must be of order 2 or larger."); CheckNTErrors(order >= 2, "The tensor must be of order 2 or larger.");
...@@ -115,11 +116,21 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -115,11 +116,21 @@ void T2TSearch::Generate(T2TStateBundle * beam)
1.0F, score.devID, score.mem); 1.0F, score.devID, score.mem);
InitTensor(&index, order, dimsTopK, X_INT, InitTensor(&index, order, dimsTopK, X_INT,
1.0F, score.devID, score.mem); 1.0F, score.devID, score.mem);
InitTensor(&preID, order, dimsTopK, X_INT,
1.0F, -1);
score.Reshape(order, dimsBeam); score.Reshape(order, dimsBeam);
/* keep the most promissing candidates in the beam */ /* keep the most promissing candidates in the beam */
TopK(score, scoreTopK, index, 0, beamSize); TopK(score, scoreTopK, index, 0, beamSize);
CopyValues(scoreTopK, preID);
int sizePredict = score.GetDim(-1);
/* pre id !!! */
/* mod !!! */
score.Reshape(order, dims); score.Reshape(order, dims);
} }
...@@ -134,6 +145,7 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -134,6 +145,7 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
T2TState * states = beam->states; T2TState * states = beam->states;
XTensor &predict = beam->prediction; XTensor &predict = beam->prediction;
XTensor index = *NewTensorBuf(predict.order - 1, predict.dimSize, X_FLOAT, 1.0F, XTensor index = *NewTensorBuf(predict.order - 1, predict.dimSize, X_FLOAT, 1.0F,
predict.devID, predict.mem); predict.devID, predict.mem);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论