Commit 9f91f159 by xiaotong

bug fixes

parent 7c68f1e7
......@@ -440,7 +440,7 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
eachHeapMaxValue[threadIdx.y * blockDim.x + threadIdx.x] = minData;
//need more optimation
if (i == 0) {
int threadLimit = (threadIdx.y + 1) * blockDim.x;
int threadLimit = threadIdx.y * blockDim.x + min(blockDim.x, strideNum);
CudaXHeap<MIN_HEAP, T> chooseHeap(k, heapData + k * ((blockDim.x * blockDim.y) + threadIdx.y));
int counter = threadIdx.y * blockDim.x;
for (; counter < threadIdx.y * blockDim.x + k; ++counter) {
......@@ -888,4 +888,4 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
} // namespace nts(NiuTrans.Tensor)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论