Commit 8eae2dbf by xiaotong

use Sort instead of TopK if K is larger than the size of the dimension we go along with

parent 0aac9d31
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "../../XName.h" #include "../../XName.h"
#include "TopK.h" #include "TopK.h"
#include "TopK.cuh" #include "TopK.cuh"
#include "Sort.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
...@@ -116,6 +117,9 @@ get the top-k items along a given dimension ...@@ -116,6 +117,9 @@ get the top-k items along a given dimension
*/ */
void TopK(XTensor &a, XTensor &b, XTensor &index, int dim, int k) void TopK(XTensor &a, XTensor &b, XTensor &index, int dim, int k)
{ {
if(a.dimSize[dim] <= k)
_Sort(&a, &b, &index, dim);
else
_TopK(&a, &b, &index, dim, k); _TopK(&a, &b, &index, dim, k);
/* tensor connection */ /* tensor connection */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论