Commit ac6ed3a1 by huchi

add support for beam search (tested with beam size 2)

parent 99097e41
...@@ -101,25 +101,10 @@ void T2TPredictor::Create(T2TModel* model, XTensor* top, const XTensor* input, i ...@@ -101,25 +101,10 @@ void T2TPredictor::Create(T2TModel* model, XTensor* top, const XTensor* input, i
InitTensorV2(&state->nstep, input->order, dims, X_FLOAT, 1.0F, input->devID); InitTensorV2(&state->nstep, input->order, dims, X_FLOAT, 1.0F, input->devID);
InitTensorV2(&state->endMark, input->order, dims, X_INT, 1.0F, input->devID); InitTensorV2(&state->endMark, input->order, dims, X_INT, 1.0F, input->devID);
/*float* data = new float[state->probPath.unitNum]; state->probPath.SetZeroAll();
for (int i = 0; i < state->probPath.unitNum; ++i) {
data[i] = -1e20F;
if (i % beamSize == 0)
data[i] = 0;
}
state->probPath.SetData(data, state->probPath.unitNum);
delete[] data;*/
SetDataFixed(state->probPath, -1e9F);
for (int i = 0; i < state->probPath.unitNum; ++i) {
if (i % beamSize == 0)
state->probPath.Set(0.0F, i);
}
state->nstep.SetZeroAll(); state->nstep.SetZeroAll();
state->endMark.SetZeroAll(); state->endMark.SetZeroAll();
state->stateNum = 0; state->stateNum = 0;
} }
......
...@@ -285,11 +285,11 @@ void T2TSearch::Generate(T2TStateBundle* beam) ...@@ -285,11 +285,11 @@ void T2TSearch::Generate(T2TStateBundle* beam)
XTensor mask; XTensor mask;
InitTensorV2(&mask, 1, dimMask, X_FLOAT, 1.0F, -1); InitTensorV2(&mask, 1, dimMask, X_FLOAT, 1.0F, -1);
mask.SetZeroAll(); mask.SetZeroAll();
mask.Set1D(-1e20F, 0); mask.Set1D(-1e9F, 0);
mask.Set1D(-1e20F, 1); mask.Set1D(-1e9F, 1);
mask.SetDevice(score.devID, score.mem); mask.SetDevice(score.devID, score.mem);
//_SumDim(&score, &mask, 2); _SumDim(&score, &mask, 2);
score.Reshape(order, dimsBeam); score.Reshape(order, dimsBeam);
/* keep the most promissing candidates in the beam */ /* keep the most promissing candidates in the beam */
...@@ -324,9 +324,6 @@ void T2TSearch::Generate(T2TStateBundle* beam) ...@@ -324,9 +324,6 @@ void T2TSearch::Generate(T2TStateBundle* beam)
for (int j = 0; j < beamSize; j++) for (int j = 0; j < beamSize; j++)
indexGPU.SetInt(i * stride + indexGPU.GetInt(i + j), i + j); indexGPU.SetInt(i * stride + indexGPU.GetInt(i + j), i + j);
} }
/*for (int i = 0; i < indexGPU.unitNum; i++) {
indexGPU.SetInt(i + indexGPU.GetInt(i), i);
}*/
CheckNTErrors(IsSameShaped(prob, probPath), "Wrong tensor shape!"); CheckNTErrors(IsSameShaped(prob, probPath), "Wrong tensor shape!");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论