Commit e184cac4 by xiaotong

fix the bug of incorrect indexing in scoring

parent 63f0f6a7
...@@ -254,8 +254,8 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -254,8 +254,8 @@ void T2TSearch::Generate(T2TStateBundle * beam)
XTensor &probPath = beam->probPath; XTensor &probPath = beam->probPath;
int order = score.order; int order = score.order;
CheckNTErrors(order >= 2, "The tensor must be of order 2 or larger."); CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 2] % beamSize == 0, "Wrong dimension size!"); CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
for (int i = 0; i < order; i++) { for (int i = 0; i < order; i++) {
dims[i] = score.GetDim(i); dims[i] = score.GetDim(i);
...@@ -263,9 +263,12 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -263,9 +263,12 @@ void T2TSearch::Generate(T2TStateBundle * beam)
dimsTopK[i] = score.GetDim(i); dimsTopK[i] = score.GetDim(i);
} }
dimsBeam[order - 2] /= beamSize; int sizeVocab = score.GetDim(-1);
int stride = score.GetDim(-1);
dimsBeam[order - 3] /= beamSize;
dimsBeam[order - 1] *= beamSize; dimsBeam[order - 1] *= beamSize;
dimsTopK[order - 2] = dimsBeam[order - 2]; dimsTopK[order - 3] = dimsBeam[order - 3];
dimsTopK[order - 1] = beamSize; dimsTopK[order - 1] = beamSize;
InitTensor(&scoreTopK, order, dimsTopK, score.dataType, InitTensor(&scoreTopK, order, dimsTopK, score.dataType,
...@@ -286,9 +289,6 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -286,9 +289,6 @@ void T2TSearch::Generate(T2TStateBundle * beam)
CopyValues(index, preID); CopyValues(index, preID);
int sizeVocab = score.GetDim(-1);
int stride = score.GetDim(-1);
/* "preID" represents the id (or the offset) of previous state used to make the current /* "preID" represents the id (or the offset) of previous state used to make the current
hypothesis. Note that we reshape the "score" tensor into a matrix where each hypothesis. Note that we reshape the "score" tensor into a matrix where each
row means a previous state. The column number is size-of-beam * vocab-size. We, row means a previous state. The column number is size-of-beam * vocab-size. We,
...@@ -327,6 +327,8 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -327,6 +327,8 @@ void T2TSearch::Generate(T2TStateBundle * beam)
order = probPath.order; order = probPath.order;
probPath.Reshape(1, probPath.unitNum); probPath.Reshape(1, probPath.unitNum);
probPathTopK.Reshape(1, probPathTopK.unitNum); probPathTopK.Reshape(1, probPathTopK.unitNum);
indexCPU.Dump(stderr, "indexCPU:");
_Gather(&probPath, &probPathTopK, probPathTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum); _Gather(&probPath, &probPathTopK, probPathTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum);
...@@ -391,14 +393,15 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -391,14 +393,15 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
T2TState & state = states[i + j]; T2TState & state = states[i + j];
int offset = id.GetInt(i + j); int offset = id.GetInt(i + j);
T2TState * last = prev->states + i * beamSize + offset; int pid = i / beamSize;
T2TState * last = prev->states + pid * beamSize + offset;
CheckNTErrors(offset >= 0, "Wrong state index!"); CheckNTErrors(offset >= 0, "Wrong state index!");
/* pointer to the previous state */ /* pointer to the previous state */
if (prev->isStart) { if (prev->isStart) {
state.last = NULL; state.last = NULL;
state.pid = i; state.pid = pid;
state.nstep = 0; state.nstep = 0;
} }
else{ else{
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论