Commit 9a4446f1 by xiaotong

stoping critiera

parent a52ba88e
......@@ -141,6 +141,10 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* push complete hypotheses into the heap */
Collect(next);
/* stop searching when all hypotheses are completed */
if(IsAllCompleted(next))
break;
}
/* fill the heap with imcomplete hypotheses if neccesary */
......@@ -262,15 +266,15 @@ void T2TSearch::Generate(T2TStateBundle * beam)
XTensor &prob = beam->prob;
int order = score.order;
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
for (int i = 0; i < order; i++) {
dims[i] = score.GetDim(i);
dimsBeam[i] = score.GetDim(i);
dimsTopK[i] = score.GetDim(i);
}
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
int sizeVocab = score.GetDim(-1);
int stride = score.GetDim(-1);
......@@ -574,6 +578,23 @@ void T2TSearch::SetEnd(const int * tokens, const int tokenNum)
endSymbolNum = tokenNum;
}
/*
check whether all hypotheses are completed
>> beam - the beam that keeps the searching states
*/
bool T2TSearch::IsAllCompleted(T2TStateBundle * beam)
{
T2TState * states = beam->states;
for (int i = 0; i < beam->stateNum; i++) {
T2TState & state = states[i];
if(!state.isCompleted)
return false;
}
return true;
}
/*
make a mask to prevent duplicated entries in beam expansion for the first position
>> beam - the beam that keeps the searching states
......
......@@ -102,6 +102,9 @@ public:
/* set end symbols for search */
void SetEnd(const int * tokens, const int tokenNum);
/* check whether all hypotheses are completed */
bool IsAllCompleted(T2TStateBundle * beam);
/* make a mask to prevent duplicated entries in beam expansion for the first position */
XTensor MakeFirstMask(T2TStateBundle * beam);
};
......
......@@ -129,7 +129,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
if (batchCount % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr,
"[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n",
"[INFO] elapsed=%.1fs, sent=%d, sword=%d\n",
elapsed, sentCount, wordCount);
}
}
......@@ -141,8 +141,8 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n",
elapsed,wordCountTotal, exp(loss/wordCount));
XPRINT4(0, stderr, "[INFO] test finished (took %.1fs, word=%d, sent=%d, and ppl=%.3f)\n",
elapsed,wordCountTotal, sentCount, exp(loss/wordCount));
}
/*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论