Commit 70d480ed by xiaotong

fix the bug in heap

parent 47c31021
...@@ -255,6 +255,10 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -255,6 +255,10 @@ void T2TSearch::Generate(T2TStateBundle * beam)
/* keep the most promissing candidates in the beam */ /* keep the most promissing candidates in the beam */
TopK(score, scoreTopK, index, -1, beamSize); TopK(score, scoreTopK, index, -1, beamSize);
score.Dump(stderr, "score:");
scoreTopK.Dump(stderr, "topk:");
index.Dump(stderr, "index:");
CopyValues(index, preID); CopyValues(index, preID);
...@@ -345,7 +349,7 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -345,7 +349,7 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
/* pointer to the previous state */ /* pointer to the previous state */
if (prev->isStart) { if (prev->isStart) {
state.last = NULL; state.last = NULL;
state.pid = offset; state.pid = i;
} }
else{ else{
state.last = last; state.last = last;
...@@ -396,7 +400,7 @@ void T2TSearch::Collect(T2TStateBundle * beam) ...@@ -396,7 +400,7 @@ void T2TSearch::Collect(T2TStateBundle * beam)
} }
/* /*
fill the hypotheis heap with incomplete hypothses fill the hypotheis heap with incomplete hypotheses
>> beam - the beam that keeps a number of states (final) >> beam - the beam that keeps a number of states (final)
*/ */
void T2TSearch::FillHeap(T2TStateBundle * beam) void T2TSearch::FillHeap(T2TStateBundle * beam)
...@@ -411,7 +415,7 @@ void T2TSearch::FillHeap(T2TStateBundle * beam) ...@@ -411,7 +415,7 @@ void T2TSearch::FillHeap(T2TStateBundle * beam)
T2TState & state = states[i]; T2TState & state = states[i];
CheckNTErrors(state.pid >= 0 && state.pid < batchSize, CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!"); "Invalid sample id!");
/* we push the imcomplete hypothesis into the heap */ /* we push the imcomplete hypothesis into the heap */
if (emptyFlags[state.pid] && state.isEnd == 0) if (emptyFlags[state.pid] && state.isEnd == 0)
...@@ -452,7 +456,7 @@ void T2TSearch::Dump(XTensor * output) ...@@ -452,7 +456,7 @@ void T2TSearch::Dump(XTensor * output)
/* dump the sentence to the output tensor */ /* dump the sentence to the output tensor */
for(int w = 0; w < count; w++) for(int w = 0; w < count; w++)
output->Set3DInt(words[count - w - 1], h, i, w); output->Set3DInt(words[count - w - 1], h, beamSize - i - 1, w);
} }
} }
......
...@@ -102,10 +102,24 @@ _XINLINE_ HeapNode<T> XHeap<hType, T>::End() ...@@ -102,10 +102,24 @@ _XINLINE_ HeapNode<T> XHeap<hType, T>::End()
template<HeapType hType, typename T> template<HeapType hType, typename T>
_XINLINE_ void XHeap<hType, T>::Push(HeapNode<T> node) _XINLINE_ void XHeap<hType, T>::Push(HeapNode<T> node)
{ {
//CheckNTErrors((count < size), "Heap is full!"); if (count < size) {
items[count] = node; items[count] = node;
Up(count); Up(count);
count++; count++;
}
else if(count == size){
HeapNode<T> & item0 = items[0];
if (hType == MIN_HEAP && item0.value >= node.value)
return;
else if (hType == MAX_HEAP && item0.value <= node.value)
return;
items[0] = node;
Down(0);
}
else {
ShowNTErrors("Overflow of the heap!");
}
} }
/* replace the top-most item and update the heap */ /* replace the top-most item and update the heap */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论