Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
467c2ed7
Commit
467c2ed7
authored
Oct 17, 2019
by
张裕浩
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add reduceMin operation using #define
parent
c9ef15f8
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
611 行增加
和
533 行删除
+611
-533
source/tensor/core/reduce/ReduceMax.cpp
+171
-147
source/tensor/core/reduce/ReduceMax.cu
+414
-386
source/tensor/core/reduce/ReduceMax.cuh
+3
-0
source/tensor/core/reduce/ReduceMax.h
+9
-0
source/tensor/core/reduce/VectorBuffer.cpp
+10
-0
source/tensor/core/reduce/VectorBuffer.h
+4
-0
没有找到文件。
source/tensor/core/reduce/ReduceMax.cpp
查看文件 @
467c2ed7
...
...
@@ -28,6 +28,8 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/*
get the max value of the items along a dimension of the tensor
...
...
@@ -35,129 +37,147 @@ get the max value of the items along a dimension of the tensor
>> output - the output tensor
>> dim - the dimension where the reduction is performed on
*/
void
_ReduceMax
(
const
XTensor
*
input
,
XTensor
*
output
,
int
dim
)
{
CheckNTErrors
((
input
->
devID
==
output
->
devID
||
(
input
->
devID
<
0
&&
output
->
devID
<
0
)),
"This code must be run on the same device!"
);
CheckNTErrors
((
input
&&
output
),
"Empty input or output tensors!"
);
CheckNTErrors
((
input
->
order
==
output
->
order
+
1
),
"Incorrect tensor sizes!"
);
CheckNTErrors
((
input
->
order
>
dim
&&
dim
>=
0
),
"Illegal dimension to reduce!"
);
CheckNTErrors
((
input
->
dataType
==
output
->
dataType
),
"Unmatched data types!"
);
CheckNTErrors
(
dim
<
input
->
order
,
"Wrong dimension!"
);
for
(
int
i
=
0
;
i
<
input
->
order
;
i
++
){
if
(
i
<
dim
){
CheckNTErrors
((
input
->
dimSize
[
i
]
==
output
->
dimSize
[
i
]),
"Unmatched tensors!"
);
}
else
if
(
i
>
dim
){
CheckNTErrors
((
input
->
dimSize
[
i
]
==
output
->
dimSize
[
i
-
1
]),
"Unmatched tensors!"
);
}
}
if
(
input
->
devID
>=
0
){
#ifdef USE_CUDA
_CudaReduceMax
(
input
,
output
,
dim
);
#endif
}
else
{
CheckNTErrors
((
input
->
dataType
==
DEFAULT_DTYPE
),
"TODO!"
);
int
stride
=
1
;
int
strideNum
=
input
->
dimSize
[
dim
];
int
blockSize
=
1
;
int
blockNum
=
1
;
for
(
int
i
=
0
;
i
<
input
->
order
;
i
++
)
{
if
(
i
>
dim
)
stride
*=
input
->
dimSize
[
i
];
else
if
(
i
<
dim
)
blockNum
*=
input
->
dimSize
[
i
];
}
blockSize
=
stride
*
strideNum
;
#define _REDUCE_CPU_FUNCTION(_funcCPUName, _vectorOp, _reduceOp) \
void _funcCPUName(const XTensor * input, XTensor * output, int dim) \
{ \
CheckNTErrors((input->devID == output->devID || (input->devID < 0 && output->devID < 0)), \
"This code must be run on the same device!"); \
CheckNTErrors((input && output), "Empty input or output tensors!"); \
CheckNTErrors((input->order == output->order + 1), "Incorrect tensor sizes!"); \
CheckNTErrors((input->order > dim && dim >= 0), "Illegal dimension to reduce!"); \
CheckNTErrors((input->dataType == output->dataType), "Unmatched data types!"); \
\
CheckNTErrors(dim < input->order, "Wrong dimension!"); \
\
for (int i = 0; i < input->order; i++) { \
\
if (i < dim) { \
\
CheckNTErrors((input->dimSize[i] == output->dimSize[i]), \
"Unmatched tensors!"); \
} \
else if (i > dim) { \
CheckNTErrors((input->dimSize[i] == output->dimSize[i - 1]), \
"Unmatched tensors!"); \
} \
} \
CheckNTErrors((input->dataType == DEFAULT_DTYPE), "TODO!"); \
int stride = 1; \
int strideNum = input->dimSize[dim]; \
int blockSize = 1; \
int blockNum = 1; \
for (int i = 0; i < input->order; i++) { \
if (i > dim) \
stride *= input->dimSize[i]; \
else if (i < dim) \
blockNum *= input->dimSize[i]; \
} \
blockSize = stride * strideNum; \
\
\
if(input->dimSize[input->order - 1] % (4 * 32 / sizeof(DTYPE)) == 0 && input->dimSize[input->order - 1] >= 32){ \
int vecBufLength = 32 / sizeof(DTYPE); \
\
if (dim == input->order - 1) { \
/*data is contiguous in dim 0 */
\
for (int i = 0; i < blockNum; i++) { \
DTYPE * ip = (DTYPE*)input->data + blockSize * i; \
DTYPE * op = (DTYPE*)output->data + i; \
VectorBuffer vecBuf[4]; \
for (int j = 0; j < 4; j++) { \
vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip)+j * vecBufLength); \
} \
for (int j = 1; j < strideNum / 32; j++) { \
const DTYPE* ptr = (DTYPE*)(ip + j * vecBufLength); \
vecBuf[0] = vecBuf[0]._vectorOp(VectorBuffer::loadu(ptr + 0 * vecBufLength)); \
vecBuf[1] = vecBuf[1]._vectorOp(VectorBuffer::loadu(ptr + 1 * vecBufLength)); \
vecBuf[2] = vecBuf[2]._vectorOp(VectorBuffer::loadu(ptr + 2 * vecBufLength)); \
vecBuf[3] = vecBuf[3]._vectorOp(VectorBuffer::loadu(ptr + 3 * vecBufLength)); \
} \
vecBuf[0] = vecBuf[0]._vectorOp(vecBuf[1]); \
vecBuf[0] = vecBuf[0]._vectorOp(vecBuf[2]); \
vecBuf[0] = vecBuf[0]._vectorOp(vecBuf[3]); \
DTYPE maxN = vecBuf[0][0]; \
for (int k = 1; k < vecBufLength; k++) { \
maxN = _reduceOp(maxN, vecBuf[0][k]); \
} \
*op = maxN; \
} \
\
} \
else { \
/* data is separated */
\
for(int i = 0; i < blockNum; i++){ \
for(int j = 0; j < input->dimSize[input->order - 1] / 32; j++){ \
DTYPE * ip = (DTYPE*)input->data + blockSize * i; \
DTYPE * op = (DTYPE*)output->data + stride * i; \
VectorBuffer vecBuf[4]; \
for(int k = 0; k < 4; k++){ \
vecBuf[k] = VectorBuffer::loadu((DTYPE*)(ip) + (j * 4 + k) * 32 / sizeof(DTYPE)); \
\
} \
for(int k = 1; k < strideNum; k++){ \
DTYPE * ptr = ip + k * stride + (j * 4) * vecBufLength; \
vecBuf[0] = vecBuf[0]._vectorOp(VectorBuffer::loadu(ptr + 0 * vecBufLength)); \
vecBuf[1] = vecBuf[1]._vectorOp(VectorBuffer::loadu(ptr + 1 * vecBufLength)); \
vecBuf[2] = vecBuf[2]._vectorOp(VectorBuffer::loadu(ptr + 2 * vecBufLength)); \
vecBuf[3] = vecBuf[3]._vectorOp(VectorBuffer::loadu(ptr + 3 * vecBufLength)); \
} \
for(int k = 0; k < 4; k++){ \
for(int l = 0; l < vecBufLength; l++) \
*(op + j * 32 + 8 * k + l) = vecBuf[k][l]; \
} \
} \
} \
} \
}
/* run vector buffer */
\
else{ \
for(int k = 0; k < blockNum; k++){ \
DTYPE * ip = (DTYPE*)input->data + blockSize * k; \
DTYPE * op = (DTYPE*)output->data + stride * k; \
for(int i = 0; i < stride; i++){ \
DTYPE * ipe = ip + blockSize; \
DTYPE tmpData = *(ip + i); \
for(DTYPE * ipb = ip + i + stride; ipb < ipe; ipb += stride){ \
DTYPE v = *ipb; \
tmpData = _reduceOp(tmpData, v); \
} \
*(op + i) = tmpData; \
} \
} \
} \
}
if
(
input
->
dimSize
[
input
->
order
-
1
]
%
(
4
*
32
/
sizeof
(
DTYPE
))
==
0
&&
input
->
dimSize
[
input
->
order
-
1
]
>=
32
){
int
vecBufLength
=
32
/
sizeof
(
DTYPE
);
if
(
dim
==
input
->
order
-
1
)
{
//data is contiguous in dim 0
for
(
int
i
=
0
;
i
<
blockNum
;
i
++
)
{
DTYPE
*
ip
=
(
DTYPE
*
)
input
->
data
+
blockSize
*
i
;
DTYPE
*
op
=
(
DTYPE
*
)
output
->
data
+
i
;
VectorBuffer
vecBuf
[
4
];
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
vecBuf
[
j
]
=
VectorBuffer
::
loadu
((
DTYPE
*
)(
ip
)
+
j
*
vecBufLength
);
}
for
(
int
j
=
1
;
j
<
strideNum
/
32
;
j
++
)
{
const
DTYPE
*
ptr
=
(
DTYPE
*
)(
ip
+
j
*
vecBufLength
);
vecBuf
[
0
]
=
vecBuf
[
0
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
0
*
vecBufLength
));
vecBuf
[
1
]
=
vecBuf
[
1
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
1
*
vecBufLength
));
vecBuf
[
2
]
=
vecBuf
[
2
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
2
*
vecBufLength
));
vecBuf
[
3
]
=
vecBuf
[
3
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
3
*
vecBufLength
));
}
vecBuf
[
0
]
=
vecBuf
[
0
].
maxData
(
vecBuf
[
1
]);
vecBuf
[
0
]
=
vecBuf
[
0
].
maxData
(
vecBuf
[
2
]);
vecBuf
[
0
]
=
vecBuf
[
0
].
maxData
(
vecBuf
[
3
]);
DTYPE
maxN
=
DTYPE_MIN
;
for
(
int
k
=
0
;
k
<
vecBufLength
;
k
++
)
{
maxN
=
MAX
(
maxN
,
vecBuf
[
0
][
k
]);
}
*
op
=
maxN
;
}
}
else
{
//data is separated
for
(
int
i
=
0
;
i
<
blockNum
;
i
++
){
for
(
int
j
=
0
;
j
<
input
->
dimSize
[
input
->
order
-
1
]
/
32
;
j
++
){
DTYPE
*
ip
=
(
DTYPE
*
)
input
->
data
+
blockSize
*
i
;
DTYPE
*
op
=
(
DTYPE
*
)
output
->
data
+
stride
*
i
;
VectorBuffer
vecBuf
[
4
];
for
(
int
k
=
0
;
k
<
4
;
k
++
){
vecBuf
[
k
]
=
VectorBuffer
::
loadu
((
DTYPE
*
)(
ip
)
+
(
j
*
4
+
k
)
*
32
/
sizeof
(
DTYPE
));
}
for
(
int
k
=
1
;
k
<
strideNum
;
k
++
){
DTYPE
*
ptr
=
ip
+
k
*
stride
+
(
j
*
4
)
*
vecBufLength
;
vecBuf
[
0
]
=
vecBuf
[
0
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
0
*
vecBufLength
));
vecBuf
[
1
]
=
vecBuf
[
1
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
1
*
vecBufLength
));
vecBuf
[
2
]
=
vecBuf
[
2
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
2
*
vecBufLength
));
vecBuf
[
3
]
=
vecBuf
[
3
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
3
*
vecBufLength
));
}
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
l
=
0
;
l
<
vecBufLength
;
l
++
)
*
(
op
+
j
*
32
+
8
*
k
+
l
)
=
vecBuf
[
k
][
l
];
}
}
}
}
}
//run vector buffer
else
{
for
(
int
k
=
0
;
k
<
blockNum
;
k
++
){
DTYPE
*
ip
=
(
DTYPE
*
)
input
->
data
+
blockSize
*
k
;
DTYPE
*
op
=
(
DTYPE
*
)
output
->
data
+
stride
*
k
;
for
(
int
i
=
0
;
i
<
stride
;
i
++
){
//#if defined(USE_BLAS)
// *(op + i) = *(ip + i + (int)(stride * IAMAX(strideNum, ip + i, stride)));
//#else
DTYPE
max
=
DTYPE_MIN
;
DTYPE
*
ipe
=
ip
+
blockSize
;
for
(
DTYPE
*
ipb
=
ip
+
i
;
ipb
<
ipe
;
ipb
+=
stride
){
DTYPE
v
=
*
ipb
;
if
(
max
<
v
)
max
=
v
;
}
*
(
op
+
i
)
=
max
;
//#endif
}
}
}
}
_REDUCE_CPU_FUNCTION
(
reduceMaxCPU
,
maxData
,
MAX
)
_REDUCE_CPU_FUNCTION
(
reduceMinCPU
,
minData
,
MIN
)
#ifdef USE_CUDA
#define _REDUCE_FUNCTION(_funcName, _cudaFuncName) \
void _funcName(const XTensor * input, XTensor * output, int dim) \
{ \
if(input->devID >= 0){ \
_cudaFuncName(input, output, dim); \
} \
else{ \
reduceMaxCPU(input, output, dim); \
} \
}
_REDUCE_FUNCTION
(
_ReduceMax
,
_CudaReduceMax
)
_REDUCE_FUNCTION
(
_ReduceMin
,
_CudaReduceMin
)
#else
#define _REDUCE_FUNCTION(_funcName, reduceNameCPU) \
void _funcName(const XTensor * input, XTensor * output, int dim) \
{ \
CheckNTErrors((input->devID < 0), "This code must be run on the CPU!"); \
reduceNameCPU(input, output, dim); \
}
_REDUCE_FUNCTION
(
_ReduceMax
,
reduceMaxCPU
)
_REDUCE_FUNCTION
(
_ReduceMin
,
reduceMinCPU
)
#endif
/*
/*
get the max value of the items along a dimension of the tensor (return an XTensor structure).
make a new tensor to keep the result and return it
...
...
@@ -165,34 +185,38 @@ make a new tensor to keep the result and return it
>> dim - the dimension where the reduction is performed on
<< return - the max value of the items along a dimension of the tensor
*/
XTensor
ReduceMax
(
const
XTensor
&
input
,
int
dim
)
{
CheckNTErrors
(
dim
>=
0
&&
dim
<
input
.
order
,
"Illegal dimension to reduce!"
);
int
order
=
input
.
order
-
1
;
int
*
dimSize
=
new
int
[
order
];
for
(
int
i
=
0
;
i
<
order
;
i
++
){
if
(
i
<
dim
)
dimSize
[
i
]
=
input
.
dimSize
[
i
];
else
if
(
i
>=
dim
)
dimSize
[
i
]
=
input
.
dimSize
[
i
+
1
];
}
float
dr
=
(
!
input
.
isSparse
)
?
1.0
F
:
input
.
denseRatio
;
XTensor
output
(
order
,
dimSize
,
input
.
dataType
,
dr
,
input
.
devID
,
input
.
mem
);
output
.
SetTMPFlag
();
/* call _ReduceMax function */
_ReduceMax
(
&
input
,
&
output
,
dim
);
/* tensor connection */
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCEMAX
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
/* destroy variables */
delete
[]
dimSize
;
return
output
;
#define REDUCE_FUNCTION(funcName, funcOp) \
XTensor funcName(const XTensor & input, int dim) \
{ \
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!"); \
\
int order = input.order - 1; \
int * dimSize = new int[order]; \
for(int i = 0; i < order; i++){ \
if(i < dim) \
dimSize[i] = input.dimSize[i]; \
else if(i >= dim) \
dimSize[i] = input.dimSize[i + 1]; \
} \
\
float dr = (!input.isSparse) ? 1.0F : input.denseRatio; \
XTensor output(order, dimSize, input.dataType, dr, input.devID, input.mem); \
output.SetTMPFlag(); \
\
/* call _ReduceMax function */
\
funcOp(&input, &output, dim); \
\
/* tensor connection */
\
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX); \
XLink::AddParamToHeadInt(&output, dim); \
\
/* destroy variables */
\
delete[] dimSize; \
\
return output; \
}
REDUCE_FUNCTION
(
ReduceMax
,
_ReduceMax
)
REDUCE_FUNCTION
(
ReduceMin
,
_ReduceMin
)
}
// namespace nts(NiuTrans.Tensor)
source/tensor/core/reduce/ReduceMax.cu
查看文件 @
467c2ed7
...
...
@@ -33,67 +33,75 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
use PTX code to reduce float data
*/
__device__ __forceinline__
float shflDownReduceMax(float input)
{
float output;
asm volatile(
"{"
".reg .f32 r0;"
".reg .pred p;"
"shfl.sync.down.b32 r0, %1, 0x10, 0x1f,0xffffffff;"
"setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x8, 0xf,0xffffffff;"
"setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff;"
"setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff;"
"setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff;"
"setp.lt.f32 p, %1, r0; "
"@p mov.f32 %1,r0;"
"mov.f32 %0,%1;"
"}"
: "=f"(output) : "f"(input));
return output;
#define SHLFUNCFLOAT(funcName, reducePTXOp) \
__device__ __forceinline__ \
float funcName(float input) \
{ \
float output; \
asm volatile( \
"{" \
".reg .f32 r0;" \
".reg .pred p;" \
"shfl.sync.down.b32 r0, %1, 0x10, 0x1f,0xffffffff;" \
"setp."#reducePTXOp".f32 p,%1,r0;" \
"@p mov.f32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x8, 0xf,0xffffffff;" \
"setp."#reducePTXOp".f32 p,%1,r0;" \
"@p mov.f32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff;" \
"setp."#reducePTXOp".f32 p,%1,r0;" \
"@p mov.f32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff;" \
"setp."#reducePTXOp".f32 p,%1,r0;" \
"@p mov.f32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff;" \
"setp."#reducePTXOp".f32 p, %1, r0; " \
"@p mov.f32 %1,r0;" \
"mov.f32 %0,%1;" \
"}" \
: "=f"(output) : "f"(input)); \
return output; \
}
SHLFUNCFLOAT(shflDownReduceMax, lt)
SHLFUNCFLOAT(shflDownReduceMin, gt)
/*
use PTX code to reduce int data
*/
__device__ __forceinline__
int shflDownReduceMax(int input)
{
int output;
asm volatile(
"{"
".reg .s32 r0;"
".reg .pred p;"
"shfl.sync.down.b32 r0, %1, 0x10, 0x1f,0xffffffff;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x8, 0xf,0xffffffff;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff;"
"setp.lt.s32 p, %1, r0; "
"@p mov.s32 %1,r0;"
"mov.s32 %0,%1;"
"}"
: "=r"(output) : "r"(input));
return output;
#define SHLFUNCINT(funcName, reducePTXOp) \
__device__ __forceinline__ \
int funcName(int input) \
{ \
int output; \
asm volatile( \
"{" \
".reg .s32 r0;" \
".reg .pred p;" \
"shfl.sync.down.b32 r0, %1, 0x10, 0x1f,0xffffffff;" \
"setp."#reducePTXOp".s32 p,%1,r0;" \
"@p mov.s32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x8, 0xf,0xffffffff;" \
"setp."#reducePTXOp".s32 p,%1,r0;" \
"@p mov.s32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff;" \
"setp."#reducePTXOp".s32 p,%1,r0;" \
"@p mov.s32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff;" \
"setp."#reducePTXOp".s32 p,%1,r0;" \
"@p mov.s32 %1,r0;" \
"shfl.sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff;" \
"setp."#reducePTXOp".s32 p, %1, r0; " \
"@p mov.s32 %1,r0;" \
"mov.s32 %0,%1;" \
"}" \
: "=r"(output) : "r"(input)); \
return output; \
}
SHLFUNCINT(shflDownReduceMax, lt)
SHLFUNCINT(shflDownReduceMin, gt)
/*
reduce a tensor to another that keeps the max value along a dimension - slow version
Given a block of data, we go over each dimension i in the stride and we have
...
...
@@ -108,48 +116,52 @@ crossing of the i-th columne and the j-th row.
>> blockSize - size of the block (i.e., stride * strideNum)
>> blockNum - how many blocks
*/
__global__
void KernelReduceMax(DTYPE * input, DTYPE * output,
int stride, int strideNum, int reducedStrideNum,
int blockSize, int blockNum)
{
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK * MIN_CUDA_SHARED_MEM_COL_SIZE/2];
int idx = threadIdx.x * blockDim.y + threadIdx.y;
unsigned int i = blockIdx.x*blockDim.x + threadIdx.x;
unsigned int j = blockIdx.y*blockDim.y + threadIdx.y;
if(i >= stride * blockNum)
return;
__syncthreads();
int k = i / stride;
int iOffset = i % stride;
DTYPE value = (i < stride * blockNum && j < strideNum) ?
input[blockSize * k + stride * j + iOffset] : FLOAT_MIN;
/* load data into the shared mem */
iData[threadIdx.x * blockDim.y + threadIdx.y] = value;
__syncthreads();
/* do reduction in shared mem */
for (unsigned int s = blockDim.y/2; s > 0; s >>= 1){
if(threadIdx.y < s && iData[idx] < iData[idx + s]){
iData[idx] = iData[idx + s];
}
__syncthreads();
}
/* write result for this block to the output array */
if (threadIdx.y == 0 && blockIdx.y < reducedStrideNum)
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = iData[threadIdx.x * blockDim.y];
#define KERNELREDUCEFUN3(funName, opName, initData) \
__global__ \
void funName(DTYPE * input, DTYPE * output, \
int stride, int strideNum, int reducedStrideNum, \
int blockSize, int blockNum) \
{ \
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK * MIN_CUDA_SHARED_MEM_COL_SIZE/2]; \
\
int idx = threadIdx.x * blockDim.y + threadIdx.y; \
unsigned int i = blockIdx.x*blockDim.x + threadIdx.x; \
unsigned int j = blockIdx.y*blockDim.y + threadIdx.y; \
\
if(i >= stride * blockNum) \
return; \
\
__syncthreads(); \
\
int k = i / stride; \
int iOffset = i % stride; \
\
DTYPE value = (i < stride * blockNum && j < strideNum) ? \
input[blockSize * k + stride * j + iOffset] : initData; \
\
/* load data into the shared mem */ \
iData[threadIdx.x * blockDim.y + threadIdx.y] = value; \
\
__syncthreads(); \
\
/* do reduction in shared mem */ \
for (unsigned int s = blockDim.y/2; s > 0; s >>= 1){ \
if(threadIdx.y < s){ \
iData[idx] = opName(iData[idx + s], iData[idx]); \
} \
\
__syncthreads(); \
} \
\
/* write result for this block to the output array */ \
if (threadIdx.y == 0 && blockIdx.y < reducedStrideNum) \
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = iData[threadIdx.x * blockDim.y]; \
\
}
KERNELREDUCEFUN3(KernelReduceMax, MAX, FLOAT_MIN)
KERNELREDUCEFUN3(KernelReduceMin, MIN, MAX_FLOAT)
/*
reduce a tensor to another that keeps the max value along a dimension - slow version
Given a block of data, we go over each dimension i in the stride and we have
...
...
@@ -231,48 +243,52 @@ reduce a tensor to another that keeps the max value along a dimension - fast ve
>> blockSize - size of the block (i.e., stride * strideNum)
>> blockNum - how many blocks
*/
template <unsigned int goodSize> __global__
void KernelReduceMaxFast(DTYPE * input, DTYPE * output,
int stride, int strideNum, int reducedStrideNum,
int blockSize, int blockNum)
{
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK];
unsigned int tid = threadIdx.y;
unsigned int j = blockIdx.y * (blockDim.y * 2) + threadIdx.y;
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if(i >= stride * blockNum)
return;
__syncthreads();
/* first level reduction */
int k = i / stride;
int iOffset = i % stride;
DTYPE * data = iData + threadIdx.x * blockDim.y;
DTYPE * inputData = input + k * blockSize;
DTYPE value = j < strideNum ? inputData[j * stride + iOffset] : FLOAT_MIN;
DTYPE value2 = j + blockDim.y < strideNum ? inputData[(j + blockDim.y) * stride + iOffset]: FLOAT_MIN;
value = MAX(value, value2);
value = shflDownReduceMax(value);
if ((tid & 0x1f) == 0)
data[tid / 32] = value;
__syncthreads();
if (tid < 32) {
if (tid < blockDim.y / 32)
value = data[tid];
else
value = FLOAT_MIN;
value = shflDownReduceMax(value);
if (tid == 0 && blockIdx.y < reducedStrideNum)
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = value;
}
#define KERNELREDUCEFUN4(funName, opName, opFuncName, initData) \
template <unsigned int goodSize> __global__ \
void funName(DTYPE * input, DTYPE * output, \
int stride, int strideNum, int reducedStrideNum, \
int blockSize, int blockNum) \
{ \
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK]; \
\
unsigned int tid = threadIdx.y; \
unsigned int j = blockIdx.y * (blockDim.y * 2) + threadIdx.y; \
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \
\
if(i >= stride * blockNum) \
return; \
\
__syncthreads(); \
\
/* first level reduction */ \
int k = i / stride; \
int iOffset = i % stride; \
\
DTYPE * data = iData + threadIdx.x * blockDim.y; \
DTYPE * inputData = input + k * blockSize; \
DTYPE value = j < strideNum ? inputData[j * stride + iOffset] : initData; \
DTYPE value2 = j + blockDim.y < strideNum ? inputData[(j + blockDim.y) * stride + iOffset]: initData; \
\
value = opName(value, value2); \
value = opFuncName(value); \
if ((tid & 0x1f) == 0) \
data[tid / 32] = value; \
__syncthreads(); \
\
if (tid < 32) { \
if (tid < blockDim.y / 32) \
value = data[tid]; \
else \
value = initData; \
value = opFuncName(value); \
if (tid == 0 && blockIdx.y < reducedStrideNum) \
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = value; \
} \
}
KERNELREDUCEFUN4(KernelReduceMaxFast, MAX, shflDownReduceMax, FLOAT_MIN)
KERNELREDUCEFUN4(KernelReduceMinFast, MIN, shflDownReduceMin, MAX_FLOAT)
/*
reduce a tensor to another that keeps the max value along a dimension - fast version
>> input - the input array (representing a tensor)
...
...
@@ -372,14 +388,12 @@ void KernelReduceMaxSimpleFast(DTYPE * input, DTYPE * output,
int stride4 = stride3 + stride;
for(int k = 0; k < blockSize; k += stride4){
DTYPE m = MAX(MAX(ip[k], ip[k + stride]), MAX(ip[k + stride2], ip[k + stride3]));
if(max < m)
max = m;
max = MAX(max, m);
}
}
else{
for(int k = 0; k < blockSize; k += stride)
if(max < ip[k])
max = ip[k];
for (int k = 0; k < blockSize; k += stride)
max = MAX(max, ip[k]);
}
__syncthreads();
...
...
@@ -429,66 +443,75 @@ inline void adjustThreadForUseWarpOptimization(dim3& blocks, dim3& threads)
/*
In some case,we use less block to imporve efficiency
*/
__global__
void KernelReduceMaxOpLessBlocks(DTYPE * input, DTYPE * output, int strideNum, int blockNum)
{
int idx = threadIdx.x % 32;
int idy = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int startIndex = idy * strideNum;
DTYPE threadMax = FLOAT_MIN;
for (int i = idx; i < strideNum; i += 32) {
threadMax = max(input[startIndex + i], threadMax);
}
threadMax = shflDownReduceMax(threadMax);
if (idx == 0)
output[idy] = threadMax;
#define KERNELREDUCEFUN2(funName, opName, opFuncName, initData) \
__global__ \
void funName(DTYPE * input, DTYPE * output, int strideNum, int blockNum) \
{ \
int idx = threadIdx.x % 32; \
int idy = (blockIdx.x * blockDim.x + threadIdx.x) / 32; \
\
int startIndex = idy * strideNum; \
DTYPE threadMax = initData; \
for (int i = idx; i < strideNum; i += 32) { \
threadMax = opName(input[startIndex + i], threadMax); \
} \
threadMax = opFuncName(threadMax); \
if (idx == 0) \
output[idy] = threadMax; \
}
KERNELREDUCEFUN2(KernelReduceMaxOpLessBlocks, MAX, shflDownReduceMax, FLOAT_MIN)
KERNELREDUCEFUN2(KernelReduceMinOpLessBlocks, MIN, shflDownReduceMin, MAX_FLOAT)
/*
we use PTX code reduce
*/
__global__
void KernelReduceMaxOp(DTYPE * input, DTYPE * output,int stride, int strideNum,
int reducedStrideNum,int blockSize, int blockNum)
{
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK / 32];
unsigned int tid = threadIdx.y;
unsigned int j = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= stride * blockNum)
return;
/* first level reduction */
int k = i / stride;
int iOffset = i % stride;
DTYPE threadMax = FLOAT_MIN;
DTYPE * data = iData + threadIdx.x * blockDim.y;
DTYPE * inputData = input + k * blockSize;
for (int it = j; it < strideNum; it += blockDim.y){
threadMax = max(inputData[it * stride + iOffset], threadMax);
}
__syncthreads();
threadMax = shflDownReduceMax(threadMax);
if ((tid & 0x1f) == 0)
data[tid / 32] = threadMax;
__syncthreads();
/* use one warp to reduce remaining data */
if (tid < 32){
if (tid < blockDim.y / 32)
threadMax = data[tid];
else threadMax = FLOAT_MIN;
threadMax = shflDownReduceMax(threadMax);
if (tid == 0 && blockIdx.y < reducedStrideNum)
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = threadMax;
}
#define KERNELREDUCEFUN1(funName, opName, opFuncName, initData) \
__global__ \
void funName(DTYPE * input, DTYPE * output,int stride, int strideNum, \
int reducedStrideNum,int blockSize, int blockNum) \
{ \
__shared__ DTYPE iData[MAX_CUDA_THREAD_NUM_PER_BLOCK / 32]; \
\
unsigned int tid = threadIdx.y; \
unsigned int j = blockIdx.y * blockDim.y + threadIdx.y; \
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \
if (i >= stride * blockNum) \
return; \
\
/* first level reduction */ \
int k = i / stride; \
int iOffset = i % stride; \
\
DTYPE threadMax = initData; \
\
DTYPE * data = iData + threadIdx.x * blockDim.y; \
DTYPE * inputData = input + k * blockSize; \
for (int it = j; it < strideNum; it += blockDim.y){ \
threadMax = opName(inputData[it * stride + iOffset], threadMax); \
} \
\
__syncthreads(); \
threadMax = opFuncName(threadMax); \
if ((tid & 0x1f) == 0) \
data[tid / 32] = threadMax; \
\
__syncthreads(); \
/* use one warp to reduce remaining data */ \
if (tid < 32){ \
if (tid < blockDim.y / 32) \
threadMax = data[tid]; \
else threadMax = initData; \
threadMax = opFuncName(threadMax); \
if (tid == 0 && blockIdx.y < reducedStrideNum) \
output[(k * reducedStrideNum + blockIdx.y) * stride + iOffset] = threadMax; \
} \
}
KERNELREDUCEFUN1(KernelReduceMaxOp, MAX, shflDownReduceMax, FLOAT_MIN)
KERNELREDUCEFUN1(KernelReduceMinOp, MIN, shflDownReduceMin, MAX_FLOAT)
/*
get the max-valued items along a dimension of the tensor (cuda version).
For a 1-dimensional data array a,
...
...
@@ -497,202 +520,207 @@ sum_i = max_{0<=j<strideNum} input_{i,j}
>> output - the output tensor
>> dim - which dimension to reduce
*/
void _CudaReduceMax(const XTensor * input, XTensor * output, int dim)
{
CheckNTErrors(input && output, "Empty input or output tensors!");
CheckNTErrors(input->order == output->order + 1, "Incorrect tensor sizes!");
CheckNTErrors(input->order > dim && dim >=0, "Illegal dimension to reduce!");
CheckNTErrors(input->dataType == output->dataType, "Unmatched data types!");
for(int i = 0; i < input->order; i++){
if(i < dim){
CheckNTErrors(input->dimSize[i] == output->dimSize[i], "Unmatched tensors!");
}
else if(i > dim){
CheckNTErrors(input->dimSize[i] == output->dimSize[i - 1], "Unmatched tensors!");
}
}
int cudaGridSize[3];
int cudaBlockSize[3];
int iter = 0;
int stride = 1;
int strideNum = input->dimSize[dim];
int blockSize = 1;
int blockNum = 1;
for (int i = 0; i < input->order; i++) {
if (i < dim)
blockNum *= input->dimSize[i];
else if (i > dim)
stride *= input->dimSize[i];
}
blockSize = stride * strideNum;
int devID = input->devID;
XMem * mem = input->mem;
GDevs.GetCudaThread2D(devID, strideNum, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
int bufSize = sizeof(DTYPE) * cudaGridSize[0] * stride * blockNum * 2;
DTYPE * buf = mem != NULL ? (DTYPE*)mem->AllocBuf(mem->devID, bufSize) : (DTYPE*)XMemAlloc(input->devID, bufSize);
DTYPE * buf1 = buf;
DTYPE * buf2 = buf + cudaGridSize[0] * stride * blockNum;
int devIDBackup;
ProtectCudaDev(input->devID, devIDBackup);
if (stride == 1 && blockNum >= 10) {
dim3 grids;
dim3 blocks;
continuousStorageThreadAllocation(grids, blocks, (long long)blockNum, strideNum);
if (blocks.y >= 128) {
KernelReduceMaxOp <<<grids, blocks >>> ((DTYPE *)input->data, (DTYPE*)output->data, stride, strideNum, grids.y, blockSize, blockNum);
}
else {
if (blockNum % 4 != 0) blockNum = (int)(blockNum / 4) + 1;
else blockNum = blockNum / 4;
KernelReduceMaxOpLessBlocks <<<blockNum, 128 >>> ((DTYPE *)input->data, (DTYPE*)output->data, strideNum, blockNum);
}
}
else {
do {
if (input->dataType == DEFAULT_DTYPE) {
DTYPE * iData = NULL;
DTYPE * oData = NULL;
if (iter == 0) {
iData = (DTYPE*)input->data;
oData = buf1;
}
else if (iter % 2 == 1) {
iData = buf1;
oData = buf2;
}
else {
iData = buf2;
oData = buf1;
}
/* unroll the reduction procedure. The code is messy but it is faster. */
if (strideNum < 32) {
GDevs.GetCudaThread2D(devID, strideNum, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (DTYPE*)output->data;
KernelReduceMax <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else if (strideNum < 128) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (DTYPE*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 64, "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceMaxFast<64> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else if (strideNum < 256) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (DTYPE*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 128, "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceMaxFast<128> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else if (strideNum < 512) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (DTYPE*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 256, "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceMaxFast<256> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (DTYPE*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 512, "Incorrect thread number when calling the cuda kernel!");
adjustThreadForUseWarpOptimization(blocks, threads);
KernelReduceMaxFast<512> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
}
else if (input->dataType == X_FLOAT16) {
__half * buf1ft16 = (__half *)buf1;
__half * buf2ft16 = (__half *)buf2;
__half * iData = NULL;
__half * oData = NULL;
if (iter == 0) {
iData = (__half*)input->data;
oData = buf1ft16;
}
else if (iter % 2 == 1) {
iData = buf1ft16;
oData = buf2ft16;
}
else {
iData = buf2ft16;
oData = buf1ft16;
}
/* unroll the reduction procedure. The code is messy but it is faster. */
if (strideNum < 32) {
GDevs.GetCudaThread2D(devID, strideNum, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (__half*)output->data;
KernelReduceMax <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else if (strideNum < 128) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (__half*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 64, "Incorrect thread number when calling the cuda kernel!");
KernelReduceMaxFast<64> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else if (strideNum < 256) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (__half*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 128, "Incorrect thread number when calling the cuda kernel!");
KernelReduceMaxFast<128> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else if (strideNum < 512) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (__half*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 256, "Incorrect thread number when calling the cuda kernel!");
KernelReduceMaxFast<256> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
else {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]);
if (cudaGridSize[0] == 1)
oData = (__half*)output->data;
CheckNTErrors(cudaBlockSize[0] >= 512, "Incorrect thread number when calling the cuda kernel!");
KernelReduceMaxFast<512> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum);
}
}
strideNum = cudaGridSize[0];
blockSize = cudaGridSize[0];
iter++;
} while (strideNum > 1);
}
#define _CUDAREDUCE(_funcName, _reduceFunc1, _reduceFunc2, _reduceFunc3, _reduceFun4) \
void _funcName(const XTensor * input, XTensor * output, int dim) \
{ \
CheckNTErrors(input && output, "Empty input or output tensors!"); \
CheckNTErrors(input->order == output->order + 1, "Incorrect tensor sizes!"); \
CheckNTErrors(input->order > dim && dim >=0, "Illegal dimension to reduce!"); \
CheckNTErrors(input->dataType == output->dataType, "Unmatched data types!"); \
\
for(int i = 0; i < input->order; i++){ \
if(i < dim){ \
CheckNTErrors(input->dimSize[i] == output->dimSize[i], "Unmatched tensors!"); \
} \
else if(i > dim){ \
CheckNTErrors(input->dimSize[i] == output->dimSize[i - 1], "Unmatched tensors!"); \
} \
} \
\
int cudaGridSize[3]; \
int cudaBlockSize[3]; \
int iter = 0; \
int stride = 1; \
int strideNum = input->dimSize[dim]; \
int blockSize = 1; \
int blockNum = 1; \
\
for (int i = 0; i < input->order; i++) { \
if (i < dim) \
blockNum *= input->dimSize[i]; \
else if (i > dim) \
stride *= input->dimSize[i]; \
} \
blockSize = stride * strideNum; \
\
int devID = input->devID; \
XMem * mem = input->mem; \
\
GDevs.GetCudaThread2D(devID, strideNum, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
\
int bufSize = sizeof(DTYPE) * cudaGridSize[0] * stride * blockNum * 2; \
DTYPE * buf = mem != NULL ? (DTYPE*)mem->AllocBuf(mem->devID, bufSize) : (DTYPE*)XMemAlloc(input->devID, bufSize); \
DTYPE * buf1 = buf; \
DTYPE * buf2 = buf + cudaGridSize[0] * stride * blockNum; \
\
int devIDBackup; \
ProtectCudaDev(input->devID, devIDBackup); \
\
if (stride == 1 && blockNum >= 10) { \
dim3 grids; \
dim3 blocks; \
continuousStorageThreadAllocation(grids, blocks, (long long)blockNum, strideNum); \
if (blocks.y >= 128) { \
_reduceFunc1 <<<grids, blocks >>> ((DTYPE *)input->data, (DTYPE*)output->data, stride, strideNum, grids.y, blockSize, blockNum); \
} \
else { \
if (blockNum % 4 != 0) blockNum = (int)(blockNum / 4) + 1; \
else blockNum = blockNum / 4; \
_reduceFunc2 <<<blockNum, 128 >>> ((DTYPE *)input->data, (DTYPE*)output->data, strideNum, blockNum); \
} \
} \
else { \
do { \
if (input->dataType == DEFAULT_DTYPE) { \
DTYPE * iData = NULL; \
DTYPE * oData = NULL; \
if (iter == 0) { \
iData = (DTYPE*)input->data; \
oData = buf1; \
} \
else if (iter % 2 == 1) { \
iData = buf1; \
oData = buf2; \
} \
else { \
iData = buf2; \
oData = buf1; \
} \
\
/* unroll the reduction procedure. The code is messy but it is faster. */ \
if (strideNum < 32) { \
GDevs.GetCudaThread2D(devID, strideNum, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (DTYPE*)output->data; \
_reduceFunc3 <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else if (strideNum < 128) { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (DTYPE*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 64, "Incorrect thread number when calling the cuda kernel!"); \
adjustThreadForUseWarpOptimization(blocks, threads); \
_reduceFun4<64> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else if (strideNum < 256) { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (DTYPE*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 128, "Incorrect thread number when calling the cuda kernel!"); \
adjustThreadForUseWarpOptimization(blocks, threads); \
_reduceFun4<128> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else if (strideNum < 512) { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (DTYPE*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 256, "Incorrect thread number when calling the cuda kernel!"); \
adjustThreadForUseWarpOptimization(blocks, threads); \
_reduceFun4<256> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (DTYPE*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 512, "Incorrect thread number when calling the cuda kernel!"); \
adjustThreadForUseWarpOptimization(blocks, threads); \
_reduceFun4<512> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
} \
else if (input->dataType == X_FLOAT16) { \
__half * buf1ft16 = (__half *)buf1; \
__half * buf2ft16 = (__half *)buf2; \
__half * iData = NULL; \
__half * oData = NULL; \
if (iter == 0) { \
iData = (__half*)input->data; \
oData = buf1ft16; \
} \
else if (iter % 2 == 1) { \
iData = buf1ft16; \
oData = buf2ft16; \
} \
else { \
iData = buf2ft16; \
oData = buf1ft16; \
} \
\
/* unroll the reduction procedure. The code is messy but it is faster. */ \
if (strideNum < 32) { \
GDevs.GetCudaThread2D(devID, strideNum, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (__half*)output->data; \
KernelReduceMax <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else if (strideNum < 128) { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (__half*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 64, "Incorrect thread number when calling the cuda kernel!"); \
KernelReduceMaxFast<64> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else if (strideNum < 256) { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (__half*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 128, "Incorrect thread number when calling the cuda kernel!"); \
KernelReduceMaxFast<128> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else if (strideNum < 512) { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (__half*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 256, "Incorrect thread number when calling the cuda kernel!"); \
KernelReduceMaxFast<256> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
else { \
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); \
dim3 blocks(cudaGridSize[1], cudaGridSize[0]), threads(cudaBlockSize[1], cudaBlockSize[0]); \
if (cudaGridSize[0] == 1) \
oData = (__half*)output->data; \
CheckNTErrors(cudaBlockSize[0] >= 512, "Incorrect thread number when calling the cuda kernel!"); \
KernelReduceMaxFast<512> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y, blockSize, blockNum); \
} \
} \
\
strideNum = cudaGridSize[0]; \
blockSize = cudaGridSize[0]; \
\
iter++; \
\
} while (strideNum > 1); \
} \
\
BacktoCudaDev(input->devID, devIDBackup); \
\
if (mem != NULL) \
mem->ReleaseBuf(mem->devID, bufSize); \
else \
XMemFree(input->devID, buf); \
}
BacktoCudaDev(input->devID, devIDBackup);
_CUDAREDUCE(_CudaReduceMax, KernelReduceMaxOp, KernelReduceMaxOpLessBlocks, KernelReduceMax, KernelReduceMaxFast)
_CUDAREDUCE(_CudaReduceMin, KernelReduceMinOp, KernelReduceMinOpLessBlocks, KernelReduceMin, KernelReduceMinFast)
if (mem != NULL)
mem->ReleaseBuf(mem->devID, bufSize);
else
XMemFree(input->devID, buf);
}
#endif // USE_CUDA
...
...
source/tensor/core/reduce/ReduceMax.cuh
查看文件 @
467c2ed7
...
...
@@ -31,6 +31,9 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* get the max-valued items along a dimension of the tensor (cuda version) */
void _CudaReduceMax(const XTensor * input, XTensor * output, int dim);
/* get the min-valued items along a dimension of the tensor (cuda version) */
void _CudaReduceMin(const XTensor * input, XTensor * output, int dim);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
...
...
source/tensor/core/reduce/ReduceMax.h
查看文件 @
467c2ed7
...
...
@@ -29,12 +29,21 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* get the max value of the items along a dimension of the tensor. */
void
_ReduceMax
(
const
XTensor
*
input
,
XTensor
*
output
,
int
dim
);
/* get the min value of the items along a dimension of the tensor. */
void
_ReduceMin
(
const
XTensor
*
input
,
XTensor
*
output
,
int
dim
);
/*
get the max value of the items along a dimension of the tensor (return an XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor
ReduceMax
(
const
XTensor
&
input
,
int
dim
);
/*
get the min value of the items along a dimension of the tensor (return an XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor
ReduceMin
(
const
XTensor
&
input
,
int
dim
);
}
// namespace nts(NiuTrans.Tensor)
#endif // __REDUCEMAX_H__
source/tensor/core/reduce/VectorBuffer.cpp
查看文件 @
467c2ed7
...
...
@@ -168,4 +168,13 @@ VectorBuffer VectorBuffer::maxData(const VectorBuffer &a) {
return
*
this
;
}
/* conculte the max of two buffer */
VectorBuffer
VectorBuffer
::
minData
(
const
VectorBuffer
&
a
)
{
for
(
int
i
=
0
;
i
!=
a
.
size
();
i
++
)
{
this
->
values
[
i
]
=
MIN
(
a
[
i
],
this
->
values
[
i
]);
printf
(
"runhere"
);
}
return
*
this
;
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
source/tensor/core/reduce/VectorBuffer.h
查看文件 @
467c2ed7
...
...
@@ -48,5 +48,8 @@ public:
/* conculte the max of two buffer */
VectorBuffer
maxData
(
const
VectorBuffer
&
a
);
/* conculte the max of two buffer */
VectorBuffer
minData
(
const
VectorBuffer
&
a
);
};
}
\ No newline at end of file
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论