Commit 94495c05 by xiaotong

bug fixes

parent 16ab02c5
......@@ -110,6 +110,8 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input
state->probPath.SetZeroAll();
state->nstep.SetZeroAll();
state->endMark.SetZeroAll();
state->stateNum = 0;
}
/*
......
......@@ -31,6 +31,11 @@ namespace transformer
/* constructor */
T2TSearch::T2TSearch()
{
alpha = 0;
maxLength = 0;
beamSize = 0;
batchSize = 0;
endSymbolNum = 0;
fullHypos = NULL;
endSymbols = new int[32];
}
......@@ -52,6 +57,7 @@ initialize the model
void T2TSearch::Init(int argc, char ** argv)
{
LoadParamInt(argc, argv, "beamsize", &beamSize, 1);
LoadParamInt(argc, argv, "batchsize", &batchSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F);
LoadParamInt(argc, argv, "endid", endSymbols, -1);
......@@ -72,6 +78,10 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
XTensor encoding;
T2TPredictor predictor;
CheckNTErrors(endSymbolNum > 0, "The search class is not initialized!");
Prepare(input->unitNum/input->GetDim(-1), beamSize);
/* encoder mask */
model->MakeMTMaskEnc(*input, *padding, maskEnc);
......@@ -83,7 +93,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
maxLength = input->GetDim(-2) * 2;
CheckNTErrors(maxLength > 0, "no max length specified!");
T2TStateBundle * states = new T2TStateBundle[maxLength];
T2TStateBundle * states = new T2TStateBundle[maxLength + 1];
T2TStateBundle * first = states;
/* create the first state */
......@@ -110,6 +120,9 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* expand the search graph */
Expand(cur, next);
/* push complete hypotheses into the heap */
Collect(next);
}
delete[] states;
......@@ -241,8 +254,6 @@ void T2TSearch::Generate(T2TStateBundle * beam)
TopK(score, scoreTopK, index, -1, beamSize);
CopyValues(index, preID);
preID.Dump(stderr, "preid:");
int sizeVocab = score.GetDim(-1);
......@@ -258,8 +269,6 @@ void T2TSearch::Generate(T2TStateBundle * beam)
in the vocabulary by dividing it with vocab-size and computing the remainder. */
Mod(index, sizeVocab);
preID.Dump(stderr, "preid:");
score.Reshape(order, dims);
/* we keep the top-k scores */
......@@ -315,8 +324,6 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
CopyValues(probRef, prob);
CopyValues(probPathRef, probPath);
CopyValues(predictionRef, prediction);
idRef.Dump(stderr, "idref:");
CheckNTErrors(beam->stateNum == id.unitNum, "Errors occur in counting!");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论