Commit e184cac4 by xiaotong

fix the bug of incorrect indexing in scoring

parent 63f0f6a7
......@@ -254,8 +254,8 @@ void T2TSearch::Generate(T2TStateBundle * beam)
XTensor &probPath = beam->probPath;
int order = score.order;
CheckNTErrors(order >= 2, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 2] % beamSize == 0, "Wrong dimension size!");
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
for (int i = 0; i < order; i++) {
dims[i] = score.GetDim(i);
......@@ -263,9 +263,12 @@ void T2TSearch::Generate(T2TStateBundle * beam)
dimsTopK[i] = score.GetDim(i);
}
dimsBeam[order - 2] /= beamSize;
int sizeVocab = score.GetDim(-1);
int stride = score.GetDim(-1);
dimsBeam[order - 3] /= beamSize;
dimsBeam[order - 1] *= beamSize;
dimsTopK[order - 2] = dimsBeam[order - 2];
dimsTopK[order - 3] = dimsBeam[order - 3];
dimsTopK[order - 1] = beamSize;
InitTensor(&scoreTopK, order, dimsTopK, score.dataType,
......@@ -286,9 +289,6 @@ void T2TSearch::Generate(T2TStateBundle * beam)
CopyValues(index, preID);
int sizeVocab = score.GetDim(-1);
int stride = score.GetDim(-1);
/* "preID" represents the id (or the offset) of previous state used to make the current
hypothesis. Note that we reshape the "score" tensor into a matrix where each
row means a previous state. The column number is size-of-beam * vocab-size. We,
......@@ -327,6 +327,8 @@ void T2TSearch::Generate(T2TStateBundle * beam)
order = probPath.order;
probPath.Reshape(1, probPath.unitNum);
probPathTopK.Reshape(1, probPathTopK.unitNum);
indexCPU.Dump(stderr, "indexCPU:");
_Gather(&probPath, &probPathTopK, probPathTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum);
......@@ -391,14 +393,15 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
T2TState & state = states[i + j];
int offset = id.GetInt(i + j);
T2TState * last = prev->states + i * beamSize + offset;
int pid = i / beamSize;
T2TState * last = prev->states + pid * beamSize + offset;
CheckNTErrors(offset >= 0, "Wrong state index!");
/* pointer to the previous state */
if (prev->isStart) {
state.last = NULL;
state.pid = i;
state.pid = pid;
state.nstep = 0;
}
else{
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论