Commit a3ae84b5 by xiaotong

fix the bug of wrong indexing in beam search

parent dea97945
...@@ -267,11 +267,13 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -267,11 +267,13 @@ void T2TSearch::Generate(T2TStateBundle * beam)
int dimsTopK[MAX_TENSOR_DIM_NUM]; int dimsTopK[MAX_TENSOR_DIM_NUM];
XTensor scoreTopK; XTensor scoreTopK;
XTensor indexCPU;
XTensor &score = beam->modelScore; XTensor &score = beam->modelScore;
XTensor &index = beam->prediction; XTensor &index = beam->prediction;
XTensor &preID = beam->preID; XTensor &preID = beam->preID;
XTensor &probPath = beam->probPath; XTensor &probPath = beam->probPath;
XTensor &prob = beam->prob; XTensor &prob = beam->prob;
int order = score.order; int order = score.order;
for (int i = 0; i < order; i++) { for (int i = 0; i < order; i++) {
...@@ -295,6 +297,8 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -295,6 +297,8 @@ void T2TSearch::Generate(T2TStateBundle * beam)
1.0F, score.devID, score.mem); 1.0F, score.devID, score.mem);
InitTensor(&index, order, dimsTopK, X_INT, InitTensor(&index, order, dimsTopK, X_INT,
1.0F, score.devID, score.mem); 1.0F, score.devID, score.mem);
InitTensor(&indexCPU, order, dimsTopK, X_INT,
1.0F, -1);
InitTensor(&preID, order, dimsTopK, X_INT, InitTensor(&preID, order, dimsTopK, X_INT,
1.0F, -1); 1.0F, -1);
...@@ -302,7 +306,7 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -302,7 +306,7 @@ void T2TSearch::Generate(T2TStateBundle * beam)
/* keep the most promissing candidates in the beam */ /* keep the most promissing candidates in the beam */
TopK(score, scoreTopK, index, -1, beamSize); TopK(score, scoreTopK, index, -1, beamSize);
CopyValues(index, indexCPU);
CopyValues(index, preID); CopyValues(index, preID);
/* "preID" represents the id (or the offset) of the previous state used to make the current /* "preID" represents the id (or the offset) of the previous state used to make the current
...@@ -324,13 +328,11 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -324,13 +328,11 @@ void T2TSearch::Generate(T2TStateBundle * beam)
CopyValues(scoreTopK, score); CopyValues(scoreTopK, score);
/* CPU data (TODO: remove GPU->CPU data copy!!!) */ /* CPU data (TODO: remove GPU->CPU data copy!!!) */
XTensor indexCPU; for (int i = 0; i < indexCPU.unitNum; i += beamSize){
InitTensor(&indexCPU, index.order, index.dimSize, index.dataType, index.denseRatio, -1); for (int j = 0; j < beamSize; j++) {
CopyValues(index, indexCPU); indexCPU.SetInt(i * stride + indexCPU.GetInt(i + j), i + j);
}
}
for (int i = 0; i < indexCPU.unitNum; i++)
indexCPU.SetInt(i * stride + indexCPU.GetInt(i), i);
CheckNTErrors(XTensor::IsSameShaped(&prob, &probPath), "Wrong tensor shape!"); CheckNTErrors(XTensor::IsSameShaped(&prob, &probPath), "Wrong tensor shape!");
...@@ -356,7 +358,6 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -356,7 +358,6 @@ void T2TSearch::Generate(T2TStateBundle * beam)
probPath.Reshape(order, dims); probPath.Reshape(order, dims);
probPathTopK.Reshape(order, dimsTopK); probPathTopK.Reshape(order, dimsTopK);
prob.Reshape(order, dims); prob.Reshape(order, dims);
probTopK.Reshape(order, dimsTopK); probTopK.Reshape(order, dimsTopK);
......
...@@ -117,7 +117,6 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -117,7 +117,6 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
seacher.Search(model, &batchEnc, &paddingEnc, &output, &score); seacher.Search(model, &batchEnc, &paddingEnc, &output, &score);
Dump(ofile, &output); Dump(ofile, &output);
//score.Dump(ofile, "score:");
float prob = 0; float prob = 0;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论