Commit a3ae84b5 by xiaotong

fix the bug of wrong indexing in beam search

parent dea97945
......@@ -267,11 +267,13 @@ void T2TSearch::Generate(T2TStateBundle * beam)
int dimsTopK[MAX_TENSOR_DIM_NUM];
XTensor scoreTopK;
XTensor indexCPU;
XTensor &score = beam->modelScore;
XTensor &index = beam->prediction;
XTensor &preID = beam->preID;
XTensor &probPath = beam->probPath;
XTensor &prob = beam->prob;
int order = score.order;
for (int i = 0; i < order; i++) {
......@@ -295,6 +297,8 @@ void T2TSearch::Generate(T2TStateBundle * beam)
1.0F, score.devID, score.mem);
InitTensor(&index, order, dimsTopK, X_INT,
1.0F, score.devID, score.mem);
InitTensor(&indexCPU, order, dimsTopK, X_INT,
1.0F, -1);
InitTensor(&preID, order, dimsTopK, X_INT,
1.0F, -1);
......@@ -302,7 +306,7 @@ void T2TSearch::Generate(T2TStateBundle * beam)
/* keep the most promissing candidates in the beam */
TopK(score, scoreTopK, index, -1, beamSize);
CopyValues(index, indexCPU);
CopyValues(index, preID);
/* "preID" represents the id (or the offset) of the previous state used to make the current
......@@ -323,14 +327,12 @@ void T2TSearch::Generate(T2TStateBundle * beam)
InitTensor(&score, &scoreTopK);
CopyValues(scoreTopK, score);
/* CPU data (TODO: remove GPU->CPU data copy!!!) */
XTensor indexCPU;
InitTensor(&indexCPU, index.order, index.dimSize, index.dataType, index.denseRatio, -1);
CopyValues(index, indexCPU);
for (int i = 0; i < indexCPU.unitNum; i++)
indexCPU.SetInt(i * stride + indexCPU.GetInt(i), i);
/* CPU data (TODO: remove GPU->CPU data copy!!!) */
for (int i = 0; i < indexCPU.unitNum; i += beamSize){
for (int j = 0; j < beamSize; j++) {
indexCPU.SetInt(i * stride + indexCPU.GetInt(i + j), i + j);
}
}
CheckNTErrors(XTensor::IsSameShaped(&prob, &probPath), "Wrong tensor shape!");
......@@ -356,7 +358,6 @@ void T2TSearch::Generate(T2TStateBundle * beam)
probPath.Reshape(order, dims);
probPathTopK.Reshape(order, dimsTopK);
prob.Reshape(order, dims);
probTopK.Reshape(order, dimsTopK);
......
......@@ -117,7 +117,6 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
seacher.Search(model, &batchEnc, &paddingEnc, &output, &score);
Dump(ofile, &output);
//score.Dump(ofile, "score:");
float prob = 0;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论