Commit fa2ed07c by 张裕浩

clean code

parent ec71b1a9
...@@ -33,7 +33,8 @@ namespace nts{ // namespace nts(NiuTrans.Tensor) ...@@ -33,7 +33,8 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* /*
use PTX code to reduce float data use PTX code to reduce float data
*/ */
__device__ __forceinline__ float shfl_down_reduce_max(float input) __device__ __forceinline__
float shflDownReduceMax(float input)
{ {
float output; float output;
asm volatile( asm volatile(
...@@ -61,6 +62,38 @@ __device__ __forceinline__ float shfl_down_reduce_max(float input) ...@@ -61,6 +62,38 @@ __device__ __forceinline__ float shfl_down_reduce_max(float input)
return output; return output;
} }
/*
use PTX code to reduce int data
*/
__device__ __forceinline__
int shflDownReduceMax(int input)
{
int output;
asm volatile(
"{"
".reg .s32 r0;"
".reg .pred p;"
"shfl.down.b32 r0, %1, 0x10, 0x1f;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.down.b32 r0, %1, 0x8, 0xf;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.down.b32 r0, %1, 0x4, 0x7;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.down.b32 r0, %1, 0x2, 0x3;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.down.b32 r0, %1, 0x1, 0x1;"
"setp.lt.s32 p, %1, r0; "
"@p mov.s32 %1,r0;"
"mov.s32 %0,%1;"
"}"
: "=r"(output) : "r"(input));
return output;
}
/* /*
reduce a tensor to another that keeps the max value along a dimension - slow version reduce a tensor to another that keeps the max value along a dimension - slow version
Given a block of data, we go over each dimension i in the stride and we have Given a block of data, we go over each dimension i in the stride and we have
...@@ -224,39 +257,18 @@ void KernelReduceMaxFast(DTYPE * input, DTYPE * output, ...@@ -224,39 +257,18 @@ void KernelReduceMaxFast(DTYPE * input, DTYPE * output,
DTYPE value2 = j + blockDim.y < strideNum ? inputData[(j + blockDim.y) * stride + iOffset]: FLOAT_MIN; DTYPE value2 = j + blockDim.y < strideNum ? inputData[(j + blockDim.y) * stride + iOffset]: FLOAT_MIN;
value = MAX(value, value2); value = MAX(value, value2);
value = shfl_down_reduce_max(value); value = shflDownReduceMax(value);
if ((tid & 0x1f) == 0) { data[tid / 32] = value; } if ((tid & 0x1f) == 0) { data[tid / 32] = value; }
__syncthreads(); __syncthreads();
if (tid < 32) if (tid < 32) {
{
if (tid < blockDim.y / 32) if (tid < blockDim.y / 32)
value = data[tid]; value = data[tid];
else value = FLOAT_MIN; else value = FLOAT_MIN;
value = shfl_down_reduce_max(value); value = shflDownReduceMax(value);
if (tid == 0 && blockIdx.y < reducedStrideNum) if (tid == 0 && blockIdx.y < reducedStrideNum)
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = value; output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = value;
} }
///* load data into the shared mem */
//data[tid] = MAX(value, value2);
//__syncthreads();
///* unroll the warp */
//if(goodSize >= 512) {if(tid < 256) {if(data[tid] < data[tid + 256]) data[tid] = data[tid + 256];} __syncthreads();}
//if(goodSize >= 256) {if(tid < 128) {if(data[tid] < data[tid + 128]) data[tid] = data[tid + 128];} __syncthreads();}
//if(goodSize >= 128) {if(tid < 64) {if(data[tid] < data[tid + 64]) data[tid] = data[tid + 64];} __syncthreads();}
//if(goodSize >= 64) {if(tid < 32) {if(data[tid] < data[tid + 32]) data[tid] = data[tid + 32];} __syncthreads();}
//if(goodSize >= 32) {if(tid < 16) {if(data[tid] < data[tid + 16]) data[tid] = data[tid + 16];} __syncthreads();}
//if(goodSize >= 16) {if(tid < 8) {if(data[tid] < data[tid + 8]) data[tid] = data[tid + 8];} __syncthreads();}
//if(goodSize >= 8) {if(tid < 4) {if(data[tid] < data[tid + 4]) data[tid] = data[tid + 4];} __syncthreads();}
//if(goodSize >= 4) {if(tid < 2) {if(data[tid] < data[tid + 2]) data[tid] = data[tid + 2];} __syncthreads();}
//if(goodSize >= 2) {if(tid < 1) {if(data[tid] < data[tid + 1]) data[tid] = data[tid + 1];} __syncthreads();}
///* write result for this block to the output array */
//if(threadIdx.y == 0 && blockIdx.y < reducedStrideNum)
// output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = data[0];
} }
/* /*
...@@ -373,14 +385,15 @@ void KernelReduceMaxSimpleFast(DTYPE * input, DTYPE * output, ...@@ -373,14 +385,15 @@ void KernelReduceMaxSimpleFast(DTYPE * input, DTYPE * output,
op[offset] = max; op[offset] = max;
} }
/*
according the GPU's sm number allocation warp num
*/
inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long vectorNum, int vectorSize) inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long vectorNum, int vectorSize)
{ {
int warpNum = 4; int warpNum = 4;
if (vectorNum < 20 * 8) if (vectorNum < 20 * 8){
{
warpNum = 8; warpNum = 8;
if (vectorNum < 20 * 4) if (vectorNum < 20 * 4){
{
warpNum = 16; warpNum = 16;
if (warpNum < 20 * 2) if (warpNum < 20 * 2)
warpNum = 32; warpNum = 32;
...@@ -389,6 +402,7 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long ...@@ -389,6 +402,7 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long
int minWarpNum = vectorSize / 32; int minWarpNum = vectorSize / 32;
if (vectorSize % 32 != 0) minWarpNum++; if (vectorSize % 32 != 0) minWarpNum++;
warpNum = min(warpNum, minWarpNum); warpNum = min(warpNum, minWarpNum);
grid.x = vectorNum; grid.x = vectorNum;
grid.y = 1; grid.y = 1;
grid.z = 1; grid.z = 1;
...@@ -397,39 +411,44 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long ...@@ -397,39 +411,44 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long
block.z = 1; block.z = 1;
} }
/*
adjust threads.x number then we can use warp optimization
*/
inline void adjustThreadForUseWarpOptimization(dim3& blocks, dim3& threads) inline void adjustThreadForUseWarpOptimization(dim3& blocks, dim3& threads)
{ {
if (threads.x > 1) if (threads.x > 1) {
{
blocks.x *= threads.x; blocks.x *= threads.x;
threads.x = 1; threads.x = 1;
} }
if (threads.y<32) if (threads.y < 32)
threads.y = 32; threads.y = 32;
} }
/*
In some case,we use less block to imporve efficiency
*/
__global__ __global__
void KernelReduceMaxOpLessBlocks(DTYPE * input, DTYPE * output, void KernelReduceMaxOpLessBlocks(DTYPE * input, DTYPE * output, int strideNum, int blockNum)
int strideNum, int blockNum)
{ {
int idx = threadIdx.x % 32; int idx = threadIdx.x % 32;
int idy = (blockIdx.x * blockDim.x + threadIdx.x) / 32; int idy = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int startIndex = idy * strideNum; int startIndex = idy * strideNum;
DTYPE threadMax = FLOAT_MIN; DTYPE threadMax = FLOAT_MIN;
for (int i = idx; i < strideNum; i += 32) for (int i = idx; i < strideNum; i += 32) {
{
threadMax = max(input[startIndex + i], threadMax); threadMax = max(input[startIndex + i], threadMax);
} }
threadMax = shfl_down_reduce_max(threadMax); threadMax = shflDownReduceMax(threadMax);
if (idx == 0) if (idx == 0)
output[idy] = threadMax; output[idy] = threadMax;
} }
/*
we use PTX code reduce
*/
__global__ __global__
void KernelReduceMaxOp(DTYPE * input, DTYPE * output, void KernelReduceMaxOp(DTYPE * input, DTYPE * output,int stride, int strideNum,
int stride, int strideNum, int reducedStrideNum, int reducedStrideNum,int blockSize, int blockNum)
int blockSize, int blockNum)
{ {
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK / 32]; __shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK / 32];
...@@ -447,26 +466,19 @@ void KernelReduceMaxOp(DTYPE * input, DTYPE * output, ...@@ -447,26 +466,19 @@ void KernelReduceMaxOp(DTYPE * input, DTYPE * output,
DTYPE * data = iData + threadIdx.x * blockDim.y; DTYPE * data = iData + threadIdx.x * blockDim.y;
DTYPE * inputData = input + k * blockSize; DTYPE * inputData = input + k * blockSize;
for (int it = j; it < strideNum; it += blockDim.y) for (int it = j; it < strideNum; it += blockDim.y){
{
threadMax = max(inputData[it * stride + iOffset], threadMax); threadMax = max(inputData[it * stride + iOffset], threadMax);
} }
__syncthreads(); __syncthreads();
//op reduce threadMax = shflDownReduceMax(threadMax);
/*threadSum += __shfl_down_sync(0xFFFFFFFF, threadSum, 16, 32);
threadSum += __shfl_down_sync(0xFFFFFFFF, threadSum, 8, 16);
threadSum += __shfl_down_sync(0xFFFFFFFF, threadSum, 4, 8);
threadSum += __shfl_down_sync(0xFFFFFFFF, threadSum, 2, 4);
threadSum += __shfl_down_sync(0xFFFFFFFF, threadSum, 1, 2);*/
threadMax = shfl_down_reduce_max(threadMax);
if ((tid & 0x1f) == 0) { data[tid / 32] = threadMax; } if ((tid & 0x1f) == 0) { data[tid / 32] = threadMax; }
__syncthreads(); __syncthreads();
if (tid < 32) /* use one warp to reduce remaining data */
{ if (tid < 32){
if (tid < blockDim.y / 32) if (tid < blockDim.y / 32)
threadMax = data[tid]; threadMax = data[tid];
else threadMax = 0; else threadMax = 0;
threadMax = shfl_down_reduce_max(threadMax); threadMax = shflDownReduceMax(threadMax);
if (tid == 0 && blockIdx.y < reducedStrideNum) if (tid == 0 && blockIdx.y < reducedStrideNum)
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = threadMax; output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = threadMax;
} }
...@@ -528,20 +540,18 @@ void _CudaReduceMax(const XTensor * input, XTensor * output, int dim) ...@@ -528,20 +540,18 @@ void _CudaReduceMax(const XTensor * input, XTensor * output, int dim)
int devIDBackup; int devIDBackup;
ProtectCudaDev(input->devID, devIDBackup); ProtectCudaDev(input->devID, devIDBackup);
if (stride == 1 && blockNum >= 10) if (stride == 1 && blockNum >= 10) {
{
dim3 grids; dim3 grids;
dim3 blocks; dim3 blocks;
continuousStorageThreadAllocation(grids, blocks, (long long)blockNum, strideNum); continuousStorageThreadAllocation(grids, blocks, (long long)blockNum, strideNum);
if (blocks.y > 128) if (blocks.y > 128) {
KernelReduceMaxOp << <grids, blocks >> > ((DTYPE *)input->data, (DTYPE*)output->data, stride, strideNum, grids.y, blockSize, blockNum); KernelReduceMaxOp <<<grids, blocks >>> ((DTYPE *)input->data, (DTYPE*)output->data, stride, strideNum, grids.y, blockSize, blockNum);
else }
{ else {
KernelReduceMaxOpLessBlocks << <blockNum / 4, 128 >> > ((DTYPE *)input->data, (DTYPE*)output->data, strideNum, blockNum); KernelReduceMaxOpLessBlocks <<<blockNum / 4, 128 >>> ((DTYPE *)input->data, (DTYPE*)output->data, strideNum, blockNum);
} }
} }
else else {
{
do { do {
if (input->dataType == DEFAULT_DTYPE) { if (input->dataType == DEFAULT_DTYPE) {
DTYPE * iData = NULL; DTYPE * iData = NULL;
...@@ -565,7 +575,7 @@ void _CudaReduceMax(const XTensor * input, XTensor * output, int dim) ...@@ -565,7 +575,7 @@ void _CudaReduceMax(const XTensor * input, XTensor * output, int dim)
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1) if (cudaGridSize[0] == 1)
oData = (DTYPE*)output->data; oData = (DTYPE*)output->data;
KernelReduceMax << <blocks, threads >> > (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); KernelReduceMax <<<blocks, threads >>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
} }
else if (strideNum < 128) { else if (strideNum < 128) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
...@@ -574,7 +584,7 @@ void _CudaReduceMax(const XTensor * input, XTensor * output, int dim) ...@@ -574,7 +584,7 @@ void _CudaReduceMax(const XTensor * input, XTensor * output, int dim)
oData = (DTYPE*)output->data; oData = (DTYPE*)output->data;
CheckNTErrors((cudaBlockSize[0] >= 64), "Incorrect thread number when calling the cuda kernel!"); CheckNTErrors((cudaBlockSize[0] >= 64), "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads); adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceMaxFast<64> << <blocks, threads >> > (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); KernelReduceMaxFast<64> <<<blocks, threads >>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
} }
else if (strideNum < 256) { else if (strideNum < 256) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
...@@ -583,7 +593,7 @@ void _CudaReduceMax(const XTensor * input, XTensor * output, int dim) ...@@ -583,7 +593,7 @@ void _CudaReduceMax(const XTensor * input, XTensor * output, int dim)
oData = (DTYPE*)output->data; oData = (DTYPE*)output->data;
CheckNTErrors((cudaBlockSize[0] >= 128), "Incorrect thread number when calling the cuda kernel!"); CheckNTErrors((cudaBlockSize[0] >= 128), "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads); adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceMaxFast<128> << <blocks, threads >> > (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); KernelReduceMaxFast<128> <<<blocks, threads >>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
} }
else if (strideNum < 512) { else if (strideNum < 512) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
...@@ -592,7 +602,7 @@ void _CudaReduceMax(const XTensor * input, XTensor * output, int dim) ...@@ -592,7 +602,7 @@ void _CudaReduceMax(const XTensor * input, XTensor * output, int dim)
oData = (DTYPE*)output->data; oData = (DTYPE*)output->data;
CheckNTErrors((cudaBlockSize[0] >= 256), "Incorrect thread number when calling the cuda kernel!"); CheckNTErrors((cudaBlockSize[0] >= 256), "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads); adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceMaxFast<256> << <blocks, threads >> > (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); KernelReduceMaxFast<256> <<<blocks, threads >>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
} }
else { else {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
...@@ -601,7 +611,7 @@ void _CudaReduceMax(const XTensor * input, XTensor * output, int dim) ...@@ -601,7 +611,7 @@ void _CudaReduceMax(const XTensor * input, XTensor * output, int dim)
oData = (DTYPE*)output->data; oData = (DTYPE*)output->data;
CheckNTErrors((cudaBlockSize[0] >= 512), "Incorrect thread number when calling the cuda kernel!"); CheckNTErrors((cudaBlockSize[0] >= 512), "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads); adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceMaxFast<512> << <blocks, threads >> > (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); KernelReduceMaxFast<512> <<<blocks, threads >>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
} }
} }
else if (input->dataType == X_FLOAT16) { else if (input->dataType == X_FLOAT16) {
......
...@@ -30,7 +30,8 @@ namespace nts{ // namespace nts(NiuTrans.Tensor) ...@@ -30,7 +30,8 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* /*
use PTX code to reduce float data use PTX code to reduce float data
*/ */
__device__ __forceinline__ float shfl_down_reduce_sum(float input) __device__ __forceinline__
float shflDownReduceSum(float input)
{ {
float output; float output;
asm volatile( asm volatile(
...@@ -54,7 +55,8 @@ __device__ __forceinline__ float shfl_down_reduce_sum(float input) ...@@ -54,7 +55,8 @@ __device__ __forceinline__ float shfl_down_reduce_sum(float input)
/* /*
use PTX code to reduce int data use PTX code to reduce int data
*/ */
__device__ __forceinline__ int shfl_down_reduce_sum(int input) __device__ __forceinline__
int shflDownReduceSum(int input)
{ {
int output; int output;
asm volatile( asm volatile(
...@@ -326,47 +328,17 @@ void KernelReduceSumFast(DTYPE * input, DTYPE * output, ...@@ -326,47 +328,17 @@ void KernelReduceSumFast(DTYPE * input, DTYPE * output,
value = value + value2; value = value + value2;
__syncthreads(); __syncthreads();
value = shfl_down_reduce_sum(value); value = shflDownReduceSum(value);
/*value += __shfl_down_sync(0x0000001F, value, 16, 32);
value += __shfl_down_sync(0x0000001F, value, 8, 16);
value += __shfl_down_sync(0x0000001F, value, 4, 8);
value += __shfl_down_sync(0x0000001F, value, 2, 4);
value += __shfl_down_sync(0x0000001F, value, 1, 2);*/
if ((tid & 0x1f) == 0) { data[tid / 32] = value; } if ((tid & 0x1f) == 0) { data[tid / 32] = value; }
__syncthreads(); __syncthreads();
if (tid < 32) if (tid < 32){
{
if (tid < blockDim.y / 32) if (tid < blockDim.y / 32)
value = data[tid]; value = data[tid];
else value = 0; else value = 0;
value = shfl_down_reduce_sum(value); value = shflDownReduceSum(value);
if (tid == 0 && blockIdx.y < reducedStrideNum) if (tid == 0 && blockIdx.y < reducedStrideNum)
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = value; output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = value;
} }
/*if (blockDim.y / 32 >= 16) { if (tid < 8) { data[tid] += data[tid + 8]; } __syncthreads(); }
if (blockDim.y / 32 >= 8) { if (tid < 4) { data[tid] += data[tid + 4]; } __syncthreads(); }
if (blockDim.y / 32 >= 4) { if (tid < 2) { data[tid] += data[tid + 2]; } __syncthreads(); }
if (blockDim.y / 32 >= 2) { if (tid < 1) { data[tid] += data[tid + 1]; } __syncthreads(); }*/
///* load data into the shared mem */
//data[tid] = value + value2;
//__syncthreads();
///* unroll the warp */
//if(goodSize >= 512) {if(tid < 256) {data[tid] += data[tid + 256];} __syncthreads();}
//if(goodSize >= 256) {if(tid < 128) {data[tid] += data[tid + 128];} __syncthreads();}
//if(goodSize >= 128) {if(tid < 64) {data[tid] += data[tid + 64];} __syncthreads();}
//if(goodSize >= 64) {if(tid < 32) {data[tid] += data[tid + 32];} __syncthreads();}
//if(goodSize >= 32) {if(tid < 16) {data[tid] += data[tid + 16];} __syncthreads();}
//if(goodSize >= 16) {if(tid < 8) {data[tid] += data[tid + 8];} __syncthreads();}
//if(goodSize >= 8) {if(tid < 4) {data[tid] += data[tid + 4];} __syncthreads();}
//if(goodSize >= 4) {if(tid < 2) {data[tid] += data[tid + 2];} __syncthreads();}
//if(goodSize >= 2) {if(tid < 1) {data[tid] += data[tid + 1];} __syncthreads();}
///* write result for this block to the output array */
//if(threadIdx.y == 0 && blockIdx.y < reducedStrideNum)
// output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = data[0];
} }
/* /*
...@@ -502,9 +474,12 @@ void KernelReduceSumFast(__half * input, __half * output, ...@@ -502,9 +474,12 @@ void KernelReduceSumFast(__half * input, __half * output,
#endif #endif
} }
__global__ void KernelReduceSumDiscontinuousStorage(DTYPE * input, DTYPE * output, /*
int stride, int strideNum, if data storage is discontinuius ,use this way to reduce
DTYPE * shift, DTYPE power, bool isExp) */
__global__
void KernelReduceSumDiscontinuousStorage(DTYPE * input, DTYPE * output, int stride,
int strideNum, DTYPE * shift, DTYPE power, bool isExp)
{ {
//int idx = blockIdx.x * blockDim.x + threadIdx.x; //int idx = blockIdx.x * blockDim.x + threadIdx.x;
//int endIndex = (idx+1) * strideNum; //int endIndex = (idx+1) * strideNum;
...@@ -515,11 +490,9 @@ __global__ void KernelReduceSumDiscontinuousStorage(DTYPE * input, DTYPE * outpu ...@@ -515,11 +490,9 @@ __global__ void KernelReduceSumDiscontinuousStorage(DTYPE * input, DTYPE * outpu
#pragma unroll #pragma unroll
for (int i = stride * strideNum * blockIndex + offsetInBlock; for (int i = stride * strideNum * blockIndex + offsetInBlock;
i < stride * strideNum * blockIndex + offsetInBlock + stride * strideNum; i < stride * strideNum * blockIndex + offsetInBlock + stride * strideNum;
i += stride) i += stride){
{
ans += input[i]; ans += input[i];
} }
if (threadIdx.x == 0 && blockIdx.x == 0) printf("%d ", stride);
output[idx] = ans; output[idx] = ans;
} }
...@@ -551,8 +524,7 @@ void KernelReduceSumOp(DTYPE * input, DTYPE * output, ...@@ -551,8 +524,7 @@ void KernelReduceSumOp(DTYPE * input, DTYPE * output,
DTYPE * data = iData + threadIdx.x * blockDim.y; DTYPE * data = iData + threadIdx.x * blockDim.y;
DTYPE * inputData = input + k * blockSize; DTYPE * inputData = input + k * blockSize;
for (int it = j; it < strideNum; it += blockDim.y) for (int it = j; it < strideNum; it += blockDim.y){
{
DTYPE value = inputData[it * stride + iOffset] - bias[threadIdx.x]; DTYPE value = inputData[it * stride + iOffset] - bias[threadIdx.x];
if (power != (DTYPE)1.0) { if (power != (DTYPE)1.0) {
if (power == (DTYPE)2.0) { if (power == (DTYPE)2.0) {
...@@ -569,34 +541,18 @@ void KernelReduceSumOp(DTYPE * input, DTYPE * output, ...@@ -569,34 +541,18 @@ void KernelReduceSumOp(DTYPE * input, DTYPE * output,
threadSum += value; threadSum += value;
} }
__syncthreads(); __syncthreads();
//op reduce threadSum = shflDownReduceSum(threadSum);
/*threadSum += __shfl_down_sync(0xFFFFFFFF, threadSum, 16, 32);
threadSum += __shfl_down_sync(0xFFFFFFFF, threadSum, 8, 16);
threadSum += __shfl_down_sync(0xFFFFFFFF, threadSum, 4, 8);
threadSum += __shfl_down_sync(0xFFFFFFFF, threadSum, 2, 4);
threadSum += __shfl_down_sync(0xFFFFFFFF, threadSum, 1, 2);*/
threadSum = shfl_down_reduce_sum(threadSum);
if ((tid & 0x1f) == 0) { data[tid / 32] = threadSum; } if ((tid & 0x1f) == 0) { data[tid / 32] = threadSum; }
__syncthreads(); __syncthreads();
if (tid < 32) if (tid < 32){
{
if (tid < blockDim.y / 32) if (tid < blockDim.y / 32)
threadSum = data[tid]; threadSum = data[tid];
else threadSum = 0; else threadSum = 0;
threadSum = shfl_down_reduce_sum(threadSum); threadSum = shflDownReduceSum(threadSum);
if (tid == 0 && blockIdx.y < reducedStrideNum) if (tid == 0 && blockIdx.y < reducedStrideNum)
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = threadSum; output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = threadSum;
} }
/*if (blockDim.y / 32 >= 32) { if (tid < 16) { data[tid] += data[tid + 16]; } __syncthreads(); }
if (blockDim.y / 32 >= 16) { if (tid < 8) { data[tid] += data[tid + 8]; } __syncthreads(); }
if (blockDim.y / 32 >= 8) { if (tid < 4) { data[tid] += data[tid + 4]; } __syncthreads(); }
if (blockDim.y / 32 >= 4) { if (tid < 2) { data[tid] += data[tid + 2]; } __syncthreads(); }
if (blockDim.y / 32 >= 2) { if (tid < 1) { data[tid] += data[tid + 1]; } __syncthreads(); }
// write result for this block to the output array
if (threadIdx.y == 0 && blockIdx.y < reducedStrideNum)
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = data[0];*/
} }
__global__ __global__
...@@ -612,8 +568,7 @@ void KernelReduceSumOpLessBlocks(DTYPE * input, DTYPE * output, ...@@ -612,8 +568,7 @@ void KernelReduceSumOpLessBlocks(DTYPE * input, DTYPE * output,
bias[threadIdx.x / 32] = shift != NULL ? shift[idy] : 0; bias[threadIdx.x / 32] = shift != NULL ? shift[idy] : 0;
int startIndex = idy * strideNum; int startIndex = idy * strideNum;
DTYPE threadSum = 0; DTYPE threadSum = 0;
for (int i = idx; i < strideNum; i += 32) for (int i = idx; i < strideNum; i += 32) {
{
DTYPE value = input[startIndex + i] - bias[threadIdx.x / 32]; DTYPE value = input[startIndex + i] - bias[threadIdx.x / 32];
if (power != (DTYPE)1.0) { if (power != (DTYPE)1.0) {
if (power == (DTYPE)2.0) { if (power == (DTYPE)2.0) {
...@@ -629,30 +584,20 @@ void KernelReduceSumOpLessBlocks(DTYPE * input, DTYPE * output, ...@@ -629,30 +584,20 @@ void KernelReduceSumOpLessBlocks(DTYPE * input, DTYPE * output,
if (isExp) value = exp(value); if (isExp) value = exp(value);
threadSum += value; threadSum += value;
} }
threadSum = shfl_down_reduce_sum(threadSum); threadSum = shflDownReduceSum(threadSum);
if (idx == 0) if (idx == 0)
output[idy] = threadSum; output[idy] = threadSum;
/*__shared__ DTYPE idata[128];
idata[threadIdx.x] = threadSum;
__syncthreads();
if (idx < 16) { idata[threadIdx.x] += idata[threadIdx.x + 16]; }__syncthreads();
if (idx < 8) { idata[threadIdx.x ] += idata[threadIdx.x + 8]; }__syncthreads();
if (idx < 4) { idata[threadIdx.x ] += idata[threadIdx.x + 4]; }__syncthreads();
if (idx < 2) { idata[threadIdx.x ] += idata[threadIdx.x + 2]; }__syncthreads();
if (idx < 1) { idata[threadIdx.x ] += idata[threadIdx.x + 1]; }__syncthreads();
if (idx == 0)
output[idy] = idata[threadIdx.x];*/
} }
//pytorch use this way to allocate threads,they maybe use hard-code according the SM number (the 1080 and 1080 Ti is 20),and it indeed have better perforamnce, /*
according the GPU's sm number allocation warp num
*/
inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long vectorNum, int vectorSize) inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long vectorNum, int vectorSize)
{ {
int warpNum = 4; int warpNum = 4;
if (vectorNum < 20 * 8) if (vectorNum < 20 * 8) {
{
warpNum = 8; warpNum = 8;
if (vectorNum < 20 * 4) if (vectorNum < 20 * 4) {
{
warpNum = 16; warpNum = 16;
if (warpNum < 20 * 2) if (warpNum < 20 * 2)
warpNum = 32; warpNum = 32;
...@@ -661,6 +606,7 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long ...@@ -661,6 +606,7 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long
int minWarpNum = vectorSize / 32; int minWarpNum = vectorSize / 32;
if (vectorSize % 32 != 0) minWarpNum++; if (vectorSize % 32 != 0) minWarpNum++;
warpNum = min(warpNum, minWarpNum); warpNum = min(warpNum, minWarpNum);
grid.x = vectorNum; grid.x = vectorNum;
grid.y = 1; grid.y = 1;
grid.z = 1; grid.z = 1;
...@@ -669,7 +615,9 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long ...@@ -669,7 +615,9 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long
block.z = 1; block.z = 1;
} }
//this situation we use block.x * grid.x deal one vector for continuous read /*
this situation we use block.x * grid.x deal one vector for continuous read
*/
inline void discontinuousStorageNoShareMemThreadAllocation(dim3& grid, dim3& block, int stride, int blockNum) inline void discontinuousStorageNoShareMemThreadAllocation(dim3& grid, dim3& block, int stride, int blockNum)
{ {
block.x = 512; block.x = 512;
...@@ -681,10 +629,12 @@ inline void discontinuousStorageNoShareMemThreadAllocation(dim3& grid, dim3& blo ...@@ -681,10 +629,12 @@ inline void discontinuousStorageNoShareMemThreadAllocation(dim3& grid, dim3& blo
grid.y = 1; grid.y = 1;
} }
/*
adjust threads.x number then we can use warp optimization
*/
inline void adjustThreadForUseWarpOptimization(dim3& blocks, dim3& threads) inline void adjustThreadForUseWarpOptimization(dim3& blocks, dim3& threads)
{ {
if (threads.x > 1) if (threads.x > 1){
{
blocks.x *= threads.x; blocks.x *= threads.x;
threads.x = 1; threads.x = 1;
} }
...@@ -757,33 +707,24 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen ...@@ -757,33 +707,24 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
int devIDBackup; int devIDBackup;
ProtectCudaDev(input->devID, devIDBackup); ProtectCudaDev(input->devID, devIDBackup);
if (stride == 1 && blockNum >= 10) if (stride == 1 && blockNum >= 10) {
{
dim3 grids; dim3 grids;
dim3 blocks; dim3 blocks;
continuousStorageThreadAllocation(grids, blocks, (long long)blockNum, strideNum); continuousStorageThreadAllocation(grids, blocks, (long long)blockNum, strideNum);
if (blocks.y > 128) if (blocks.y > 128)
KernelReduceSumOp << <grids, blocks >> > ((DTYPE *)input->data, (DTYPE*)output->data, stride, strideNum, grids.y, blockSize, blockNum, sp, power, isExp); KernelReduceSumOp <<<grids, blocks >>> ((DTYPE *)input->data, (DTYPE*)output->data, stride, strideNum, grids.y, blockSize, blockNum, sp, power, isExp);
else else
{ KernelReduceSumOpLessBlocks <<<blockNum / 4, 128 >>> ((DTYPE *)input->data, (DTYPE*)output->data, strideNum, blockNum, sp, power, isExp);
KernelReduceSumOpLessBlocks << <blockNum / 4, 128 >> > ((DTYPE *)input->data, (DTYPE*)output->data, strideNum, blockNum, sp, power, isExp);
}
//printf("grad %d %d thread %d %d\n", grids.x, grids.y, blocks.x, blocks.y);
} }
else if (stride != 1 && stride * blockNum > 4096) else if (stride != 1 && stride * blockNum > 4096){
{
//GDevs->GetGridAndBlockSize2D(devID, stride * blockNum, strideNum,MAX_INT, cudaGridSize, cudaBlockSize); //GDevs->GetGridAndBlockSize2D(devID, stride * blockNum, strideNum,MAX_INT, cudaGridSize, cudaBlockSize);
//printf("%d %d %d %d\n", cudaGridSize[0], cudaGridSize[1], cudaBlockSize[0], cudaBlockSize[1]);
//unsigned int* goutput = (unsigned int *)input->data; //unsigned int* goutput = (unsigned int *)input->data;
//convert2uintV2 << <dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >> > ((float*)input->data, goutput, stride, strideNum, blockNum, strideNum*blockNum*stride); //convert2uintV2 << <dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >> > ((float*)input->data, goutput, stride, strideNum, blockNum, strideNum*blockNum*stride);
dim3 grid, block; dim3 grid, block;
discontinuousStorageNoShareMemThreadAllocation(grid, block, stride, blockNum); discontinuousStorageNoShareMemThreadAllocation(grid, block, stride, blockNum);
//printf("%d %d %d %d\n", cudaGridSize[0], cudaGridSize[1], cudaBlockSize[0], cudaBlockSize[1]); KernelReduceSumDiscontinuousStorage <<<grid, block >>> ((DTYPE *)input->data, (DTYPE*)output->data, stride, strideNum, sp, power, isExp);
KernelReduceSumDiscontinuousStorage << <grid, block >> > ((DTYPE *)input->data, (DTYPE*)output->data, stride, strideNum, sp, power, isExp);
} }
else else {
{
do { do {
if (input->dataType == DEFAULT_DTYPE) { if (input->dataType == DEFAULT_DTYPE) {
DTYPE * iData = NULL; DTYPE * iData = NULL;
...@@ -806,7 +747,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen ...@@ -806,7 +747,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1) if (cudaGridSize[0] == 1)
oData = (DTYPE*)output->data; oData = (DTYPE*)output->data;
KernelReduceSum << <blocks, threads >> > (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum, sp, power, isExp); KernelReduceSum <<<blocks, threads >>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum, sp, power, isExp);
} }
else if (strideNum < 128) { else if (strideNum < 128) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
...@@ -815,7 +756,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen ...@@ -815,7 +756,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
oData = (DTYPE*)output->data; oData = (DTYPE*)output->data;
CheckNTErrors((cudaBlockSize[0] >= 64), "Incorrect thread number when calling the cuda kernel!"); CheckNTErrors((cudaBlockSize[0] >= 64), "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads); adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceSumFast<64> << <blocks, threads >> > (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum, sp, power, isExp); KernelReduceSumFast<64> <<<blocks, threads >>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum, sp, power, isExp);
} }
else if (strideNum < 256) { else if (strideNum < 256) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
...@@ -824,7 +765,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen ...@@ -824,7 +765,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
oData = (DTYPE*)output->data; oData = (DTYPE*)output->data;
CheckNTErrors((cudaBlockSize[0] >= 128), "Incorrect thread number when calling the cuda kernel!"); CheckNTErrors((cudaBlockSize[0] >= 128), "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads); adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceSumFast<128> << <blocks, threads >> > (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum, sp, power, isExp); KernelReduceSumFast<128> <<<blocks, threads >>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum, sp, power, isExp);
} }
else if (strideNum < 512) { else if (strideNum < 512) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
...@@ -833,7 +774,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen ...@@ -833,7 +774,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
oData = (DTYPE*)output->data; oData = (DTYPE*)output->data;
CheckNTErrors((cudaBlockSize[0] >= 256), "Incorrect thread number when calling the cuda kernel!"); CheckNTErrors((cudaBlockSize[0] >= 256), "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads); adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceSumFast<256> << <blocks, threads >> > (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum, sp, power, isExp); KernelReduceSumFast<256> <<<blocks, threads >>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum, sp, power, isExp);
} }
else { else {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
...@@ -842,7 +783,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen ...@@ -842,7 +783,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
oData = (DTYPE*)output->data; oData = (DTYPE*)output->data;
CheckNTErrors((cudaBlockSize[0] >= 512), "Incorrect thread number when calling the cuda kernel!"); CheckNTErrors((cudaBlockSize[0] >= 512), "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads); adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceSumFast<512> << <blocks, threads >> > (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum, sp, power, isExp); KernelReduceSumFast<512> <<<blocks, threads >>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum, sp, power, isExp);
} }
} }
else if (input->dataType == X_FLOAT16) { else if (input->dataType == X_FLOAT16) {
...@@ -872,7 +813,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen ...@@ -872,7 +813,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1) if (cudaGridSize[0] == 1)
oData = (__half*)output->data; oData = (__half*)output->data;
KernelReduceSum << <blocks, threads >> > (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum, spft16, *powerft16p, isExp); KernelReduceSum <<<blocks, threads >>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum, spft16, *powerft16p, isExp);
} }
else if (strideNum < 128) { else if (strideNum < 128) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
......
...@@ -433,53 +433,42 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -433,53 +433,42 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
T minData = minValue; T minData = minValue;
int heapLimit = heap.count / 2; int heapLimit = heap.count / 2;
if (heapLimit % 2 == 0 && heapLimit != 0) heapLimit -= 1; if (heapLimit % 2 == 0 && heapLimit != 0) heapLimit -= 1;
for (int counter = heap.count - 1; counter >= heapLimit; --counter) for (int counter = heap.count - 1; counter >= heapLimit; --counter) {
{
if (minData < heap.items[counter].value) if (minData < heap.items[counter].value)
minData = heap.items[counter].value; minData = heap.items[counter].value;
} }
eachHeapMaxValue[threadIdx.y * blockDim.x + threadIdx.x] = minData; eachHeapMaxValue[threadIdx.y * blockDim.x + threadIdx.x] = minData;
//need more optimation //need more optimation
if (i == 0) if (i == 0) {
{
int threadLimit = (threadIdx.y + 1) * blockDim.x; int threadLimit = (threadIdx.y + 1) * blockDim.x;
CudaXHeap<MIN_HEAP, T> chooseHeap(k, heapData + k * ((blockDim.x * blockDim.y) + threadIdx.y)); CudaXHeap<MIN_HEAP, T> chooseHeap(k, heapData + k * ((blockDim.x * blockDim.y) + threadIdx.y));
int counter = threadIdx.y * blockDim.x; int counter = threadIdx.y * blockDim.x;
for (; counter < threadIdx.y * blockDim.x + k; ++counter) for (; counter < threadIdx.y * blockDim.x + k; ++counter) {
{
chooseHeap.Push(counter, eachHeapMaxValue[counter]); chooseHeap.Push(counter, eachHeapMaxValue[counter]);
} }
for (; counter < threadLimit; ++counter) for (; counter < threadLimit; ++counter) {
{ if (eachHeapMaxValue[counter]>chooseHeap.items[0].value) {
if (eachHeapMaxValue[counter]>chooseHeap.items[0].value)
{
chooseHeap.ReplaceTop(counter, eachHeapMaxValue[counter]); chooseHeap.ReplaceTop(counter, eachHeapMaxValue[counter]);
} }
} }
CudaXHeap<MIN_HEAP, T> ansHeapData(k, k - parameter, heapData + k * chooseHeap.items[0].index); CudaXHeap<MIN_HEAP, T> ansHeapData(k, k - parameter, heapData + k * chooseHeap.items[0].index);
int miss = parameter; int miss = parameter;
for (counter = 1; counter < k; ++counter) for (counter = 1; counter < k; ++counter) {
{
//printf("%f %d\n",chooseHeap.items[0].value,chooseHeap.items[0].index);
chooseHeap.items[0] = chooseHeap.items[chooseHeap.count - 1]; chooseHeap.items[0] = chooseHeap.items[chooseHeap.count - 1];
chooseHeap.count--; chooseHeap.count--;
chooseHeap.Down(0); chooseHeap.Down(0);
CudaHeapNode<T> * cmpHeapData = heapData + k * (chooseHeap.items[0].index); CudaHeapNode<T> * cmpHeapData = heapData + k * (chooseHeap.items[0].index);
int cmpHeapLimit = 0; int cmpHeapLimit = 0;
if (counter + heapLimit <= k - parameter) if (counter + heapLimit <= k - parameter){
{
cmpHeapLimit = heapLimit; cmpHeapLimit = heapLimit;
} }
//take the max data from the minHeap,so start search from the leaf node /* take the max data from the minHeap,so start search from the leaf node */
for (int iterator = k - 1 - parameter; iterator >= cmpHeapLimit; --iterator) for (int iterator = k - 1 - parameter; iterator >= cmpHeapLimit; --iterator){
{ if (miss > 0){
if (miss > 0)
{
ansHeapData.Push(cmpHeapData[iterator].index, cmpHeapData[iterator].value); ansHeapData.Push(cmpHeapData[iterator].index, cmpHeapData[iterator].value);
miss--; miss--;
} }
else if (ansHeapData.items[0].value < cmpHeapData[iterator].value) else if (ansHeapData.items[0].value < cmpHeapData[iterator].value){
{
ansHeapData.ReplaceTop(cmpHeapData[iterator].index, cmpHeapData[iterator].value); ansHeapData.ReplaceTop(cmpHeapData[iterator].index, cmpHeapData[iterator].value);
} }
} }
...@@ -487,8 +476,7 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -487,8 +476,7 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
int offset = stride * k * blockIndex + offsetInBlock; int offset = stride * k * blockIndex + offsetInBlock;
T * dOutput = output + offset; T * dOutput = output + offset;
int * indexOutput = index + offset; int * indexOutput = index + offset;
for (int q = 0; q < k; ++q) for (int q = 0; q < k; ++q){
{
dOutput[stride * q] = ansHeapData.items[q].value; dOutput[stride * q] = ansHeapData.items[q].value;
indexOutput[stride * q] = ansHeapData.items[q].index; indexOutput[stride * q] = ansHeapData.items[q].index;
} }
...@@ -496,52 +484,61 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -496,52 +484,61 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
} }
__device__ __forceinline__ unsigned getLaneMaskLe() { __device__ __forceinline__
unsigned getLaneMaskLe()
{
unsigned mask; unsigned mask;
asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask)); asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
return mask; return mask;
} }
__device__ __forceinline__ int getLaneId() {
__device__ __forceinline__
int getLaneId()
{
int laneId; int laneId;
asm("mov.s32 %0, %laneid;" : "=r"(laneId)); asm("mov.s32 %0, %laneid;" : "=r"(laneId));
return laneId; return laneId;
} }
__device__ unsigned convert(float v) __device__
unsigned convert(float v)
{ {
unsigned x = __float_as_int(v); unsigned x = __float_as_int(v);
unsigned mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; unsigned mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask); return (x ^ mask);
} }
__device__ float convert(unsigned int v)
__device__
float convert(unsigned int v)
{ {
float x = __uint_as_float(v); float x = __uint_as_float(v);
return x; return x;
} }
__device__ float deconvert(unsigned int v) {
unsigned int mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
__device__
float deconvert(unsigned int v)
{
unsigned int mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask); return __int_as_float(v ^ mask);
} }
__global__ void convert2uintV2(float* input, unsigned int *output, int stride, int strideNum, int blockNum, int size) __global__
void convert2uintV2(float* input, unsigned int *output, int stride, int strideNum, int blockNum, int size)
{ {
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
int idy = blockDim.y * blockIdx.y + threadIdx.y; int idy = blockDim.y * blockIdx.y + threadIdx.y;
//int strideNum = (int)strideNumSize;
//if (flag) strideNum = strideNumSize[idy];
int blockIndex = idy / stride; int blockIndex = idy / stride;
int offsetInBlock = idy% stride; int offsetInBlock = idy% stride;
#pragma unroll #pragma unroll
for (int i = idx * stride + stride * strideNum * blockIndex + offsetInBlock; for (int i = idx * stride + stride * strideNum * blockIndex + offsetInBlock;
i < stride * strideNum * blockIndex + offsetInBlock + stride * strideNum && i < size; i < stride * strideNum * blockIndex + offsetInBlock + stride * strideNum && i < size;
i += stride * blockDim.x) i += stride * blockDim.x){
{
output[i] = convert(input[i]); output[i] = convert(input[i]);
} }
} }
__global__ void deconvert2floatV2(unsigned int * input, float *output, int stride, int strideNum, int blockNum, int size) __global__
void deconvert2floatV2(unsigned int * input, float *output, int stride, int strideNum, int blockNum, int size)
{ {
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
int idy = blockDim.y * blockIdx.y + threadIdx.y; int idy = blockDim.y * blockIdx.y + threadIdx.y;
...@@ -552,13 +549,13 @@ __global__ void deconvert2floatV2(unsigned int * input, float *output, int strid ...@@ -552,13 +549,13 @@ __global__ void deconvert2floatV2(unsigned int * input, float *output, int strid
#pragma unroll #pragma unroll
for (int i = idx * stride + stride * strideNum * blockIndex + offsetInBlock; for (int i = idx * stride + stride * strideNum * blockIndex + offsetInBlock;
i < stride * strideNum * blockIndex + offsetInBlock + stride * strideNum && i < size; i < stride * strideNum * blockIndex + offsetInBlock + stride * strideNum && i < size;
i += stride * blockDim.x) i += stride * blockDim.x){
{
output[i] = deconvert(input[i]); output[i] = deconvert(input[i]);
} }
} }
__device__ void radixCount(unsigned int *data, int limit, int *pos_count, unsigned int mask, int mask_desire, unsigned int desire, int stride, int strideNum, int blockNum) __device__
void radixCount(unsigned int *data, int limit, int *posCount, unsigned int mask, int maskDesire, unsigned int desire, int stride, int strideNum, int blockNum)
{ {
/*the idx th thread in one vector */ /*the idx th thread in one vector */
...@@ -569,149 +566,141 @@ __device__ void radixCount(unsigned int *data, int limit, int *pos_count, unsign ...@@ -569,149 +566,141 @@ __device__ void radixCount(unsigned int *data, int limit, int *pos_count, unsign
int offsetInBlock = idy% stride; int offsetInBlock = idy% stride;
for (int j = idx*stride + stride * strideNum * blockIndex + offsetInBlock; for (int j = idx*stride + stride * strideNum * blockIndex + offsetInBlock;
j< stride * strideNum * blockIndex + offsetInBlock + stride*strideNum && j<limit; j< stride * strideNum * blockIndex + offsetInBlock + stride*strideNum && j<limit;
j += stride * WORKERSNUM) j += stride * WORKERSNUM) {
{ if ((data[j] & maskDesire) == desire) {
// printf("idx:%d, idy:%d,j:%d,addpos:%d\n",idx,idy,j, (idy % WORKERSNUM)*blockDim.x + idx); if (data[j] & mask) {
if ((data[j] & mask_desire) == desire) posCount[(idy % (512 / WORKERSNUM))*blockDim.x + idx]++;
{
if (data[j] & mask)
{
pos_count[(idy % (512 / WORKERSNUM))*blockDim.x + idx]++;
} }
} }
// printf("Radix Count: %d Idx: %d,Idy: %d,end: %d\n", j,idx,idy, stride * strideNum * blockIndex + offsetInBlock + stride*strideNum);
} }
} }
//the theard number need be 32 times /* We can use this way to check thread status in a warp fastly,
__device__ void gpu_check_warp(int *smem, bool in, int *carry, int *index) note that the theard number need be 32 times */
__device__
void gpuCheckWarp(int *smem, bool in, int *carry, int *index)
{ {
int vote = __ballot_sync(0xffffffff, in); int vote = __ballot_sync(0xffffffff, in);
*index = __popc(getLaneMaskLe() & vote); *index = __popc(getLaneMaskLe() & vote);
*carry = __popc(vote); *carry = __popc(vote);
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
int warp = idx / 32; //idx 0 -- blockDim.x int warp = idx / 32;
int warp_num = blockDim.x / 32;//get each vector use how many warp int warpNum = blockDim.x / 32;
if (getLaneId() == 0)
{ if (getLaneId() == 0) {
smem[warp + warp_num * threadIdx.y] = *carry; //save each warp carry /* save each warp carry */
//printf("%d ", warp + warp_num * threadIdx.y); smem[warp + warpNum * threadIdx.y] = *carry;
} }
__syncthreads(); __syncthreads();
if (idx == 0) //use one thread to count the carry for globe the warp /* use one thread to count the carry for globe the warp */
{ if (idx == 0) {
for (int i = 1 + warp_num * threadIdx.y; i < warp_num * (threadIdx.y + 1); ++i) for (int i = 1 + warpNum * threadIdx.y; i < warpNum * (threadIdx.y + 1); ++i) {
{
smem[i] += smem[i - 1]; smem[i] += smem[i - 1];
} }
} }
__syncthreads(); __syncthreads();
if (warp % warp_num) if (warp % warpNum) {
{ *index += smem[warpNum * threadIdx.y + warp - 1];
*index += smem[warp_num * threadIdx.y + warp - 1];
} }
*carry = smem[warp_num * threadIdx.y + warp_num - 1]; *carry = smem[warpNum * threadIdx.y + warpNum - 1];
} }
__device__ void collect_number(unsigned int *data, int stride, int strideNum, int limit, unsigned int pattern, float *ans, int *ansIndex, int k) /*
collect the data bigger than pattern as ans return
*/
__device__
void collectNumber(unsigned int *data, int stride, int strideNum, int limit,
unsigned int pattern, float *ans, int *ansIndex, int k)
{ {
int idy = blockDim.y * blockIdx.y + threadIdx.y; int idy = blockDim.y * blockIdx.y + threadIdx.y;
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
int blockIndex = idy / stride; int blockIndex = idy / stride;
int offsetInBlock = idy % stride; int offsetInBlock = idy % stride;
__shared__ int smem[32]; //for count each warp's tmp carry
/* for count each warp's tmp carry */
__shared__ int smem[32];
int carry; int carry;
int index; int index;
int vector_limit = stride * strideNum * blockIndex + offsetInBlock + stride * strideNum; int vectorLimit = stride * strideNum * blockIndex + offsetInBlock + stride * strideNum;
int alibn_strideNum = strideNum; int alibnStrideNum = strideNum;
if (alibn_strideNum % blockDim.x) alibn_strideNum = alibn_strideNum + blockDim.x - (alibn_strideNum % blockDim.x); if (alibnStrideNum % blockDim.x) alibnStrideNum = alibnStrideNum + blockDim.x - (alibnStrideNum % blockDim.x);
int vector_alibn_limit = stride * strideNum * blockIndex + offsetInBlock + stride * alibn_strideNum; int vectorAlibnLimit = stride * strideNum * blockIndex + offsetInBlock + stride * alibnStrideNum;
int ans_array_index = stride * k * blockIndex + offsetInBlock; int ansArrayIndex = stride * k * blockIndex + offsetInBlock;
int ans_size = 0; int ansSize = 0;
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (int i = idx*stride + stride * strideNum * blockIndex + offsetInBlock; for (int i = idx * stride + stride * strideNum * blockIndex + offsetInBlock;
i < vector_alibn_limit; i < vectorAlibnLimit; i += stride * WORKERSNUM){
i += stride * WORKERSNUM)
{ bool hasTopk = false;
bool has_topk = false; if (i < vectorLimit&&data[i] > pattern){
if (i < vector_limit&&data[i] > pattern) hasTopk = true;
{ }
has_topk = true; gpuCheckWarp(smem, hasTopk, &carry, &index);
} if (carry > 0) {
gpu_check_warp(smem, has_topk, &carry, &index); if (hasTopk) {
if (carry>0) ans[ansArrayIndex + (index - 1) * stride] = deconvert(data[i]);
{ ansIndex[ansArrayIndex + (index - 1) * stride] = i - stride * strideNum * blockIndex;
if (has_topk)
{
ans[ans_array_index + (index - 1) * stride] = deconvert(data[i]);
ansIndex[ans_array_index + (index - 1) * stride] = i - stride * strideNum * blockIndex;
} }
ans_array_index += carry * stride; ansArrayIndex += carry * stride;
ans_size += carry; ansSize += carry;
} }
__syncthreads(); __syncthreads();
} }
if (ans_size < k) if (ansSize < k){
{ int ramindNum = k - ansSize;
int ramind_num = k - ans_size;
#pragma unroll #pragma unroll
for (int i = idx*stride + stride * strideNum * blockIndex + offsetInBlock; for (int i = idx * stride + stride * strideNum * blockIndex + offsetInBlock; i < vectorAlibnLimit; i += stride * WORKERSNUM) {
i < vector_alibn_limit; bool hasTopk = false;
i += stride * WORKERSNUM) if (i < vectorLimit && data[i] == pattern) {
{ hasTopk = true;
bool has_topk = false;
if (i < vector_limit&&data[i] == pattern)
{
has_topk = true;
} }
gpu_check_warp(smem, has_topk, &carry, &index);
if (carry>0) gpuCheckWarp(smem, hasTopk, &carry, &index);
{
int check_tmp_index = ans_array_index + (index - 1) * stride; if (carry>0) {
// for don't pointer boundary overflow ,for instance,if there need one index,but two index fits ,wo should filter the bigger index int checkTmpIndex = ansArrayIndex + (index - 1) * stride;
if (has_topk && check_tmp_index <stride * k * blockIndex + offsetInBlock + stride * k) /* for don't pointer boundary overflow, for instance,
{ if there need one index,but two index fits, wo should filter the bigger index */
ans[check_tmp_index] = deconvert(pattern); if (hasTopk && checkTmpIndex <stride * k * blockIndex + offsetInBlock + stride * k) {
ansIndex[check_tmp_index] = i - stride * strideNum * blockIndex; ans[checkTmpIndex] = deconvert(pattern);
ansIndex[checkTmpIndex] = i - stride * strideNum * blockIndex;
} }
ramind_num -= carry; ramindNum -= carry;
ans_array_index += carry * stride; ansArrayIndex += carry * stride;
if (ramind_num <= 0) break; if (ramindNum <= 0) break;
} }
__syncthreads(); __syncthreads();
} }
} }
} }
__device__ void collect_number_old(unsigned int *data, int n, int k, unsigned int pattern, unsigned int *ans, int *indexNum, int stride, int strideNum) /*
This is an old way,we use one thread to collect number and this way is very slow,so we drop it
*/
__device__
void collectNumberOld(unsigned int *data, int n, int k, unsigned int pattern, unsigned int *ans, int *indexNum, int stride, int strideNum)
{ {
int idy = blockDim.y * blockIdx.y + threadIdx.y; int idy = blockDim.y * blockIdx.y + threadIdx.y;
int blockIndex = idy / stride; int blockIndex = idy / stride;
int offsetInBlock = idy % stride; int offsetInBlock = idy % stride;
int cot = 0; int cot = 0;
for (int i = stride * strideNum * blockIndex + offsetInBlock, j = 0; j < strideNum; j++, i += stride) for (int i = stride * strideNum * blockIndex + offsetInBlock, j = 0; j < strideNum; j++, i += stride) {
{ if (data[i] > pattern) {
if (data[i] > pattern)
{
ans[cot] = data[i]; ans[cot] = data[i];
indexNum[cot++] = j; indexNum[cot++] = j;
} }
} }
/*if the cot < k ,so the left value must be desire*/ /* if the cot < k ,so the left value must be desire */
if (cot < k) if (cot < k) {
{ for (int i = cot; i < k; ++i) {
for (int i = cot; i < k; ++i)
{
ans[i] = pattern; ans[i] = pattern;
} }
//count the remain index and the data value must equal pattern /* count the remain index and the data value must equal pattern */
for (int i = stride * strideNum * blockIndex + offsetInBlock, j = 0; j < strideNum; j++, i += stride) for (int i = stride * strideNum * blockIndex + offsetInBlock, j = 0; j < strideNum; j++, i += stride) {
{ if (data[i] == pattern) {
if (data[i] == pattern)
{
indexNum[cot++] = j; indexNum[cot++] = j;
if (cot == k) break; if (cot == k) break;
} }
...@@ -719,8 +708,12 @@ __device__ void collect_number_old(unsigned int *data, int n, int k, unsigned in ...@@ -719,8 +708,12 @@ __device__ void collect_number_old(unsigned int *data, int n, int k, unsigned in
} }
} }
/*
When k is very big, we can't use share memory to calculate, so we use radix select algorithm
*/
template<class T> __global__ template<class T> __global__
void KernelTopKRadixSelect(unsigned int * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int* index, int limit) void KernelTopKRadixSelect(unsigned int * input, int stride, int strideNum,
int blockNum, int k, T minValue, T * output, int* index, int limit)
{ {
/* the idx th thread in one vector */ /* the idx th thread in one vector */
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
...@@ -733,73 +726,71 @@ void KernelTopKRadixSelect(unsigned int * input, int stride, int strideNum, int ...@@ -733,73 +726,71 @@ void KernelTopKRadixSelect(unsigned int * input, int stride, int strideNum, int
if (idy >= stride *blockNum) return; if (idy >= stride *blockNum) return;
int mask_desire = 0; int maskDesire = 0;
unsigned int mask = 0x80000000; unsigned int mask = 0x80000000;
unsigned int desire = 0; unsigned int desire = 0;
__shared__ int pos_count[32 * 32]; __shared__ int posCount[32 * 32];
int tmp_k = k; int tmpK = k;
//if (idx == 0)
//printf("%d %d blockSize: <%d ,%d>\n", idx + blockDim.x*idy,idy, blockDim.x, blockDim.y);
int flag = 1; int flag = 1;
#pragma unroll #pragma unroll
for (int i = 0; i < 32; i++) for (int i = 0; i < 32; i++){
{ /* we need to clean the shared memory every loop */
//we need to clearn the shared memory every loop
pos_count[idx + blockDim.x*(idy % (512 / WORKERSNUM))] = 0; posCount[idx + blockDim.x*(idy % (512 / WORKERSNUM))] = 0;
if (flag) if (flag)
radixCount(input, stride*strideNum*blockNum, pos_count, mask, mask_desire, desire, stride, strideNum, blockNum); radixCount(input, stride*strideNum*blockNum, posCount, mask, maskDesire, desire, stride, strideNum, blockNum);
__syncthreads(); __syncthreads();
int sumCount = 0; int sumCount = 0;
#pragma unroll #pragma unroll
for (int j = 0; j < WORKERSNUM; j++) for (int j = 0; j < WORKERSNUM; j++) {
{ sumCount += posCount[(idy % (512 / WORKERSNUM))*blockDim.x + j];
sumCount += pos_count[(idy % (512 / WORKERSNUM))*blockDim.x + j];
} }
__syncthreads(); __syncthreads();
if (tmp_k<sumCount)//this position should be 1
{ if (tmpK<sumCount) {
/* this position should be 1 */
desire = mask^desire; desire = mask^desire;
} }
else //zoom out the k size,this position should be 0 else {
{ /* zoom out the k size,this position should be 0 */
tmp_k = tmp_k - sumCount; tmpK = tmpK - sumCount;
if (tmpK == 0){
if (tmp_k == 0) desire = (~(maskDesire >> 1)) | desire;
{ /* avoid Synchronize deadlock ,can't use break,so we use flag */
desire = (~(mask_desire >> 1)) | desire;
// avoid Synchronize deadlock
//break; //break;
flag = 0; flag = 0;
} }
} }
mask_desire = mask^mask_desire; maskDesire = mask^maskDesire;
mask = mask >> 1; mask = mask >> 1;
} }
__syncthreads(); __syncthreads();
//if (idx == 0) /* old way to collect number */
//{ /*
// unsigned int* uintOutput = new unsigned int; if (idx == 0)
// int* tmpIndex = new int; {
// //*******************something worng*************************** unsigned int* uintOutput = new unsigned int;
// cudaMalloc((void **)&uintOutput, sizeof(unsigned int)* k); int* tmpIndex = new int;
// cudaMalloc((void **)&tmpIndex, sizeof(unsigned int)*k); //*******************something worng***************************
// //************************************************************* cudaMalloc((void **)&uintOutput, sizeof(unsigned int)* k);
// collect_number_old(input, limit, k, desire, uintOutput, tmpIndex, stride, strideNum); cudaMalloc((void **)&tmpIndex, sizeof(unsigned int)*k);
// int blockIndex = idy / stride; //*************************************************************
// int offsetInBlock = idy% stride; collectNumberOld(input, limit, k, desire, uintOutput, tmpIndex, stride, strideNum);
int blockIndex = idy / stride;
// for (int i = stride * k * blockIndex + offsetInBlock, j = 0; j < k; j++, i += stride) int offsetInBlock = idy% stride;
// {
// //for(int i = ) for (int i = stride * k * blockIndex + offsetInBlock, j = 0; j < k; j++, i += stride)
// output[i] = deconvert(uintOutput[j]); {
// index[i] = tmpIndex[j]; //for(int i = )
// } output[i] = deconvert(uintOutput[j]);
//} index[i] = tmpIndex[j];
//__syncthreads(); }
}
collect_number(input, stride, strideNum, limit, desire, output, index, k); __syncthreads();
*/
collectNumber(input, stride, strideNum, limit, desire, output, index, k);
} }
/* /*
...@@ -828,13 +819,14 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -828,13 +819,14 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
for (int i = dimRDI + 1; i < a->order; i++) for (int i = dimRDI + 1; i < a->order; i++)
blockNum *= a->dimSizeRDI[i]; blockNum *= a->dimSizeRDI[i];
int workerNum = blockNum < 16 ? 64 : 32; // should be tuned for better performance int workerNum = blockNum < 16 ? 64 : 32;
/*adjust the thread num according size of k for fitting the share memory size*/ /* adjust the thread num according size of k for fitting the share memory size */
if (k< 6) workerNum = 512; if (k< 6) workerNum = 512;
else if (k < 11) workerNum = 256; else if (k < 11) workerNum = 256;
else if (k < 22) workerNum = 128; else if (k < 22) workerNum = 128;
else if (k < 44) workerNum = 64; else if (k < 44) workerNum = 64;
else workerNum = 32; else workerNum = 32;
int cudaGrids[3]; int cudaGrids[3];
int cudaBlocks[3]; int cudaBlocks[3];
...@@ -842,22 +834,6 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -842,22 +834,6 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
workerNum, stride * blockNum, MAX_INT, workerNum, stride * blockNum, MAX_INT,
cudaGrids, cudaBlocks); cudaGrids, cudaBlocks);
/*for (int i = 0; i < 2; i++) {
if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) >= SHARED_MEMORY_SIZE) {
if (cudaBlocks[1] >= 2 && cudaBlocks[1] % 2 == 0) {
cudaBlocks[1] /= 2;
cudaGrids[1] *= 2;
}
}
if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) >= SHARED_MEMORY_SIZE) {
if (cudaBlocks[0] >= 2 && cudaBlocks[0] % 2 == 0) {
cudaBlocks[0] /= 2;
cudaGrids[0] *= 2;
}
}
}*/
int devIDBackup = 0; int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup); ProtectCudaDev(a->devID, devIDBackup);
...@@ -866,7 +842,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -866,7 +842,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
cudaBlocks[1] = 1; cudaBlocks[1] = 1;
if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) < SHARED_MEMORY_SIZE) { if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) < SHARED_MEMORY_SIZE) {
if (a->dataType == DEFAULT_DTYPE) { if (a->dataType == DEFAULT_DTYPE) {
KernelTopK2<DTYPE> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> > KernelTopK3<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN, ((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN,
(DTYPE*)b->data, (int*)index->data); (DTYPE*)b->data, (int*)index->data);
} }
...@@ -896,17 +872,14 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -896,17 +872,14 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
GDevs.GetCudaThread2D(a->mem->devID, GDevs.GetCudaThread2D(a->mem->devID,
workerNum, stride * blockNum, MAX_INT, workerNum, stride * blockNum, MAX_INT,
cudaGrids, cudaBlocks); cudaGrids, cudaBlocks);
//printf("dim is %d %d %d %d\n", cudaGrids[0], cudaGrids[1], cudaBlocks[0], cudaBlocks[1]);
if (a->dataType == DEFAULT_DTYPE) { if (a->dataType == DEFAULT_DTYPE) {
unsigned int* goutput = (unsigned int *)a->data; unsigned int* goutput = (unsigned int *)a->data;
//two all almost the same time /* two way all almost the same time to convert data*/
convert2uintV2 << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> > ((float*)a->data, goutput, stride, strideNumA, blockNum, strideNumA*blockNum*stride); convert2uintV2 <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> ((float*)a->data, goutput, stride, strideNumA, blockNum, strideNumA*blockNum*stride);
//convert2uintV2 << <dim3(1, stride * blockNum), dim3(512,1) >> >((float*)a->data, goutput, stride, strideNumA, blockNum, strideNumA*blockNum*stride); //convert2uintV2 << <dim3(1, stride * blockNum), dim3(512,1) >> >((float*)a->data, goutput, stride, strideNumA, blockNum, strideNumA*blockNum*stride);
KernelTopKRadixSelect<DTYPE> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> > (goutput, stride, strideNumA, blockNum, k, DTYPE_MIN, (DTYPE *)b->data, (int *)index->data, stride * strideNumA * blockNum); KernelTopKRadixSelect<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> (goutput, stride, strideNumA, blockNum, k, DTYPE_MIN, (DTYPE *)b->data, (int *)index->data, stride * strideNumA * blockNum);
deconvert2floatV2 << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> > ((unsigned int *)a->data, (float *)goutput, stride, strideNumA, blockNum, strideNumA*blockNum*stride); deconvert2floatV2 <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> ((unsigned int *)a->data, (float *)goutput, stride, strideNumA, blockNum, strideNumA*blockNum*stride);
//int *indexTensorData = (int *)malloc(4 * strideNumA*blockNum*stride);
//cudaMemcpy(indexTensorData, index->data, sizeof(DTYPE)*index->unitNum, cudaMemcpyDeviceToHost);
} }
} }
......
...@@ -155,7 +155,11 @@ void KernelSoftmaxComputeTensor(__half * x, __half * max, __half * sum, __half * ...@@ -155,7 +155,11 @@ void KernelSoftmaxComputeTensor(__half * x, __half * max, __half * sum, __half *
} }
} }
__device__ __forceinline__ float broadCast(float input) /*
use PTX code to broadcast float data
*/
__device__ __forceinline__
float broadcast(float input)
{ {
float output; float output;
asm( asm(
...@@ -167,28 +171,28 @@ __device__ __forceinline__ float broadCast(float input) ...@@ -167,28 +171,28 @@ __device__ __forceinline__ float broadCast(float input)
return output; return output;
} }
/*
use warp broadcast to optimize softmax computing
*/
__global__ __global__
void KernelSoftmaxComputeTensorUseBroadcast(DTYPE * input, DTYPE * max, DTYPE * sum, DTYPE * output, int stride, int strideNum, int blockNum) void KernelSoftmaxComputeTensorUseBroadcast(DTYPE * input, DTYPE * max, DTYPE * sum, DTYPE * output,
int stride, int strideNum, int blockNum)
{ {
int i = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.x * blockIdx.x + threadIdx.x;
int j = blockDim.y * blockIdx.y + threadIdx.y; int j = blockDim.y * blockIdx.y + threadIdx.y;
int i2 = j % stride; int i2 = j % stride;
int blockSize = stride * strideNum; int blockSize = stride * strideNum;
if (j < stride * blockNum)
{ if (j < stride * blockNum) {
DTYPE sumData, maxData; DTYPE sumData, maxData;
if (i % 32 == 0) if (i % 32 == 0) {
{
sumData = sum[j]; sumData = sum[j];
maxData = max[j]; maxData = max[j];
} }
//sumData = __shfl_sync(0xffffffff,sumData, 0); sumData = broadcast(sumData);
//maxData = __shfl_sync(0xffffffff,maxData, 0); maxData = broadcast(maxData);
sumData = broadCast(sumData); if (i < strideNum){
maxData = broadCast(maxData);
if (i < strideNum)
{
int offset = int(j / stride) * blockSize + i * stride + i2; int offset = int(j / stride) * blockSize + i * stride + i2;
output[offset] = exp(input[offset] - maxData) / sumData; output[offset] = exp(input[offset] - maxData) / sumData;
} }
...@@ -223,20 +227,18 @@ void _CudaSoftmaxSumMax(const XTensor * x, XTensor * y, int leadDim, XTensor * s ...@@ -223,20 +227,18 @@ void _CudaSoftmaxSumMax(const XTensor * x, XTensor * y, int leadDim, XTensor * s
int cudaGridSize[3]; int cudaGridSize[3];
int cudaBlockSize[3]; int cudaBlockSize[3];
if (leadDim != 0 || dimensionSize <= 10) if (leadDim != 0 || dimensionSize <= 10){
{ /* allocate thread num for old function */
//allocate thread num for old function
GDevs.GetCudaThread2D(x->devID, stride * blockNum, dimensionSize, MAX_INT, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread2D(x->devID, stride * blockNum, dimensionSize, MAX_INT, cudaGridSize, cudaBlockSize);
} }
else else {
{ /* allocate thread num for new function */
//allocate thread num for new function
GDevs.GetCudaThread2D(x->devID, dimensionSize, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread2D(x->devID, dimensionSize, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
if (cudaBlockSize[0] < 32) if (cudaBlockSize[0] < 32) {
{ /* use at least a warp */
cudaBlockSize[0] = 32;//use at least a warp cudaBlockSize[0] = 32;
if (cudaBlockSize[1] > 32)
{ if (cudaBlockSize[1] > 32) {
cudaGridSize[1] = int(ceil(float(stride * blockNum) / 32)); cudaGridSize[1] = int(ceil(float(stride * blockNum) / 32));
cudaBlockSize[1] = 32; cudaBlockSize[1] = 32;
} }
...@@ -246,23 +248,21 @@ void _CudaSoftmaxSumMax(const XTensor * x, XTensor * y, int leadDim, XTensor * s ...@@ -246,23 +248,21 @@ void _CudaSoftmaxSumMax(const XTensor * x, XTensor * y, int leadDim, XTensor * s
ProtectCudaDev(x->devID, devIDBackup); ProtectCudaDev(x->devID, devIDBackup);
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){ if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
if (leadDim != 0 || dimensionSize <= 10) if (leadDim != 0 || dimensionSize <= 10) {
{ KernelSoftmaxComputeTensor <<< dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >>>
KernelSoftmaxComputeTensor << <dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >> > ((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data,
((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data, stride, dimensionSize, stride * dimensionSize, blockNum, stride * blockNum);
stride, dimensionSize, stride * dimensionSize, blockNum, stride * blockNum);
} }
else else {
{ KernelSoftmaxComputeTensorUseBroadcast <<< dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >>>
KernelSoftmaxComputeTensorUseBroadcast << <dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >> > ((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data,
((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data, stride, dimensionSize, blockNum);
stride, dimensionSize, blockNum);
} }
} }
else if(x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16){ else if(x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16){
KernelSoftmaxComputeTensor<<<dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1])>>> KernelSoftmaxComputeTensor <<< dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >>>
((__half*)x->data, (__half*)max->data, (__half*)sum->data, (__half*)y->data, ((__half*)x->data, (__half*)max->data, (__half*)sum->data, (__half*)y->data,
stride, dimensionSize, blockNum); stride, dimensionSize, blockNum);
} }
else{ else{
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论