Commit 965bf9d6 by 张裕浩

topk BUG fix

parent 6d05d6c8
...@@ -377,8 +377,8 @@ get the top-k items ...@@ -377,8 +377,8 @@ get the top-k items
template<class T> __global__ template<class T> __global__
void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index) void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index)
{ {
__shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE - 1024 * sizeof(T)) / sizeof(CudaHeapNode<T>)]; __shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE - 512 * sizeof(T)) / sizeof(CudaHeapNode<T>)];
__shared__ T eachHeapMaxValue[1024]; __shared__ T eachHeapMaxValue[512];
/*optimization k size the parameter must more than half of k*/ /*optimization k size the parameter must more than half of k*/
int parameter = 0; int parameter = 0;
...@@ -429,7 +429,7 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -429,7 +429,7 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
} }
__syncthreads(); __syncthreads();
/*to merge the heap use another way*/ /* to merge the heap use another way */
T minData = minValue; T minData = minValue;
int heapLimit = heap.count / 2; int heapLimit = heap.count / 2;
if (heapLimit % 2 == 0 && heapLimit != 0) heapLimit -= 1; if (heapLimit % 2 == 0 && heapLimit != 0) heapLimit -= 1;
...@@ -438,12 +438,13 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -438,12 +438,13 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
minData = heap.items[counter].value; minData = heap.items[counter].value;
} }
eachHeapMaxValue[threadIdx.y * blockDim.x + threadIdx.x] = minData; eachHeapMaxValue[threadIdx.y * blockDim.x + threadIdx.x] = minData;
//need more optimation //need more optimation
if (i == 0) { if (i == 0) {
int threadLimit = (threadIdx.y + 1) * blockDim.x; int threadLimit = threadIdx.y * blockDim.x + min(blockDim.x,strideNum);
CudaXHeap<MIN_HEAP, T> chooseHeap(k, heapData + k * ((blockDim.x * blockDim.y) + threadIdx.y)); CudaXHeap<MIN_HEAP, T> chooseHeap(k, heapData + k * ((blockDim.x * blockDim.y) + threadIdx.y));
int counter = threadIdx.y * blockDim.x; int counter = threadIdx.y * blockDim.x;
for (; counter < threadIdx.y * blockDim.x + k; ++counter) { for (; counter < threadIdx.y * blockDim.x + min(k, blockDim.x); ++counter) {
chooseHeap.Push(counter, eachHeapMaxValue[counter]); chooseHeap.Push(counter, eachHeapMaxValue[counter]);
} }
for (; counter < threadLimit; ++counter) { for (; counter < threadLimit; ++counter) {
...@@ -451,15 +452,16 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -451,15 +452,16 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
chooseHeap.ReplaceTop(counter, eachHeapMaxValue[counter]); chooseHeap.ReplaceTop(counter, eachHeapMaxValue[counter]);
} }
} }
int heapNum = chooseHeap.count;
CudaXHeap<MIN_HEAP, T> ansHeapData(k, k - parameter, heapData + k * chooseHeap.items[0].index); CudaXHeap<MIN_HEAP, T> ansHeapData(k, k - parameter, heapData + k * chooseHeap.items[0].index);
int miss = parameter; int miss = parameter;
for (counter = 1; counter < k; ++counter) { for (counter = 1; counter < heapNum; ++counter) {
chooseHeap.items[0] = chooseHeap.items[chooseHeap.count - 1]; chooseHeap.items[0] = chooseHeap.items[chooseHeap.count - 1];
chooseHeap.count--; chooseHeap.count--;
chooseHeap.Down(0); chooseHeap.Down(0);
CudaHeapNode<T> * cmpHeapData = heapData + k * (chooseHeap.items[0].index); CudaHeapNode<T> * cmpHeapData = heapData + k * (chooseHeap.items[0].index);
int cmpHeapLimit = 0; int cmpHeapLimit = 0;
if (counter + heapLimit <= k - parameter){ if (counter + heapLimit <= k - parameter && heapNum == k){
cmpHeapLimit = heapLimit; cmpHeapLimit = heapLimit;
} }
/* take the max data from the minHeap,so start search from the leaf node */ /* take the max data from the minHeap,so start search from the leaf node */
...@@ -826,7 +828,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -826,7 +828,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
else if (k < 22) workerNum = 128; else if (k < 22) workerNum = 128;
else if (k < 44) workerNum = 64; else if (k < 44) workerNum = 64;
else workerNum = 32; else workerNum = 32;
int cudaGrids[3]; int cudaGrids[3];
int cudaBlocks[3]; int cudaBlocks[3];
...@@ -840,7 +842,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -840,7 +842,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
/* we run the kernel if the heaps can fit into the shared memory */ /* we run the kernel if the heaps can fit into the shared memory */
cudaGrids[1] *= cudaBlocks[1]; cudaGrids[1] *= cudaBlocks[1];
cudaBlocks[1] = 1; cudaBlocks[1] = 1;
if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) < SHARED_MEMORY_SIZE) { if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) + (512 * sizeof(int))< SHARED_MEMORY_SIZE) {
if (a->dataType == DEFAULT_DTYPE) { if (a->dataType == DEFAULT_DTYPE) {
KernelTopK3<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> KernelTopK3<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN, ((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论