Commit 96bdb988 by xiaotong

generate index of previous states in inference

parent f5958ffa
......@@ -52,6 +52,9 @@ class T2TStateBundle
public:
/* predictions */
XTensor prediction;
/* id of the previous state that generates the current one */
XTensor preID;
/* score of every prediction (last state of the path) */
XTensor score;
......
......@@ -95,6 +95,7 @@ void T2TSearch::Generate(T2TStateBundle * beam)
XTensor scoreTopK;
XTensor &score = beam->score;
XTensor &index = beam->prediction;
XTensor &preID = beam->preID;
int order = score.order;
CheckNTErrors(order >= 2, "The tensor must be of order 2 or larger.");
......@@ -115,11 +116,21 @@ void T2TSearch::Generate(T2TStateBundle * beam)
1.0F, score.devID, score.mem);
InitTensor(&index, order, dimsTopK, X_INT,
1.0F, score.devID, score.mem);
InitTensor(&preID, order, dimsTopK, X_INT,
1.0F, -1);
score.Reshape(order, dimsBeam);
/* keep the most promissing candidates in the beam */
TopK(score, scoreTopK, index, 0, beamSize);
CopyValues(scoreTopK, preID);
int sizePredict = score.GetDim(-1);
/* pre id !!! */
/* mod !!! */
score.Reshape(order, dims);
}
......@@ -134,6 +145,7 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
T2TState * states = beam->states;
XTensor &predict = beam->prediction;
XTensor index = *NewTensorBuf(predict.order - 1, predict.dimSize, X_FLOAT, 1.0F,
predict.devID, predict.mem);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论