Commit 70f3926a by xiaotong

update the code of heap

parent d489680b
...@@ -289,6 +289,7 @@ void T2TSearch::Collect(T2TStateBundle * beam) ...@@ -289,6 +289,7 @@ void T2TSearch::Collect(T2TStateBundle * beam)
"Invalid sample id!"); "Invalid sample id!");
if(IsEnd(state.prediction)){ if(IsEnd(state.prediction)){
fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore));
state.isEnd = 1; state.isEnd = 1;
} }
else else
......
...@@ -37,20 +37,20 @@ get the top-k items along a given dimension ...@@ -37,20 +37,20 @@ get the top-k items along a given dimension
*/ */
void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
{ {
CheckNTErrors((a->unitSize == b->unitSize), "Unmatched input tensors!"); CheckNTErrors(a->unitSize == b->unitSize, "Unmatched input tensors!");
CheckNTErrors((a->order == b->order), "Unmatched input tensors!"); CheckNTErrors(a->order == b->order, "Unmatched input tensors!");
CheckNTErrors((index == NULL || a->order == index->order), "Unmatched input tensors!"); CheckNTErrors(index == NULL || a->order == index->order, "Unmatched input tensors!");
CheckNTErrors((index->dataType == X_INT), "Wrong data type!"); CheckNTErrors(index->dataType == X_INT, "Wrong data type!");
int dimRDI = a->order - dim - 1; int dimRDI = a->order - dim - 1;
for (int i = 0; i < a->order; i++) { for (int i = 0; i < a->order; i++) {
if (i == dimRDI) { if (i == dimRDI) {
CheckNTErrors((b->dimSizeRDI[i] == k), "A too large K"); CheckNTErrors(b->dimSizeRDI[i] == k, "A too large K");
CheckNTErrors((index == NULL || index->dimSizeRDI[i] == k), "Wrong size!"); CheckNTErrors(index == NULL || index->dimSizeRDI[i] == k, "Wrong size!");
} }
else { else {
CheckNTErrors((b->dimSizeRDI[i] == a->dimSizeRDI[i]), "Wrong size!"); CheckNTErrors(b->dimSizeRDI[i] == a->dimSizeRDI[i], "Wrong size!");
CheckNTErrors((index == NULL || index->dimSizeRDI[i] == a->dimSizeRDI[i]), "Wrong size!"); CheckNTErrors(index == NULL || index->dimSizeRDI[i] == a->dimSizeRDI[i], "Wrong size!");
} }
} }
...@@ -100,7 +100,7 @@ void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -100,7 +100,7 @@ void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
for (int j = strideNumA >= k ? k - 1 : strideNumA - 1; j >= 0; j--) { for (int j = strideNumA >= k ? k - 1 : strideNumA - 1; j >= 0; j--) {
HeapNode<DTYPE> node = heap.Pop(); HeapNode<DTYPE> node = heap.Pop();
dataB[j * stride] = node.value; dataB[j * stride] = node.value;
indexData[j * stride] = node.index; indexData[j * stride] = (int)node.index;
} }
} }
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论