Commit 9a4446f1 by xiaotong

stoping critiera

parent a52ba88e
...@@ -141,6 +141,10 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -141,6 +141,10 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* push complete hypotheses into the heap */ /* push complete hypotheses into the heap */
Collect(next); Collect(next);
/* stop searching when all hypotheses are completed */
if(IsAllCompleted(next))
break;
} }
/* fill the heap with imcomplete hypotheses if neccesary */ /* fill the heap with imcomplete hypotheses if neccesary */
...@@ -262,15 +266,15 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -262,15 +266,15 @@ void T2TSearch::Generate(T2TStateBundle * beam)
XTensor &prob = beam->prob; XTensor &prob = beam->prob;
int order = score.order; int order = score.order;
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
for (int i = 0; i < order; i++) { for (int i = 0; i < order; i++) {
dims[i] = score.GetDim(i); dims[i] = score.GetDim(i);
dimsBeam[i] = score.GetDim(i); dimsBeam[i] = score.GetDim(i);
dimsTopK[i] = score.GetDim(i); dimsTopK[i] = score.GetDim(i);
} }
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
int sizeVocab = score.GetDim(-1); int sizeVocab = score.GetDim(-1);
int stride = score.GetDim(-1); int stride = score.GetDim(-1);
...@@ -574,6 +578,23 @@ void T2TSearch::SetEnd(const int * tokens, const int tokenNum) ...@@ -574,6 +578,23 @@ void T2TSearch::SetEnd(const int * tokens, const int tokenNum)
endSymbolNum = tokenNum; endSymbolNum = tokenNum;
} }
/*
check whether all hypotheses are completed
>> beam - the beam that keeps the searching states
*/
bool T2TSearch::IsAllCompleted(T2TStateBundle * beam)
{
T2TState * states = beam->states;
for (int i = 0; i < beam->stateNum; i++) {
T2TState & state = states[i];
if(!state.isCompleted)
return false;
}
return true;
}
/* /*
make a mask to prevent duplicated entries in beam expansion for the first position make a mask to prevent duplicated entries in beam expansion for the first position
>> beam - the beam that keeps the searching states >> beam - the beam that keeps the searching states
......
...@@ -102,6 +102,9 @@ public: ...@@ -102,6 +102,9 @@ public:
/* set end symbols for search */ /* set end symbols for search */
void SetEnd(const int * tokens, const int tokenNum); void SetEnd(const int * tokens, const int tokenNum);
/* check whether all hypotheses are completed */
bool IsAllCompleted(T2TStateBundle * beam);
/* make a mask to prevent duplicated entries in beam expansion for the first position */ /* make a mask to prevent duplicated entries in beam expansion for the first position */
XTensor MakeFirstMask(T2TStateBundle * beam); XTensor MakeFirstMask(T2TStateBundle * beam);
}; };
......
...@@ -129,7 +129,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -129,7 +129,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
if (batchCount % 1 == 0) { if (batchCount % 1 == 0) {
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, XPRINT3(0, stderr,
"[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n", "[INFO] elapsed=%.1fs, sent=%d, sword=%d\n",
elapsed, sentCount, wordCount); elapsed, sentCount, wordCount);
} }
} }
...@@ -141,8 +141,8 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -141,8 +141,8 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n", XPRINT4(0, stderr, "[INFO] test finished (took %.1fs, word=%d, sent=%d, and ppl=%.3f)\n",
elapsed,wordCountTotal, exp(loss/wordCount)); elapsed,wordCountTotal, sentCount, exp(loss/wordCount));
} }
/* /*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论