Commit cbe68108 by 张裕浩

TopK BUG2 fix

parent 965bf9d6
...@@ -871,7 +871,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -871,7 +871,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
//delete indexA; //delete indexA;
int workerNum = WORKERSNUM; int workerNum = WORKERSNUM;
GDevs.GetCudaThread2D(a->mem->devID, GDevs.GetCudaThread2D(a->devID,
workerNum, stride * blockNum, MAX_INT, workerNum, stride * blockNum, MAX_INT,
cudaGrids, cudaBlocks); cudaGrids, cudaBlocks);
if (a->dataType == DEFAULT_DTYPE) { if (a->dataType == DEFAULT_DTYPE) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论