Commit 47c31021 by xiaotong

fill the heap with incomplete hypotheses

parent f34f70ed
......@@ -124,10 +124,13 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* push complete hypotheses into the heap */
Collect(next);
}
delete[] states;
/* fill the heap with imcomplete hypotheses if neccesary */
FillHeap(&states[maxLength]);
Dump(output);
delete[] states;
}
/*
......@@ -333,14 +336,21 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
it needs much more coding work and the speed-up is not obvious. */
for(int i = 0; i < beam->stateNum; i++){
T2TState & state = states[i];
int offset = id.GetInt(i);
T2TState * last = prev->states + offset;
CheckNTErrors(offset >= 0, "Wrong state index!");
/* pointer to the previous state */
if(prev->isStart)
state.last = NULL;
if (prev->isStart) {
state.last = NULL;
state.pid = offset;
}
else{
int offset = id.GetInt(i);
state.last = prev->states + offset;
CheckNTErrors(offset >= 0 && offset < prev->stateNum, "Wrong state index!");
state.last = last;
state.pid = state.last->pid;
CheckNTErrors(offset < prev->stateNum, "Wrong state index!");
}
/* scores */
......@@ -376,7 +386,6 @@ void T2TSearch::Collect(T2TStateBundle * beam)
for (int i = 0; i < beam->stateNum; i++) {
T2TState & state = states[i];
state.pid = state.last->pid;
CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!");
......@@ -387,6 +396,32 @@ void T2TSearch::Collect(T2TStateBundle * beam)
}
/*
fill the hypotheis heap with incomplete hypothses
>> beam - the beam that keeps a number of states (final)
*/
void T2TSearch::FillHeap(T2TStateBundle * beam)
{
bool * emptyFlags = new bool[batchSize];
for (int i = 0; i < batchSize; i++)
emptyFlags[i] = (fullHypos[i].Count() == 0);
T2TState * states = beam->states;
for (int i = 0; i < beam->stateNum; i++) {
T2TState & state = states[i];
CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!");
/* we push the imcomplete hypothesis into the heap */
if (emptyFlags[state.pid] && state.isEnd == 0)
fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore));
}
delete[] emptyFlags;
}
/*
save the output sequences in a tensor
>> output - output sequences (for return)
*/
......@@ -404,7 +439,7 @@ void T2TSearch::Dump(XTensor * output)
XHeap<MIN_HEAP, float> &heap = fullHypos[h];
/* for each output in the beam */
for(int i = 0; i < beamSize; i++){
for(int i = 0; i < beamSize && heap.Count() > 0; i++){
T2TState * state = (T2TState *)heap.Pop().index;
int count = 0;
......
......@@ -87,6 +87,9 @@ public:
/* collect hypotheses with ending symbol */
void Collect(T2TStateBundle * beam);
/* fill the hypotheis heap with incomplete hypothses */
void FillHeap(T2TStateBundle * beam);
/* save the output sequences in a tensor */
void Dump(XTensor * output);
......
......@@ -65,6 +65,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
int wordCount = 0;
int wordCountTotal = 0;
int sentCount = 0;
int batchCount = 0;
float loss = 0;
/* data files */
......@@ -118,9 +119,18 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
float prob = 0;
loss += -prob;
wc = batchEnc.GetDim(-1);
wordCount += wc;
wordCountTotal += wc;
sentCount += 1;
sentCount += batchEnc.GetDim(-2);
batchCount += 1;
if (batchCount % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr,
"[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n",
elapsed, sentCount, wordCount);
}
}
fclose(file);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论