Commit 8eae2dbf by xiaotong

use Sort instead of TopK if K is larger than the size of the dimension we go along with

parent 0aac9d31
......@@ -23,6 +23,7 @@
#include "../../XName.h"
#include "TopK.h"
#include "TopK.cuh"
#include "Sort.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -116,7 +117,10 @@ get the top-k items along a given dimension
*/
void TopK(XTensor &a, XTensor &b, XTensor &index, int dim, int k)
{
_TopK(&a, &b, &index, dim, k);
if(a.dimSize[dim] <= k)
_Sort(&a, &b, &index, dim);
else
_TopK(&a, &b, &index, dim, k);
/* tensor connection */
XList list(2);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论