Commit 467c2ed7 by 张裕浩

Add reduceMin operation using #define

parent c9ef15f8
......@@ -28,6 +28,8 @@
namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
get the max value of the items along a dimension of the tensor
......@@ -35,129 +37,147 @@ get the max value of the items along a dimension of the tensor
>> output - the output tensor
>> dim - the dimension where the reduction is performed on
*/
void _ReduceMax(const XTensor * input, XTensor * output, int dim)
{
CheckNTErrors((input->devID == output->devID || (input->devID < 0 && output->devID < 0)),
"This code must be run on the same device!");
CheckNTErrors((input && output), "Empty input or output tensors!");
CheckNTErrors((input->order == output->order + 1), "Incorrect tensor sizes!");
CheckNTErrors((input->order > dim && dim >=0), "Illegal dimension to reduce!");
CheckNTErrors((input->dataType == output->dataType), "Unmatched data types!");
CheckNTErrors(dim < input->order, "Wrong dimension!");
for(int i = 0; i < input->order; i++){
if(i < dim){
CheckNTErrors((input->dimSize[i] == output->dimSize[i]),
"Unmatched tensors!");
}
else if(i > dim){
CheckNTErrors((input->dimSize[i] == output->dimSize[i - 1]),
"Unmatched tensors!");
}
}
if(input->devID >= 0){
#ifdef USE_CUDA
_CudaReduceMax(input, output, dim);
#endif
}
else{
CheckNTErrors((input->dataType == DEFAULT_DTYPE), "TODO!");
int stride = 1;
int strideNum = input->dimSize[dim];
int blockSize = 1;
int blockNum = 1;
for (int i = 0; i < input->order; i++) {
if (i > dim)
stride *= input->dimSize[i];
else if (i < dim)
blockNum *= input->dimSize[i];
}
blockSize = stride * strideNum;
#define _REDUCE_CPU_FUNCTION(_funcCPUName, _vectorOp, _reduceOp) \
void _funcCPUName(const XTensor * input, XTensor * output, int dim) \
{ \
CheckNTErrors((input->devID == output->devID || (input->devID < 0 && output->devID < 0)), \
"This code must be run on the same device!"); \
CheckNTErrors((input && output), "Empty input or output tensors!"); \
CheckNTErrors((input->order == output->order + 1), "Incorrect tensor sizes!"); \
CheckNTErrors((input->order > dim && dim >= 0), "Illegal dimension to reduce!"); \
CheckNTErrors((input->dataType == output->dataType), "Unmatched data types!"); \
\
CheckNTErrors(dim < input->order, "Wrong dimension!"); \
\
for (int i = 0; i < input->order; i++) { \
\
if (i < dim) { \
\
CheckNTErrors((input->dimSize[i] == output->dimSize[i]), \
"Unmatched tensors!"); \
} \
else if (i > dim) { \
CheckNTErrors((input->dimSize[i] == output->dimSize[i - 1]), \
"Unmatched tensors!"); \
} \
} \
CheckNTErrors((input->dataType == DEFAULT_DTYPE), "TODO!"); \
int stride = 1; \
int strideNum = input->dimSize[dim]; \
int blockSize = 1; \
int blockNum = 1; \
for (int i = 0; i < input->order; i++) { \
if (i > dim) \
stride *= input->dimSize[i]; \
else if (i < dim) \
blockNum *= input->dimSize[i]; \
} \
blockSize = stride * strideNum; \
\
\
if(input->dimSize[input->order - 1] % (4 * 32 / sizeof(DTYPE)) == 0 && input->dimSize[input->order - 1] >= 32){ \
int vecBufLength = 32 / sizeof(DTYPE); \
\
if (dim == input->order - 1) { \
/*data is contiguous in dim 0 */ \
for (int i = 0; i < blockNum; i++) { \
DTYPE * ip = (DTYPE*)input->data + blockSize * i; \
DTYPE * op = (DTYPE*)output->data + i; \
VectorBuffer vecBuf[4]; \
for (int j = 0; j < 4; j++) { \
vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip)+j * vecBufLength); \
} \
for (int j = 1; j < strideNum / 32; j++) { \
const DTYPE* ptr = (DTYPE*)(ip + j * vecBufLength); \
vecBuf[0] = vecBuf[0]._vectorOp(VectorBuffer::loadu(ptr + 0 * vecBufLength)); \
vecBuf[1] = vecBuf[1]._vectorOp(VectorBuffer::loadu(ptr + 1 * vecBufLength)); \
vecBuf[2] = vecBuf[2]._vectorOp(VectorBuffer::loadu(ptr + 2 * vecBufLength)); \
vecBuf[3] = vecBuf[3]._vectorOp(VectorBuffer::loadu(ptr + 3 * vecBufLength)); \
} \
vecBuf[0] = vecBuf[0]._vectorOp(vecBuf[1]); \
vecBuf[0] = vecBuf[0]._vectorOp(vecBuf[2]); \
vecBuf[0] = vecBuf[0]._vectorOp(vecBuf[3]); \
DTYPE maxN = vecBuf[0][0]; \
for (int k = 1; k < vecBufLength; k++) { \
maxN = _reduceOp(maxN, vecBuf[0][k]); \
} \
*op = maxN; \
} \
\
} \
else { \
/* data is separated */ \
for(int i = 0; i < blockNum; i++){ \
for(int j = 0; j < input->dimSize[input->order - 1] / 32; j++){ \
DTYPE * ip = (DTYPE*)input->data + blockSize * i; \
DTYPE * op = (DTYPE*)output->data + stride * i; \
VectorBuffer vecBuf[4]; \
for(int k = 0; k < 4; k++){ \
vecBuf[k] = VectorBuffer::loadu((DTYPE*)(ip) + (j * 4 + k) * 32 / sizeof(DTYPE)); \
\
} \
for(int k = 1; k < strideNum; k++){ \
DTYPE * ptr = ip + k * stride + (j * 4) * vecBufLength; \
vecBuf[0] = vecBuf[0]._vectorOp(VectorBuffer::loadu(ptr + 0 * vecBufLength)); \
vecBuf[1] = vecBuf[1]._vectorOp(VectorBuffer::loadu(ptr + 1 * vecBufLength)); \
vecBuf[2] = vecBuf[2]._vectorOp(VectorBuffer::loadu(ptr + 2 * vecBufLength)); \
vecBuf[3] = vecBuf[3]._vectorOp(VectorBuffer::loadu(ptr + 3 * vecBufLength)); \
} \
for(int k = 0; k < 4; k++){ \
for(int l = 0; l < vecBufLength; l++) \
*(op + j * 32 + 8 * k + l) = vecBuf[k][l]; \
} \
} \
} \
} \
}/* run vector buffer */ \
else{ \
for(int k = 0; k < blockNum; k++){ \
DTYPE * ip = (DTYPE*)input->data + blockSize * k; \
DTYPE * op = (DTYPE*)output->data + stride * k; \
for(int i = 0; i < stride; i++){ \
DTYPE * ipe = ip + blockSize; \
DTYPE tmpData = *(ip + i); \
for(DTYPE * ipb = ip + i + stride; ipb < ipe; ipb += stride){ \
DTYPE v = *ipb; \
tmpData = _reduceOp(tmpData, v); \
} \
*(op + i) = tmpData; \
} \
} \
} \
}
if(input->dimSize[input->order - 1] % (4 * 32 / sizeof(DTYPE)) == 0 && input->dimSize[input->order - 1] >= 32){
int vecBufLength = 32 / sizeof(DTYPE);
if (dim == input->order - 1) {
//data is contiguous in dim 0
for (int i = 0; i < blockNum; i++) {
DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE * op = (DTYPE*)output->data + i;
VectorBuffer vecBuf[4];
for (int j = 0; j < 4; j++) {
vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip)+j * vecBufLength);
}
for (int j = 1; j < strideNum / 32; j++) {
const DTYPE* ptr = (DTYPE*)(ip + j * vecBufLength);
vecBuf[0] = vecBuf[0].maxData(VectorBuffer::loadu(ptr + 0 * vecBufLength));
vecBuf[1] = vecBuf[1].maxData(VectorBuffer::loadu(ptr + 1 * vecBufLength));
vecBuf[2] = vecBuf[2].maxData(VectorBuffer::loadu(ptr + 2 * vecBufLength));
vecBuf[3] = vecBuf[3].maxData(VectorBuffer::loadu(ptr + 3 * vecBufLength));
}
vecBuf[0] = vecBuf[0].maxData(vecBuf[1]);
vecBuf[0] = vecBuf[0].maxData(vecBuf[2]);
vecBuf[0] = vecBuf[0].maxData(vecBuf[3]);
DTYPE maxN = DTYPE_MIN;
for (int k = 0; k < vecBufLength; k++) {
maxN = MAX(maxN, vecBuf[0][k]);
}
*op = maxN;
}
}
else {
//data is separated
for(int i = 0; i < blockNum; i++){
for(int j = 0; j < input->dimSize[input->order - 1] / 32; j++){
DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE * op = (DTYPE*)output->data + stride * i;
VectorBuffer vecBuf[4];
for(int k = 0; k < 4; k++){
vecBuf[k] = VectorBuffer::loadu((DTYPE*)(ip) + (j * 4 + k) * 32 / sizeof(DTYPE));
}
for(int k = 1; k < strideNum; k++){
DTYPE * ptr = ip + k * stride + (j * 4) * vecBufLength;
vecBuf[0] = vecBuf[0].maxData(VectorBuffer::loadu(ptr + 0 * vecBufLength));
vecBuf[1] = vecBuf[1].maxData(VectorBuffer::loadu(ptr + 1 * vecBufLength));
vecBuf[2] = vecBuf[2].maxData(VectorBuffer::loadu(ptr + 2 * vecBufLength));
vecBuf[3] = vecBuf[3].maxData(VectorBuffer::loadu(ptr + 3 * vecBufLength));
}
for(int k = 0; k < 4; k++){
for(int l = 0; l < vecBufLength; l++)
*(op + j * 32 + 8 * k + l) = vecBuf[k][l];
}
}
}
}
}//run vector buffer
else{
for(int k = 0; k < blockNum; k++){
DTYPE * ip = (DTYPE*)input->data + blockSize * k;
DTYPE * op = (DTYPE*)output->data + stride * k;
for(int i = 0; i < stride; i++){
//#if defined(USE_BLAS)
// *(op + i) = *(ip + i + (int)(stride * IAMAX(strideNum, ip + i, stride)));
//#else
DTYPE max = DTYPE_MIN;
DTYPE * ipe = ip + blockSize;
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE v = *ipb;
if(max < v)
max = v;
}
*(op + i) = max;
//#endif
}
}
}
}
_REDUCE_CPU_FUNCTION(reduceMaxCPU, maxData, MAX)
_REDUCE_CPU_FUNCTION(reduceMinCPU, minData, MIN)
#ifdef USE_CUDA
#define _REDUCE_FUNCTION(_funcName, _cudaFuncName) \
void _funcName(const XTensor * input, XTensor * output, int dim) \
{ \
if(input->devID >= 0){ \
_cudaFuncName(input, output, dim); \
} \
else{ \
reduceMaxCPU(input, output, dim); \
} \
}
_REDUCE_FUNCTION(_ReduceMax, _CudaReduceMax)
_REDUCE_FUNCTION(_ReduceMin, _CudaReduceMin)
#else
#define _REDUCE_FUNCTION(_funcName, reduceNameCPU) \
void _funcName(const XTensor * input, XTensor * output, int dim) \
{ \
CheckNTErrors((input->devID < 0), "This code must be run on the CPU!"); \
reduceNameCPU(input, output, dim); \
}
_REDUCE_FUNCTION(_ReduceMax, reduceMaxCPU)
_REDUCE_FUNCTION(_ReduceMin, reduceMinCPU)
#endif
/*
/*
get the max value of the items along a dimension of the tensor (return an XTensor structure).
make a new tensor to keep the result and return it
......@@ -165,34 +185,38 @@ make a new tensor to keep the result and return it
>> dim - the dimension where the reduction is performed on
<< return - the max value of the items along a dimension of the tensor
*/
XTensor ReduceMax(const XTensor &input, int dim)
{
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!");
int order = input.order - 1;
int * dimSize = new int[order];
for(int i = 0; i < order; i++){
if(i < dim)
dimSize[i] = input.dimSize[i];
else if(i >= dim)
dimSize[i] = input.dimSize[i + 1];
}
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
XTensor output(order, dimSize, input.dataType, dr, input.devID, input.mem);
output.SetTMPFlag();
/* call _ReduceMax function */
_ReduceMax(&input, &output, dim);
/* tensor connection */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX);
XLink::AddParamToHeadInt(&output, dim);
/* destroy variables */
delete[] dimSize;
return output;
#define REDUCE_FUNCTION(funcName, funcOp) \
XTensor funcName(const XTensor & input, int dim) \
{ \
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!"); \
\
int order = input.order - 1; \
int * dimSize = new int[order]; \
for(int i = 0; i < order; i++){ \
if(i < dim) \
dimSize[i] = input.dimSize[i]; \
else if(i >= dim) \
dimSize[i] = input.dimSize[i + 1]; \
} \
\
float dr = (!input.isSparse) ? 1.0F : input.denseRatio; \
XTensor output(order, dimSize, input.dataType, dr, input.devID, input.mem); \
output.SetTMPFlag(); \
\
/* call _ReduceMax function */ \
funcOp(&input, &output, dim); \
\
/* tensor connection */ \
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX); \
XLink::AddParamToHeadInt(&output, dim); \
\
/* destroy variables */ \
delete[] dimSize; \
\
return output; \
}
REDUCE_FUNCTION(ReduceMax, _ReduceMax)
REDUCE_FUNCTION(ReduceMin, _ReduceMin)
} // namespace nts(NiuTrans.Tensor)
......@@ -33,67 +33,75 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
use PTX code to reduce float data
*/
__device__ __forceinline__
float shflDownReduceMax(float input)
{
float output;
asm volatile(
"{"
".reg .f32 r0;"
".reg .pred p;"
"shfl.sync.down.b32 r0, %1, 0x10, 0x1f,0xffffffff;"
"setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x8, 0xf,0xffffffff;"
"setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff;"
"setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff;"
"setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff;"
"setp.lt.f32 p, %1, r0; "
"@p mov.f32 %1,r0;"
"mov.f32 %0,%1;"
"}"
: "=f"(output) : "f"(input));
return output;
#define SHLFUNCFLOAT(funcName, reducePTXOp) \
__device__ __forceinline__ \
float funcName(float input) \
{ \
float output; \
asm volatile( \
"{" \
".reg .f32 r0;" \
".reg .pred p;" \
"shfl.sync.down.b32 r0, %1, 0x10, 0x1f,0xffffffff;" \
"setp."#reducePTXOp".f32 p,%1,r0;" \
"@p mov.f32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x8, 0xf,0xffffffff;" \
"setp."#reducePTXOp".f32 p,%1,r0;" \
"@p mov.f32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff;" \
"setp."#reducePTXOp".f32 p,%1,r0;" \
"@p mov.f32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff;" \
"setp."#reducePTXOp".f32 p,%1,r0;" \
"@p mov.f32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff;" \
"setp."#reducePTXOp".f32 p, %1, r0; " \
"@p mov.f32 %1,r0;" \
"mov.f32 %0,%1;" \
"}" \
: "=f"(output) : "f"(input)); \
return output; \
}
SHLFUNCFLOAT(shflDownReduceMax, lt)
SHLFUNCFLOAT(shflDownReduceMin, gt)
/*
use PTX code to reduce int data
*/
__device__ __forceinline__
int shflDownReduceMax(int input)
{
int output;
asm volatile(
"{"
".reg .s32 r0;"
".reg .pred p;"
"shfl.sync.down.b32 r0, %1, 0x10, 0x1f,0xffffffff;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x8, 0xf,0xffffffff;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff;"
"setp.lt.s32 p, %1, r0; "
"@p mov.s32 %1,r0;"
"mov.s32 %0,%1;"
"}"
: "=r"(output) : "r"(input));
return output;
#define SHLFUNCINT(funcName, reducePTXOp) \
__device__ __forceinline__ \
int funcName(int input) \
{ \
int output; \
asm volatile( \
"{" \
".reg .s32 r0;" \
".reg .pred p;" \
"shfl.sync.down.b32 r0, %1, 0x10, 0x1f,0xffffffff;" \
"setp."#reducePTXOp".s32 p,%1,r0;" \
"@p mov.s32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x8, 0xf,0xffffffff;" \
"setp."#reducePTXOp".s32 p,%1,r0;" \
"@p mov.s32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff;" \
"setp."#reducePTXOp".s32 p,%1,r0;" \
"@p mov.s32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff;" \
"setp."#reducePTXOp".s32 p,%1,r0;" \
"@p mov.s32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff;" \
"setp."#reducePTXOp".s32 p, %1, r0; " \
"@p mov.s32 %1,r0;" \
"mov.s32 %0,%1;" \
"}" \
: "=r"(output) : "r"(input)); \
return output; \
}
SHLFUNCINT(shflDownReduceMax, lt)
SHLFUNCINT(shflDownReduceMin, gt)
/*
reduce a tensor to another that keeps the max value along a dimension - slow version
Given a block of data, we go over each dimension i in the stride and we have
......@@ -108,48 +116,52 @@ crossing of the i-th columne and the j-th row.
>> blockSize - size of the block (i.e., stride * strideNum)
>> blockNum - how many blocks
*/
__global__
void KernelReduceMax(DTYPE * input, DTYPE * output,
int stride, int strideNum, int reducedStrideNum,
int blockSize, int blockNum)
{
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK * MIN_CUDA_SHARED_MEM_COL_SIZE/2];
int idx = threadIdx.x * blockDim.y + threadIdx.y;
unsigned int i = blockIdx.x*blockDim.x + threadIdx.x;
unsigned int j = blockIdx.y*blockDim.y + threadIdx.y;
if(i >= stride * blockNum)
return;
__syncthreads();
int k = i / stride;
int iOffset = i % stride;
DTYPE value = (i < stride * blockNum && j < strideNum) ?
input[blockSize * k + stride * j + iOffset] : FLOAT_MIN;
/* load data into the shared mem */
iData[threadIdx.x * blockDim.y + threadIdx.y] = value;
__syncthreads();
/* do reduction in shared mem */
for (unsigned int s = blockDim.y/2; s > 0; s >>= 1){
if(threadIdx.y < s && iData[idx] < iData[idx + s]){
iData[idx] = iData[idx + s];
}
__syncthreads();
}
/* write result for this block to the output array */
if (threadIdx.y == 0 && blockIdx.y < reducedStrideNum)
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = iData[threadIdx.x * blockDim.y];
#define KERNELREDUCEFUN3(funName, opName, initData) \
__global__ \
void funName(DTYPE * input, DTYPE * output, \
int stride, int strideNum, int reducedStrideNum, \
int blockSize, int blockNum) \
{ \
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK * MIN_CUDA_SHARED_MEM_COL_SIZE/2]; \
\
int idx = threadIdx.x * blockDim.y + threadIdx.y; \
unsigned int i = blockIdx.x*blockDim.x + threadIdx.x; \
unsigned int j = blockIdx.y*blockDim.y + threadIdx.y; \
\
if(i >= stride * blockNum) \
return; \
\
__syncthreads(); \
\
int k = i / stride; \
int iOffset = i % stride; \
\
DTYPE value = (i < stride * blockNum && j < strideNum) ? \
input[blockSize * k + stride * j + iOffset] : initData; \
\
/* load data into the shared mem */ \
iData[threadIdx.x * blockDim.y + threadIdx.y] = value; \
\
__syncthreads(); \
\
/* do reduction in shared mem */ \
for (unsigned int s = blockDim.y/2; s > 0; s >>= 1){ \
if(threadIdx.y < s){ \
iData[idx] = opName(iData[idx + s], iData[idx]); \
} \
\
__syncthreads(); \
} \
\
/* write result for this block to the output array */ \
if (threadIdx.y == 0 && blockIdx.y < reducedStrideNum) \
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = iData[threadIdx.x * blockDim.y]; \
\
}
KERNELREDUCEFUN3(KernelReduceMax, MAX, FLOAT_MIN)
KERNELREDUCEFUN3(KernelReduceMin, MIN, MAX_FLOAT)
/*
reduce a tensor to another that keeps the max value along a dimension - slow version
Given a block of data, we go over each dimension i in the stride and we have
......@@ -231,48 +243,52 @@ reduce a tensor to another that keeps the max value along a dimension - fast ve
>> blockSize - size of the block (i.e., stride * strideNum)
>> blockNum - how many blocks
*/
template <unsigned int goodSize> __global__
void KernelReduceMaxFast(DTYPE * input, DTYPE * output,
int stride, int strideNum, int reducedStrideNum,
int blockSize, int blockNum)
{
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK];
unsigned int tid = threadIdx.y;
unsigned int j = blockIdx.y * (blockDim.y * 2) + threadIdx.y;
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if(i >= stride * blockNum)
return;
__syncthreads();
/* first level reduction */
int k = i / stride;
int iOffset = i % stride;
DTYPE * data = iData + threadIdx.x * blockDim.y;
DTYPE * inputData = input + k * blockSize;
DTYPE value = j < strideNum ? inputData[j * stride + iOffset] : FLOAT_MIN;
DTYPE value2 = j + blockDim.y < strideNum ? inputData[(j + blockDim.y) * stride + iOffset]: FLOAT_MIN;
value = MAX(value, value2);
value = shflDownReduceMax(value);
if ((tid & 0x1f) == 0)
data[tid / 32] = value;
__syncthreads();
if (tid < 32) {
if (tid < blockDim.y / 32)
value = data[tid];
else
value = FLOAT_MIN;
value = shflDownReduceMax(value);
if (tid == 0 && blockIdx.y < reducedStrideNum)
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = value;
}
#define KERNELREDUCEFUN4(funName, opName, opFuncName, initData) \
template <unsigned int goodSize> __global__ \
void funName(DTYPE * input, DTYPE * output, \
int stride, int strideNum, int reducedStrideNum, \
int blockSize, int blockNum) \
{ \
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK]; \
\
unsigned int tid = threadIdx.y; \
unsigned int j = blockIdx.y * (blockDim.y * 2) + threadIdx.y; \
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \
\
if(i >= stride * blockNum) \
return; \
\
__syncthreads(); \
\
/* first level reduction */ \
int k = i / stride; \
int iOffset = i % stride; \
\
DTYPE * data = iData + threadIdx.x * blockDim.y; \
DTYPE * inputData = input + k * blockSize; \
DTYPE value = j < strideNum ? inputData[j * stride + iOffset] : initData; \
DTYPE value2 = j + blockDim.y < strideNum ? inputData[(j + blockDim.y) * stride + iOffset]: initData; \
\
value = opName(value, value2); \
value = opFuncName(value); \
if ((tid & 0x1f) == 0) \
data[tid / 32] = value; \
__syncthreads(); \
\
if (tid < 32) { \
if (tid < blockDim.y / 32) \
value = data[tid]; \
else \
value = initData; \
value = opFuncName(value); \
if (tid == 0 && blockIdx.y < reducedStrideNum) \
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = value; \
} \
}
KERNELREDUCEFUN4(KernelReduceMaxFast, MAX, shflDownReduceMax, FLOAT_MIN)
KERNELREDUCEFUN4(KernelReduceMinFast, MIN, shflDownReduceMin, MAX_FLOAT)
/*
reduce a tensor to another that keeps the max value along a dimension - fast version
>> input - the input array (representing a tensor)
......@@ -372,14 +388,12 @@ void KernelReduceMaxSimpleFast(DTYPE * input, DTYPE * output,
int stride4 = stride3 + stride;
for(int k = 0; k < blockSize; k += stride4){
DTYPE m = MAX(MAX(ip[k], ip[k + stride]), MAX(ip[k + stride2], ip[k + stride3]));
if(max < m)
max = m;
max = MAX(max, m);
}
}
else{
for(int k = 0; k < blockSize; k += stride)
if(max < ip[k])
max = ip[k];
for (int k = 0; k < blockSize; k += stride)
max = MAX(max, ip[k]);
}
__syncthreads();
......@@ -429,66 +443,75 @@ inline void adjustThreadForUseWarpOptimization(dim3& blocks, dim3& threads)
/*
In some case,we use less block to imporve efficiency
*/
__global__
void KernelReduceMaxOpLessBlocks(DTYPE * input, DTYPE * output, int strideNum, int blockNum)
{
int idx = threadIdx.x % 32;
int idy = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int startIndex = idy * strideNum;
DTYPE threadMax = FLOAT_MIN;
for (int i = idx; i < strideNum; i += 32) {
threadMax = max(input[startIndex + i], threadMax);
}
threadMax = shflDownReduceMax(threadMax);
if (idx == 0)
output[idy] = threadMax;
#define KERNELREDUCEFUN2(funName, opName, opFuncName, initData) \
__global__ \
void funName(DTYPE * input, DTYPE * output, int strideNum, int blockNum) \
{ \
int idx = threadIdx.x % 32; \
int idy = (blockIdx.x * blockDim.x + threadIdx.x) / 32; \
\
int startIndex = idy * strideNum; \
DTYPE threadMax = initData; \
for (int i = idx; i < strideNum; i += 32) { \
threadMax = opName(input[startIndex + i], threadMax); \
} \
threadMax = opFuncName(threadMax); \
if (idx == 0) \
output[idy] = threadMax; \
}
KERNELREDUCEFUN2(KernelReduceMaxOpLessBlocks, MAX, shflDownReduceMax, FLOAT_MIN)
KERNELREDUCEFUN2(KernelReduceMinOpLessBlocks, MIN, shflDownReduceMin, MAX_FLOAT)
/*
we use PTX code reduce
*/
__global__
void KernelReduceMaxOp(DTYPE * input, DTYPE * output,int stride, int strideNum,
int reducedStrideNum,int blockSize, int blockNum)
{
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK / 32];
unsigned int tid = threadIdx.y;
unsigned int j = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= stride * blockNum)
return;
/* first level reduction */
int k = i / stride;
int iOffset = i % stride;
DTYPE threadMax = FLOAT_MIN;
DTYPE * data = iData + threadIdx.x * blockDim.y;
DTYPE * inputData = input + k * blockSize;
for (int it = j; it < strideNum; it += blockDim.y){
threadMax = max(inputData[it * stride + iOffset], threadMax);
}
__syncthreads();
threadMax = shflDownReduceMax(threadMax);
if ((tid & 0x1f) == 0)
data[tid / 32] = threadMax;
__syncthreads();
/* use one warp to reduce remaining data */
if (tid < 32){
if (tid < blockDim.y / 32)
threadMax = data[tid];
else threadMax = FLOAT_MIN;
threadMax = shflDownReduceMax(threadMax);
if (tid == 0 && blockIdx.y < reducedStrideNum)
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = threadMax;
}
#define KERNELREDUCEFUN1(funName, opName, opFuncName, initData) \
__global__ \
void funName(DTYPE * input, DTYPE * output,int stride, int strideNum, \
int reducedStrideNum,int blockSize, int blockNum) \
{ \
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK / 32]; \
\
unsigned int tid = threadIdx.y; \
unsigned int j = blockIdx.y * blockDim.y + threadIdx.y; \
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \
if (i >= stride * blockNum) \
return; \
\
/* first level reduction */ \
int k = i / stride; \
int iOffset = i % stride; \
\
DTYPE threadMax = initData; \
\
DTYPE * data = iData + threadIdx.x * blockDim.y; \
DTYPE * inputData = input + k * blockSize; \
for (int it = j; it < strideNum; it += blockDim.y){ \
threadMax = opName(inputData[it * stride + iOffset], threadMax); \
} \
\
__syncthreads(); \
threadMax = opFuncName(threadMax); \
if ((tid & 0x1f) == 0) \
data[tid / 32] = threadMax; \
\
__syncthreads(); \
/* use one warp to reduce remaining data */ \
if (tid < 32){ \
if (tid < blockDim.y / 32) \
threadMax = data[tid]; \
else threadMax = initData; \
threadMax = opFuncName(threadMax); \
if (tid == 0 && blockIdx.y < reducedStrideNum) \
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = threadMax; \
} \
}
KERNELREDUCEFUN1(KernelReduceMaxOp, MAX, shflDownReduceMax, FLOAT_MIN)
KERNELREDUCEFUN1(KernelReduceMinOp, MIN, shflDownReduceMin, MAX_FLOAT)
/*
get the max-valued items along a dimension of the tensor (cuda version).
For a 1-dimensional data array a,
......@@ -497,202 +520,207 @@ sum_i = max_{0<=j<strideNum} input_{i,j}
>> output - the output tensor
>> dim - which dimension to reduce
*/
void _CudaReduceMax(const XTensor * input, XTensor * output, int dim)
{
CheckNTErrors(input && output, "Empty input or output tensors!");
CheckNTErrors(input->order == output->order + 1, "Incorrect tensor sizes!");
CheckNTErrors(input->order > dim && dim >=0, "Illegal dimension to reduce!");
CheckNTErrors(input->dataType == output->dataType, "Unmatched data types!");
for(int i = 0; i < input->order; i++){
if(i < dim){
CheckNTErrors(input->dimSize[i] == output->dimSize[i], "Unmatched tensors!");
}
else if(i > dim){
CheckNTErrors(input->dimSize[i] == output->dimSize[i - 1], "Unmatched tensors!");
}
}
int cudaGridSize[3];
int cudaBlockSize[3];
int iter = 0;
int stride = 1;
int strideNum = input->dimSize[dim];
int blockSize = 1;
int blockNum = 1;
for (int i = 0; i < input->order; i++) {
if (i < dim)
blockNum *= input->dimSize[i];
else if (i > dim)
stride *= input->dimSize[i];
}
blockSize = stride * strideNum;
int devID = input->devID;
XMem * mem = input->mem;
GDevs.GetCudaThread2D(devID, strideNum, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
int bufSize = sizeof(DTYPE) * cudaGridSize[0] * stride * blockNum * 2;
DTYPE * buf = mem != NULL ? (DTYPE*)mem->AllocBuf(mem->devID, bufSize) : (DTYPE*)XMemAlloc(input->devID, bufSize);
DTYPE * buf1 = buf;
DTYPE * buf2 = buf + cudaGridSize[0] * stride * blockNum;
int devIDBackup;
ProtectCudaDev(input->devID, devIDBackup);
if (stride == 1 && blockNum >= 10) {
dim3 grids;
dim3 blocks;
continuousStorageThreadAllocation(grids, blocks, (long long)blockNum, strideNum);
if (blocks.y >= 128) {
KernelReduceMaxOp <<<grids, blocks >>> ((DTYPE *)input->data, (DTYPE*)output->data, stride, strideNum, grids.y, blockSize, blockNum);
}
else {
if (blockNum % 4 != 0) blockNum = (int)(blockNum / 4) + 1;
else blockNum = blockNum / 4;
KernelReduceMaxOpLessBlocks <<<blockNum, 128 >>> ((DTYPE *)input->data, (DTYPE*)output->data, strideNum, blockNum);
}
}
else {
do {
if (input->dataType == DEFAULT_DTYPE) {
DTYPE * iData = NULL;
DTYPE * oData = NULL;
if (iter == 0) {
iData = (DTYPE*)input->data;
oData = buf1;
}
else if (iter % 2 == 1) {
iData = buf1;
oData = buf2;
}
else {
iData = buf2;
oData = buf1;
}
/* unroll the reduction procedure. The code is messy but it is faster. */
if (strideNum < 32) {
GDevs.GetCudaThread2D(devID, strideNum, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (DTYPE*)output->data;
KernelReduceMax <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else if (strideNum < 128) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (DTYPE*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 64, "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceMaxFast<64> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else if (strideNum < 256) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (DTYPE*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 128, "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceMaxFast<128> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else if (strideNum < 512) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (DTYPE*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 256, "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceMaxFast<256> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (DTYPE*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 512, "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceMaxFast<512> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
}
else if (input->dataType == X_FLOAT16) {
__half * buf1ft16 = (__half *)buf1;
__half * buf2ft16 = (__half *)buf2;
__half * iData = NULL;
__half * oData = NULL;
if (iter == 0) {
iData = (__half*)input->data;
oData = buf1ft16;
}
else if (iter % 2 == 1) {
iData = buf1ft16;
oData = buf2ft16;
}
else {
iData = buf2ft16;
oData = buf1ft16;
}
/* unroll the reduction procedure. The code is messy but it is faster. */
if (strideNum < 32) {
GDevs.GetCudaThread2D(devID, strideNum, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (__half*)output->data;
KernelReduceMax <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else if (strideNum < 128) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (__half*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 64, "Incorrect thread number when calling the cuda kernel!");
KernelReduceMaxFast<64> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else if (strideNum < 256) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (__half*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 128, "Incorrect thread number when calling the cuda kernel!");
KernelReduceMaxFast<128> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else if (strideNum < 512) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (__half*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 256, "Incorrect thread number when calling the cuda kernel!");
KernelReduceMaxFast<256> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (__half*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 512, "Incorrect thread number when calling the cuda kernel!");
KernelReduceMaxFast<512> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
}
strideNum = cudaGridSize[0];
blockSize = cudaGridSize[0];
iter++;
} while (strideNum > 1);
}
#define _CUDAREDUCE(_funcName, _reduceFunc1, _reduceFunc2, _reduceFunc3, _reduceFun4) \
void _funcName(const XTensor * input, XTensor * output, int dim) \
{ \
CheckNTErrors(input && output, "Empty input or output tensors!"); \
CheckNTErrors(input->order == output->order + 1, "Incorrect tensor sizes!"); \
CheckNTErrors(input->order > dim && dim >=0, "Illegal dimension to reduce!"); \
CheckNTErrors(input->dataType == output->dataType, "Unmatched data types!"); \
\
for(int i = 0; i < input->order; i++){ \
if(i < dim){ \
CheckNTErrors(input->dimSize[i] == output->dimSize[i], "Unmatched tensors!"); \
} \
else if(i > dim){ \
CheckNTErrors(input->dimSize[i] == output->dimSize[i - 1], "Unmatched tensors!"); \
} \
} \
\
int cudaGridSize[3]; \
int cudaBlockSize[3]; \
int iter = 0; \
int stride = 1; \
int strideNum = input->dimSize[dim]; \
int blockSize = 1; \
int blockNum = 1; \
\
for (int i = 0; i < input->order; i++) { \
if (i < dim) \
blockNum *= input->dimSize[i]; \
else if (i > dim) \
stride *= input->dimSize[i]; \
} \
blockSize = stride * strideNum; \
\
int devID = input->devID; \
XMem * mem = input->mem; \
\
GDevs.GetCudaThread2D(devID, strideNum, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
\
int bufSize = sizeof(DTYPE) * cudaGridSize[0] * stride * blockNum * 2; \
DTYPE * buf = mem != NULL ? (DTYPE*)mem->AllocBuf(mem->devID, bufSize) : (DTYPE*)XMemAlloc(input->devID, bufSize); \
DTYPE * buf1 = buf; \
DTYPE * buf2 = buf + cudaGridSize[0] * stride * blockNum; \
\
int devIDBackup; \
ProtectCudaDev(input->devID, devIDBackup); \
\
if (stride == 1 && blockNum >= 10) { \
dim3 grids; \
dim3 blocks; \
continuousStorageThreadAllocation(grids, blocks, (long long)blockNum, strideNum); \
if (blocks.y >= 128) { \
_reduceFunc1 <<<grids, blocks >>> ((DTYPE *)input->data, (DTYPE*)output->data, stride, strideNum, grids.y, blockSize, blockNum); \
} \
else { \
if (blockNum % 4 != 0) blockNum = (int)(blockNum / 4) + 1; \
else blockNum = blockNum / 4; \
_reduceFunc2 <<<blockNum, 128 >>> ((DTYPE *)input->data, (DTYPE*)output->data, strideNum, blockNum); \
} \
} \
else { \
do { \
if (input->dataType == DEFAULT_DTYPE) { \
DTYPE * iData = NULL; \
DTYPE * oData = NULL; \
if (iter == 0) { \
iData = (DTYPE*)input->data; \
oData = buf1; \
} \
else if (iter % 2 == 1) { \
iData = buf1; \
oData = buf2; \
} \
else { \
iData = buf2; \
oData = buf1; \
} \
\
/* unroll the reduction procedure. The code is messy but it is faster. */ \
if (strideNum < 32) { \
GDevs.GetCudaThread2D(devID, strideNum, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (DTYPE*)output->data; \
_reduceFunc3 <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else if (strideNum < 128) { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (DTYPE*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 64, "Incorrect thread number when calling the cuda kernel!"); \
adjustThreadForUseWarpOptimization(blocks, threads); \
_reduceFun4<64> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else if (strideNum < 256) { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (DTYPE*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 128, "Incorrect thread number when calling the cuda kernel!"); \
adjustThreadForUseWarpOptimization(blocks, threads); \
_reduceFun4<128> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else if (strideNum < 512) { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (DTYPE*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 256, "Incorrect thread number when calling the cuda kernel!"); \
adjustThreadForUseWarpOptimization(blocks, threads); \
_reduceFun4<256> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (DTYPE*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 512, "Incorrect thread number when calling the cuda kernel!"); \
adjustThreadForUseWarpOptimization(blocks, threads); \
_reduceFun4<512> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
} \
else if (input->dataType == X_FLOAT16) { \
__half * buf1ft16 = (__half *)buf1; \
__half * buf2ft16 = (__half *)buf2; \
__half * iData = NULL; \
__half * oData = NULL; \
if (iter == 0) { \
iData = (__half*)input->data; \
oData = buf1ft16; \
} \
else if (iter % 2 == 1) { \
iData = buf1ft16; \
oData = buf2ft16; \
} \
else { \
iData = buf2ft16; \
oData = buf1ft16; \
} \
\
/* unroll the reduction procedure. The code is messy but it is faster. */ \
if (strideNum < 32) { \
GDevs.GetCudaThread2D(devID, strideNum, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (__half*)output->data; \
KernelReduceMax <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else if (strideNum < 128) { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (__half*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 64, "Incorrect thread number when calling the cuda kernel!"); \
KernelReduceMaxFast<64> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else if (strideNum < 256) { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (__half*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 128, "Incorrect thread number when calling the cuda kernel!"); \
KernelReduceMaxFast<128> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else if (strideNum < 512) { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (__half*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 256, "Incorrect thread number when calling the cuda kernel!"); \
KernelReduceMaxFast<256> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (__half*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 512, "Incorrect thread number when calling the cuda kernel!"); \
KernelReduceMaxFast<512> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
} \
\
strideNum = cudaGridSize[0]; \
blockSize = cudaGridSize[0]; \
\
iter++; \
\
} while (strideNum > 1); \
} \
\
BacktoCudaDev(input->devID, devIDBackup); \
\
if (mem != NULL) \
mem->ReleaseBuf(mem->devID, bufSize); \
else \
XMemFree(input->devID, buf); \
}
BacktoCudaDev(input->devID, devIDBackup);
_CUDAREDUCE(_CudaReduceMax, KernelReduceMaxOp, KernelReduceMaxOpLessBlocks, KernelReduceMax, KernelReduceMaxFast)
_CUDAREDUCE(_CudaReduceMin, KernelReduceMinOp, KernelReduceMinOpLessBlocks, KernelReduceMin, KernelReduceMinFast)
if (mem != NULL)
mem->ReleaseBuf(mem->devID, bufSize);
else
XMemFree(input->devID, buf);
}
#endif // USE_CUDA
......
......@@ -31,6 +31,9 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* get the max-valued items along a dimension of the tensor (cuda version) */
void _CudaReduceMax(const XTensor * input, XTensor * output, int dim);
/* get the min-valued items along a dimension of the tensor (cuda version) */
void _CudaReduceMin(const XTensor * input, XTensor * output, int dim);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
......
......@@ -29,12 +29,21 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* get the max value of the items along a dimension of the tensor. */
void _ReduceMax(const XTensor * input, XTensor * output, int dim);
/* get the min value of the items along a dimension of the tensor. */
void _ReduceMin(const XTensor * input, XTensor * output, int dim);
/*
get the max value of the items along a dimension of the tensor (return an XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor ReduceMax(const XTensor &input, int dim);
/*
get the min value of the items along a dimension of the tensor (return an XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor ReduceMin(const XTensor &input, int dim);
} // namespace nts(NiuTrans.Tensor)
#endif // __REDUCEMAX_H__
......@@ -168,4 +168,13 @@ VectorBuffer VectorBuffer::maxData(const VectorBuffer &a) {
return *this;
}
/* conculte the max of two buffer */
VectorBuffer VectorBuffer::minData(const VectorBuffer &a) {
for (int i = 0; i != a.size(); i++) {
this->values[i] = MIN(a[i], this->values[i]);
printf("runhere");
}
return *this;
}
}/* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
......@@ -48,5 +48,8 @@ public:
/* conculte the max of two buffer */
VectorBuffer maxData(const VectorBuffer &a);
/* conculte the max of two buffer */
VectorBuffer minData(const VectorBuffer &a);
};
}
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论