Commit 70f3926a by xiaotong

update the code of heap

parent d489680b
......@@ -289,6 +289,7 @@ void T2TSearch::Collect(T2TStateBundle * beam)
"Invalid sample id!");
if(IsEnd(state.prediction)){
fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore));
state.isEnd = 1;
}
else
......
......@@ -37,20 +37,20 @@ get the top-k items along a given dimension
*/
void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
{
CheckNTErrors((a->unitSize == b->unitSize), "Unmatched input tensors!");
CheckNTErrors((a->order == b->order), "Unmatched input tensors!");
CheckNTErrors((index == NULL || a->order == index->order), "Unmatched input tensors!");
CheckNTErrors((index->dataType == X_INT), "Wrong data type!");
CheckNTErrors(a->unitSize == b->unitSize, "Unmatched input tensors!");
CheckNTErrors(a->order == b->order, "Unmatched input tensors!");
CheckNTErrors(index == NULL || a->order == index->order, "Unmatched input tensors!");
CheckNTErrors(index->dataType == X_INT, "Wrong data type!");
int dimRDI = a->order - dim - 1;
for (int i = 0; i < a->order; i++) {
if (i == dimRDI) {
CheckNTErrors((b->dimSizeRDI[i] == k), "A too large K");
CheckNTErrors((index == NULL || index->dimSizeRDI[i] == k), "Wrong size!");
CheckNTErrors(b->dimSizeRDI[i] == k, "A too large K");
CheckNTErrors(index == NULL || index->dimSizeRDI[i] == k, "Wrong size!");
}
else {
CheckNTErrors((b->dimSizeRDI[i] == a->dimSizeRDI[i]), "Wrong size!");
CheckNTErrors((index == NULL || index->dimSizeRDI[i] == a->dimSizeRDI[i]), "Wrong size!");
CheckNTErrors(b->dimSizeRDI[i] == a->dimSizeRDI[i], "Wrong size!");
CheckNTErrors(index == NULL || index->dimSizeRDI[i] == a->dimSizeRDI[i], "Wrong size!");
}
}
......@@ -100,7 +100,7 @@ void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
for (int j = strideNumA >= k ? k - 1 : strideNumA - 1; j >= 0; j--) {
HeapNode<DTYPE> node = heap.Pop();
dataB[j * stride] = node.value;
indexData[j * stride] = node.index;
indexData[j * stride] = (int)node.index;
}
}
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论