Commit 47c31021 by xiaotong

fill the heap with incomplete hypotheses

parent f34f70ed
...@@ -125,9 +125,12 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -125,9 +125,12 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
Collect(next); Collect(next);
} }
delete[] states; /* fill the heap with imcomplete hypotheses if neccesary */
FillHeap(&states[maxLength]);
Dump(output); Dump(output);
delete[] states;
} }
/* /*
...@@ -334,13 +337,20 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -334,13 +337,20 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
for(int i = 0; i < beam->stateNum; i++){ for(int i = 0; i < beam->stateNum; i++){
T2TState & state = states[i]; T2TState & state = states[i];
int offset = id.GetInt(i);
T2TState * last = prev->states + offset;
CheckNTErrors(offset >= 0, "Wrong state index!");
/* pointer to the previous state */ /* pointer to the previous state */
if(prev->isStart) if (prev->isStart) {
state.last = NULL; state.last = NULL;
state.pid = offset;
}
else{ else{
int offset = id.GetInt(i); state.last = last;
state.last = prev->states + offset; state.pid = state.last->pid;
CheckNTErrors(offset >= 0 && offset < prev->stateNum, "Wrong state index!"); CheckNTErrors(offset < prev->stateNum, "Wrong state index!");
} }
/* scores */ /* scores */
...@@ -376,7 +386,6 @@ void T2TSearch::Collect(T2TStateBundle * beam) ...@@ -376,7 +386,6 @@ void T2TSearch::Collect(T2TStateBundle * beam)
for (int i = 0; i < beam->stateNum; i++) { for (int i = 0; i < beam->stateNum; i++) {
T2TState & state = states[i]; T2TState & state = states[i];
state.pid = state.last->pid;
CheckNTErrors(state.pid >= 0 && state.pid < batchSize, CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!"); "Invalid sample id!");
...@@ -387,6 +396,32 @@ void T2TSearch::Collect(T2TStateBundle * beam) ...@@ -387,6 +396,32 @@ void T2TSearch::Collect(T2TStateBundle * beam)
} }
/* /*
fill the hypotheis heap with incomplete hypothses
>> beam - the beam that keeps a number of states (final)
*/
void T2TSearch::FillHeap(T2TStateBundle * beam)
{
bool * emptyFlags = new bool[batchSize];
for (int i = 0; i < batchSize; i++)
emptyFlags[i] = (fullHypos[i].Count() == 0);
T2TState * states = beam->states;
for (int i = 0; i < beam->stateNum; i++) {
T2TState & state = states[i];
CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!");
/* we push the imcomplete hypothesis into the heap */
if (emptyFlags[state.pid] && state.isEnd == 0)
fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore));
}
delete[] emptyFlags;
}
/*
save the output sequences in a tensor save the output sequences in a tensor
>> output - output sequences (for return) >> output - output sequences (for return)
*/ */
...@@ -404,7 +439,7 @@ void T2TSearch::Dump(XTensor * output) ...@@ -404,7 +439,7 @@ void T2TSearch::Dump(XTensor * output)
XHeap<MIN_HEAP, float> &heap = fullHypos[h]; XHeap<MIN_HEAP, float> &heap = fullHypos[h];
/* for each output in the beam */ /* for each output in the beam */
for(int i = 0; i < beamSize; i++){ for(int i = 0; i < beamSize && heap.Count() > 0; i++){
T2TState * state = (T2TState *)heap.Pop().index; T2TState * state = (T2TState *)heap.Pop().index;
int count = 0; int count = 0;
......
...@@ -87,6 +87,9 @@ public: ...@@ -87,6 +87,9 @@ public:
/* collect hypotheses with ending symbol */ /* collect hypotheses with ending symbol */
void Collect(T2TStateBundle * beam); void Collect(T2TStateBundle * beam);
/* fill the hypotheis heap with incomplete hypothses */
void FillHeap(T2TStateBundle * beam);
/* save the output sequences in a tensor */ /* save the output sequences in a tensor */
void Dump(XTensor * output); void Dump(XTensor * output);
......
...@@ -65,6 +65,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -65,6 +65,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
int wordCount = 0; int wordCount = 0;
int wordCountTotal = 0; int wordCountTotal = 0;
int sentCount = 0; int sentCount = 0;
int batchCount = 0;
float loss = 0; float loss = 0;
/* data files */ /* data files */
...@@ -118,9 +119,18 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -118,9 +119,18 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
float prob = 0; float prob = 0;
loss += -prob; loss += -prob;
wc = batchEnc.GetDim(-1);
wordCount += wc; wordCount += wc;
wordCountTotal += wc; wordCountTotal += wc;
sentCount += 1; sentCount += batchEnc.GetDim(-2);
batchCount += 1;
if (batchCount % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr,
"[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n",
elapsed, sentCount, wordCount);
}
} }
fclose(file); fclose(file);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论