Commit 94495c05 by xiaotong

bug fixes

parent 16ab02c5
...@@ -110,6 +110,8 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input ...@@ -110,6 +110,8 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input
state->probPath.SetZeroAll(); state->probPath.SetZeroAll();
state->nstep.SetZeroAll(); state->nstep.SetZeroAll();
state->endMark.SetZeroAll(); state->endMark.SetZeroAll();
state->stateNum = 0;
} }
/* /*
......
...@@ -31,6 +31,11 @@ namespace transformer ...@@ -31,6 +31,11 @@ namespace transformer
/* constructor */ /* constructor */
T2TSearch::T2TSearch() T2TSearch::T2TSearch()
{ {
alpha = 0;
maxLength = 0;
beamSize = 0;
batchSize = 0;
endSymbolNum = 0;
fullHypos = NULL; fullHypos = NULL;
endSymbols = new int[32]; endSymbols = new int[32];
} }
...@@ -52,6 +57,7 @@ initialize the model ...@@ -52,6 +57,7 @@ initialize the model
void T2TSearch::Init(int argc, char ** argv) void T2TSearch::Init(int argc, char ** argv)
{ {
LoadParamInt(argc, argv, "beamsize", &beamSize, 1); LoadParamInt(argc, argv, "beamsize", &beamSize, 1);
LoadParamInt(argc, argv, "batchsize", &batchSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F); LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F);
LoadParamInt(argc, argv, "endid", endSymbols, -1); LoadParamInt(argc, argv, "endid", endSymbols, -1);
...@@ -72,6 +78,10 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -72,6 +78,10 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
XTensor encoding; XTensor encoding;
T2TPredictor predictor; T2TPredictor predictor;
CheckNTErrors(endSymbolNum > 0, "The search class is not initialized!");
Prepare(input->unitNum/input->GetDim(-1), beamSize);
/* encoder mask */ /* encoder mask */
model->MakeMTMaskEnc(*input, *padding, maskEnc); model->MakeMTMaskEnc(*input, *padding, maskEnc);
...@@ -83,7 +93,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -83,7 +93,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
maxLength = input->GetDim(-2) * 2; maxLength = input->GetDim(-2) * 2;
CheckNTErrors(maxLength > 0, "no max length specified!"); CheckNTErrors(maxLength > 0, "no max length specified!");
T2TStateBundle * states = new T2TStateBundle[maxLength]; T2TStateBundle * states = new T2TStateBundle[maxLength + 1];
T2TStateBundle * first = states; T2TStateBundle * first = states;
/* create the first state */ /* create the first state */
...@@ -110,6 +120,9 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -110,6 +120,9 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* expand the search graph */ /* expand the search graph */
Expand(cur, next); Expand(cur, next);
/* push complete hypotheses into the heap */
Collect(next);
} }
delete[] states; delete[] states;
...@@ -241,8 +254,6 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -241,8 +254,6 @@ void T2TSearch::Generate(T2TStateBundle * beam)
TopK(score, scoreTopK, index, -1, beamSize); TopK(score, scoreTopK, index, -1, beamSize);
CopyValues(index, preID); CopyValues(index, preID);
preID.Dump(stderr, "preid:");
int sizeVocab = score.GetDim(-1); int sizeVocab = score.GetDim(-1);
...@@ -258,8 +269,6 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -258,8 +269,6 @@ void T2TSearch::Generate(T2TStateBundle * beam)
in the vocabulary by dividing it with vocab-size and computing the remainder. */ in the vocabulary by dividing it with vocab-size and computing the remainder. */
Mod(index, sizeVocab); Mod(index, sizeVocab);
preID.Dump(stderr, "preid:");
score.Reshape(order, dims); score.Reshape(order, dims);
/* we keep the top-k scores */ /* we keep the top-k scores */
...@@ -315,8 +324,6 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -315,8 +324,6 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
CopyValues(probRef, prob); CopyValues(probRef, prob);
CopyValues(probPathRef, probPath); CopyValues(probPathRef, probPath);
CopyValues(predictionRef, prediction); CopyValues(predictionRef, prediction);
idRef.Dump(stderr, "idref:");
CheckNTErrors(beam->stateNum == id.unitNum, "Errors occur in counting!"); CheckNTErrors(beam->stateNum == id.unitNum, "Errors occur in counting!");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论