Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
T
Tensor.LowPrecision
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
linye
Tensor.LowPrecision
Commits
930b837e
Commit
930b837e
authored
Aug 07, 2019
by
linye
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update float16 datatype of Normalize
parent
253f2527
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
204 行增加
和
9 行删除
+204
-9
source/tensor/core/math/Normalize.cu
+25
-7
source/tensor/core/math/Normalize.cuh
+3
-2
source/tensor/test/TNormalize.cpp
+176
-0
没有找到文件。
source/tensor/core/math/Normalize.cu
查看文件 @
930b837e
...
...
@@ -23,6 +23,7 @@
#include "../../XTensor.h"
#include "Normalize.h"
#include "Normalize.cuh"
#include "cuda_fp16.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
...
...
@@ -42,13 +43,14 @@ where a and b are the scalar and bias respectively, and \epsilon is the adjustme
>> strideNum - how many strides we need to go over for next block
>> blockNum - how many blocks we have
*/
template<class T, TENSOR_DATA_TYPE datatype>
__global__
void KernelNormalize(
DTYPE * input, DTYPE * output, DTYPE * mean, DTYPE
* var,
DTYPE * a, DTYPE * b, DTYPE
epsilon,
void KernelNormalize(
T * input, T * output, T * mean, T
* var,
T * a, T * b, T
epsilon,
int stride, int strideNum, int blockNum)
{
__shared__
DTYPE
iMean[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__
DTYPE
iVar[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__
T
iMean[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__
T
iVar[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ int iBlock[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ int iOffset[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ int blockSize;
...
...
@@ -72,7 +74,15 @@ void KernelNormalize(DTYPE * input, DTYPE * output, DTYPE * mean, DTYPE * var,
int inBlockOffset = j * stride + iOffset[threadIdx.x];
int offset = iBlock[threadIdx.x] * blockSize + inBlockOffset;
output[offset] = a[inBlockOffset] * (input[offset] - iMean[threadIdx.x]) / sqrt(iVar[threadIdx.x] + epsilon) + b[inBlockOffset];
if (datatype == X_FLOAT) {
output[offset] = (DTYPE)(a[inBlockOffset] * (input[offset] - iMean[threadIdx.x])) /
sqrt((DTYPE)(iVar[threadIdx.x] + epsilon)) + (DTYPE)b[inBlockOffset];
}
else if (datatype == X_FLOAT16) {
output[offset] = __hadd(__hdiv(__hmul(a[inBlockOffset], __hsub(input[offset], iMean[threadIdx.x])),
hsqrt(iVar[threadIdx.x] + epsilon)), __float2half(b[inBlockOffset]));
}
}
/*
...
...
@@ -93,7 +103,6 @@ void _CudaNormalize(const XTensor * input, XTensor * output, int dim,
const XTensor * a, const XTensor * b,
DTYPE epsilon)
{
CheckNTErrors((input->dataType == DEFAULT_DTYPE), "TODO!");
int dimRDI = input->order - dim - 1;
int stride = 1;
...
...
@@ -118,10 +127,19 @@ void _CudaNormalize(const XTensor * input, XTensor * output, int dim,
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
KernelNormalize << <blocks, threads >> >((DTYPE*)input->data, (DTYPE*)output->data,
if (input->dataType == DEFAULT_DTYPE) {
KernelNormalize <DTYPE, X_FLOAT><< <blocks, threads >> >((DTYPE*)input->data, (DTYPE*)output->data,
(DTYPE*)mean->data, (DTYPE*)var->data,
(DTYPE*)a->data, (DTYPE*)b->data, epsilon,
stride, strideNum, blockNum);
}
else if (input->dataType == X_FLOAT16) {
__half epsilon1 = __float2half(epsilon);
KernelNormalize <__half, X_FLOAT16><< <blocks, threads >> > ((__half*)input->data, (__half*)output->data,
(__half*)mean->data, (__half*)var->data,
(__half*)a->data, (__half*)b->data, epsilon1,
stride, strideNum, blockNum);
}
BacktoCudaDev(a->devID, devIDBackup);
}
...
...
source/tensor/core/math/Normalize.cuh
查看文件 @
930b837e
...
...
@@ -33,9 +33,10 @@ normalized the data with normal distribution (Kernel code). For an input x,
y = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter
*/
template<class T, TENSOR_DATA_TYPE datatype>
__global__
void KernelNormalize(
DTYPE * input, DTYPE * output, DTYPE * mean, DTYPE
* var,
DTYPE * a, DTYPE * b, DTYPE
epsilon,
void KernelNormalize(
T * input, T * output, T * mean, T
* var,
T * a, T * b, T
epsilon,
int stride, int strideNum, int blockNum);
/*
...
...
source/tensor/test/TNormalize.cpp
查看文件 @
930b837e
...
...
@@ -20,6 +20,7 @@
*/
#include "TNormalize.h"
#include "../core/getandset/ConvertDataType.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
...
...
@@ -204,6 +205,171 @@ bool TestNormalize1()
#endif // USE_CUDA
}
/*
case 2: float16 normalized the data with normal distribution
For an input x, y = a * (x-mean)/sqrt(variance+\epsilon) + b.
where a and b are the scalar and bias respectively,
and \epsilon is the adjustment parameter.
*/
bool
TestNormalize2
()
{
/* a source tensor of size (2, 3) */
int
sOrder
=
2
;
int
*
sDimSize
=
new
int
[
sOrder
];
sDimSize
[
0
]
=
2
;
sDimSize
[
1
]
=
3
;
int
sUnitNum
=
1
;
for
(
int
i
=
0
;
i
<
sOrder
;
i
++
)
sUnitNum
*=
sDimSize
[
i
];
/* a target tensor of size (2, 3) */
int
tOrder
=
2
;
int
*
tDimSize
=
new
int
[
tOrder
];
tDimSize
[
0
]
=
2
;
tDimSize
[
1
]
=
3
;
int
tUnitNum
=
1
;
for
(
int
i
=
0
;
i
<
tOrder
;
i
++
)
tUnitNum
*=
tDimSize
[
i
];
/* a mean tensor of size (3) */
int
meanOrder
=
1
;
int
*
meanDimSize
=
new
int
[
meanOrder
];
meanDimSize
[
0
]
=
3
;
int
meanUnitNum
=
1
;
for
(
int
i
=
0
;
i
<
meanOrder
;
i
++
)
meanUnitNum
*=
meanDimSize
[
i
];
/* a variance tensor of size (3) */
int
varOrder
=
1
;
int
*
varDimSize
=
new
int
[
varOrder
];
varDimSize
[
0
]
=
3
;
int
varUnitNum
=
1
;
for
(
int
i
=
0
;
i
<
varOrder
;
i
++
)
varUnitNum
*=
varDimSize
[
i
];
/* a scalar tensor of size (2, 3) */
int
aOrder
=
2
;
int
*
aDimSize
=
new
int
[
aOrder
];
aDimSize
[
0
]
=
2
;
aDimSize
[
1
]
=
3
;
int
aUnitNum
=
1
;
for
(
int
i
=
0
;
i
<
aOrder
;
i
++
)
aUnitNum
*=
aDimSize
[
i
];
/* a bias tensor of size (2, 3) */
int
bOrder
=
2
;
int
*
bDimSize
=
new
int
[
bOrder
];
bDimSize
[
0
]
=
2
;
bDimSize
[
1
]
=
3
;
int
bUnitNum
=
1
;
for
(
int
i
=
0
;
i
<
bOrder
;
i
++
)
bUnitNum
*=
bDimSize
[
i
];
DTYPE
sData
[
2
][
3
]
=
{
{
1.0
F
,
2.0
F
,
3.0
F
},
{
1.5
F
,
2.5
F
,
3.5
F
}
};
DTYPE
meanData
[
3
]
=
{
1.0
F
,
1.5
F
,
2.0
F
};
DTYPE
varData
[
3
]
=
{
1.0
F
,
1.0
F
,
4.0
F
};
DTYPE
aData
[
2
][
3
]
=
{
{
1.0
F
,
1.0
F
,
1.0
F
},
{
1.0
F
,
1.0
F
,
1.0
F
}
};
DTYPE
answer
[
2
][
3
]
=
{
{
0.0
F
,
0.5
F
,
0.5
F
},
{
0.5
F
,
1.0
F
,
0.75
F
}
};
/* CPU test */
bool
cpuTest
=
true
;
#ifdef USE_CUDA
/* GPU test */
bool
gpuTest
=
true
;
/* create tensors */
XTensor
*
sGPU
=
NewTensor
(
sOrder
,
sDimSize
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
meanGPU
=
NewTensor
(
meanOrder
,
meanDimSize
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
varGPU
=
NewTensor
(
varOrder
,
varDimSize
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
aGPU
=
NewTensor
(
aOrder
,
aDimSize
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
bGPU
=
NewTensor
(
bOrder
,
bDimSize
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
tGPU
=
NewTensor
(
tOrder
,
tDimSize
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
tMeGPU
=
NewTensor
(
sOrder
,
sDimSize
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
tUserGPU
;
/* create float16 tensors */
XTensor
sHalfGPU
;
XTensor
meanHalfGPU
;
XTensor
varHalfGPU
;
XTensor
aHalfGPU
;
XTensor
bHalfGPU
;
XTensor
tHalfGPU
;
XTensor
tMeHalfGPU
;
XTensor
tUserHalfGPU
;
/* initialize variables */
sGPU
->
SetData
(
sData
,
sUnitNum
);
tMeGPU
->
SetData
(
sData
,
sUnitNum
);
meanGPU
->
SetData
(
meanData
,
meanUnitNum
);
varGPU
->
SetData
(
varData
,
varUnitNum
);
aGPU
->
SetData
(
aData
,
aUnitNum
);
bGPU
->
SetZeroAll
();
tGPU
->
SetZeroAll
();
/* convert data type from float to float16 */
sHalfGPU
=
ConvertDataType
(
*
sGPU
,
X_FLOAT16
);
meanHalfGPU
=
ConvertDataType
(
*
meanGPU
,
X_FLOAT16
);
varHalfGPU
=
ConvertDataType
(
*
varGPU
,
X_FLOAT16
);
aHalfGPU
=
ConvertDataType
(
*
aGPU
,
X_FLOAT16
);
bHalfGPU
=
ConvertDataType
(
*
bGPU
,
X_FLOAT16
);
tHalfGPU
=
ConvertDataType
(
*
tGPU
,
X_FLOAT16
);
tMeHalfGPU
=
ConvertDataType
(
*
tMeGPU
,
X_FLOAT16
);
/* call Normalize function */
_Normalize
(
&
sHalfGPU
,
&
tHalfGPU
,
0
,
&
meanHalfGPU
,
&
varHalfGPU
,
&
aHalfGPU
,
&
bHalfGPU
,
0.0
F
);
_NormalizeMe
(
&
tMeHalfGPU
,
0
,
&
meanHalfGPU
,
&
varHalfGPU
,
&
aHalfGPU
,
&
bHalfGPU
,
0.0
F
);
tUserHalfGPU
=
Normalize
(
sHalfGPU
,
0
,
meanHalfGPU
,
varHalfGPU
,
aHalfGPU
,
bHalfGPU
,
0.0
F
);
/* convert data type from float16 to float */
_ConvertDataType
(
&
tHalfGPU
,
tGPU
);
_ConvertDataType
(
&
tMeHalfGPU
,
tMeGPU
);
tUserGPU
=
ConvertDataType
(
tUserHalfGPU
,
X_FLOAT
);
/* check results */
gpuTest
=
tGPU
->
CheckData
(
answer
,
tUnitNum
,
1e-4
F
)
&&
tMeGPU
->
CheckData
(
answer
,
tUnitNum
,
1e-4
F
)
&&
tUserGPU
.
CheckData
(
answer
,
tUnitNum
,
1e-4
F
);
/* destroy variables */
delete
sGPU
;
delete
tMeGPU
;
delete
tGPU
;
delete
meanGPU
;
delete
varGPU
;
delete
aGPU
;
delete
bGPU
;
delete
[]
sDimSize
;
delete
[]
tDimSize
;
delete
[]
meanDimSize
;
delete
[]
varDimSize
;
delete
[]
aDimSize
;
delete
[]
bDimSize
;
return
cpuTest
&&
gpuTest
;
#else
/* destroy variables */
delete
[]
sDimSize
;
delete
[]
tDimSize
;
delete
[]
meanDimSize
;
delete
[]
varDimSize
;
delete
[]
aDimSize
;
delete
[]
bDimSize
;
return
cpuTest
;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
...
...
@@ -225,6 +391,16 @@ bool TestNormalize()
else
XPRINT
(
0
,
stdout
,
">> case 1 passed!
\n
"
);
/* case 2 test */
caseFlag
=
TestNormalize2
();
if
(
!
caseFlag
)
{
returnFlag
=
false
;
XPRINT
(
0
,
stdout
,
">> case 2 failed!
\n
"
);
}
else
XPRINT
(
0
,
stdout
,
">> case 2 passed!
\n
"
);
/* other cases test */
/*
TODO!!
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论