Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
T
Tensor.LowPrecision
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
魏冰浩
Tensor.LowPrecision
Commits
30217de4
Commit
30217de4
authored
Jul 07, 2019
by
linye
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
int8/int/float16 updated
parent
340701d8
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
287 行增加
和
38 行删除
+287
-38
source/network/Main.cpp
+25
-3
source/tensor/core/math/ScaleAndShift.cu
+42
-29
source/tensor/test/TMatrixMul.cpp
+220
-6
没有找到文件。
source/network/Main.cpp
查看文件 @
30217de4
...
...
@@ -47,6 +47,7 @@ void ReduceMaxFP16Test();
void
ReduceSumFP16Test
();
void
LogSoftmaxFP16Test
();
void
ClipFP16Test
();
void
ScaleAndShiftFP16Test
();
using
namespace
nts
;
using
namespace
fnnlm
;
...
...
@@ -82,9 +83,10 @@ int main(int argc, const char ** argv )
//return 0;
//LogSoftmaxFP16Test();
//return 0;
ClipFP16Test
();
return
0
;
//ClipFP16Test();
//return 0;
//ScaleAndShiftFP16Test();
//return 0;
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-test"
))
Test
();
...
...
@@ -104,6 +106,26 @@ int main(int argc, const char ** argv )
return
0
;
}
void
ScaleAndShiftFP16Test
()
{
XTensor
a
;
XTensor
intA
;
XTensor
b
;
XTensor
intB
;
InitTensor2D
(
&
a
,
1
,
10
,
X_FLOAT
,
0
);
a
.
SetDataRand
(
-
10.0
F
,
10.0
F
);
a
.
Dump
(
stderr
,
"a:"
);
intA
=
ConvertDataType
(
a
,
X_INT
);
intB
=
ScaleAndShift
(
intA
,
2
,
0
);
b
=
ConvertDataType
(
intB
,
X_FLOAT
);
b
.
Dump
(
stderr
,
"b:"
);
}
void
ClipFP16Test
()
{
XTensor
a
;
XTensor
intA
;
...
...
source/tensor/core/math/ScaleAndShift.cu
查看文件 @
30217de4
...
...
@@ -17,6 +17,7 @@
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
* $Update by: Lin Ye (email: linye2015@outlook.com) 2019-07-06 float16/int added
*/
#include "ScaleAndShift.cuh"
...
...
@@ -34,9 +35,9 @@ scale and shift all tensor entires b = a * scale + shift (CUDA Kernel)
>> scale - how much we want to scale it
>> shift - how much we want to shift it
*/
template<bool isUnitScale, bool isZeroShift>
template<
class T,
bool isUnitScale, bool isZeroShift>
__global__
void KernelScaleAndShift(
DTYPE * a, DTYPE * b, int size, DTYPE scale, DTYPE
shift)
void KernelScaleAndShift(
T * a, T * b, int size, T scale, T
shift)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
...
...
@@ -56,28 +57,6 @@ void KernelScaleAndShift(DTYPE * a, DTYPE * b, int size, DTYPE scale, DTYPE shif
}
}
/*
scale and shift all tensor entires p = p * scale + shift (CUDA Kernel)
This is for float16 computation
>> a - the input data array
>> b - the output data array
>> size - the size of d
>> scale - how much we want to scale it
>> shift - how much we want to shift it
*/
__global__
void KernelScaleAndShift(__half * a, __half * b, int size, __half scale, __half shift)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
if(i < size)
b[i] = __hadd(__hmul(a[i], scale), shift);
#else
if (i < size)
b[i] = __float2half(__half2float(a[i]) * __half2float(scale) + __half2float(shift));
#endif
}
/*
scale and shift all tensor entires
...
...
@@ -108,20 +87,54 @@ void _CudaScaleAndShift(const XTensor * a, XTensor * b, DTYPE scale, DTYPE shift
if(a->dataType == DEFAULT_DTYPE){
if(scale == 1.0F && shift == 0)
KernelScaleAndShift<true, true> <<<blocks, threads>>>((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
KernelScaleAndShift<
DTYPE,
true, true> <<<blocks, threads>>>((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
else if (scale == 1.0F && shift != 0)
KernelScaleAndShift<true, false> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
KernelScaleAndShift<
DTYPE,
true, false> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
else if(scale != 1.0F && shift == 0)
KernelScaleAndShift<false, true> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
KernelScaleAndShift<
DTYPE,
false, true> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
else
KernelScaleAndShift<false, false> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
KernelScaleAndShift<
DTYPE,
false, false> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
}
else if(a->dataType == X_FLOAT16){
unsigned short scale2 = FloatToFloat16(scale);
unsigned short shift2 = FloatToFloat16(shift);
__half * scaleft16p = (__half*)&scale2;
__half * shiftft16p = (__half*)&shift2;
KernelScaleAndShift<<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p);
if (scale == 1.0F && shift == 0)
KernelScaleAndShift<__half, true, true><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p);
else if (scale == 1.0F && shift != 0)
KernelScaleAndShift<__half, true, false><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p);
else if (scale != 1.0F && shift == 0)
KernelScaleAndShift<__half, false, true><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p);
else
KernelScaleAndShift<__half, false, false> << <blocks, threads >> >((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p);
}
else if (a->dataType == X_INT) {
int scale2 = int(scale);
int shift2 = int(shift);
if (scale == 1.0F && shift == 0)
KernelScaleAndShift<int, true, true><<<blocks, threads>>>((int *)a->data, (int *)b->data, a->unitNum, scale2, shift2);
else if (scale == 1.0F && shift != 0)
KernelScaleAndShift<int, true, false><<<blocks, threads>>>((int *)a->data, (int *)b->data, a->unitNum, scale2, shift2);
else if (scale != 1.0F && shift == 0)
KernelScaleAndShift<int, false, true><<<blocks, threads>>>((int *)a->data, (int *)b->data, a->unitNum, scale2, shift2);
else
KernelScaleAndShift<int, false, false><<<blocks, threads>>>((int *)a->data, (int *)b->data, a->unitNum, scale2, shift2);
}
else if (a->dataType == X_INT8) {
__int8 scale2 = __int8(scale);
__int8 shift2 = __int8(shift);
if (scale == 1.0F && shift == 0)
KernelScaleAndShift<__int8, true, true> << <blocks, threads >> >((__int8 *)a->data, (__int8 *)b->data, a->unitNum, scale2, shift2);
else if (scale == 1.0F && shift != 0)
KernelScaleAndShift<__int8, true, false> << <blocks, threads >> >((__int8 *)a->data, (__int8 *)b->data, a->unitNum, scale2, shift2);
else if (scale != 1.0F && shift == 0)
KernelScaleAndShift<__int8, false, true> << <blocks, threads >> >((__int8 *)a->data, (__int8 *)b->data, a->unitNum, scale2, shift2);
else
KernelScaleAndShift<__int8, false, false> << <blocks, threads >> >((__int8 *)a->data, (__int8 *)b->data, a->unitNum, scale2, shift2);
}
else{
ShowNTErrors("TODO!");
...
...
source/tensor/test/TMatrixMul.cpp
查看文件 @
30217de4
...
...
@@ -17,7 +17,7 @@
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-06-14
* $Update by: Lin Ye (email: linye2015@outlook.com) 2019-07-0
6 float16
added
* $Update by: Lin Ye (email: linye2015@outlook.com) 2019-07-0
7 float16/int8
added
*/
#include "TMatrixMul.h"
...
...
@@ -683,11 +683,6 @@ bool TestMatrixMul6()
halfSGPU1
=
ConvertDataType
(
*
sGPU1
,
X_FLOAT16
);
halfSGPU2
=
ConvertDataType
(
*
sGPU2
,
X_FLOAT16
);
sGPU1
->
Dump
(
stderr
,
"sGPU1:"
);
sGPU2
->
Dump
(
stderr
,
"sGPU2:"
);
halfSGPU1
.
Dump
(
&
halfSGPU1
,
stderr
,
"halfSGPU1:"
);
halfSGPU2
.
Dump
(
&
halfSGPU2
,
stderr
,
"halfSGPU2:"
);
/* call MatrixMul function */
_MatrixMul
(
&
halfSGPU1
,
X_NOTRANS
,
&
halfSGPU2
,
X_NOTRANS
,
tGPU
);
tUserGPU
=
MatrixMul
(
halfSGPU1
,
X_NOTRANS
,
halfSGPU2
,
X_NOTRANS
,
X_FLOAT
);
...
...
@@ -714,6 +709,207 @@ bool TestMatrixMul6()
#endif // USE_CUDA
}
/*
case 7: int8 matrix multiplication.
In this case, int8 a=(2, 3), int8 b=(3, 2) -> float32 c=(2, 2),
transposedA=X_NOTRANS, transposedB=X_NOTRANS.
*/
bool
TestMatrixMul7
()
{
/* a source tensor of size (2, 3) */
int
sOrder1
=
2
;
int
*
sDimSize1
=
new
int
[
sOrder1
];
sDimSize1
[
0
]
=
2
;
sDimSize1
[
1
]
=
3
;
int
sUnitNum1
=
1
;
for
(
int
i
=
0
;
i
<
sOrder1
;
i
++
)
sUnitNum1
*=
sDimSize1
[
i
];
/* a source tensor of size (3, 2) */
int
sOrder2
=
2
;
int
*
sDimSize2
=
new
int
[
sOrder2
];
sDimSize2
[
0
]
=
3
;
sDimSize2
[
1
]
=
2
;
int
sUnitNum2
=
1
;
for
(
int
i
=
0
;
i
<
sOrder2
;
i
++
)
sUnitNum2
*=
sDimSize2
[
i
];
/* a target tensor of size (2, 2) */
int
tOrder
=
2
;
int
*
tDimSize
=
new
int
[
tOrder
];
tDimSize
[
0
]
=
2
;
tDimSize
[
1
]
=
2
;
int
tUnitNum
=
1
;
for
(
int
i
=
0
;
i
<
tOrder
;
i
++
)
tUnitNum
*=
tDimSize
[
i
];
DTYPE
sData1
[
2
][
3
]
=
{
{
1
,
2
,
3
},
{
-
4
,
5
,
6
}
};
DTYPE
sData2
[
3
][
2
]
=
{
{
0
,
-
1
},
{
1
,
2
},
{
2
,
1
}
};
DTYPE
answer
[
2
][
2
]
=
{
{
8
,
6
},
{
17
,
20
}
};
/* CPU test */
bool
cpuTest
=
true
;
#ifdef USE_CUDA
/* GPU test */
bool
gpuTest
=
true
;
/* create tensor */
XTensor
*
sGPU1
=
NewTensor
(
sOrder1
,
sDimSize1
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
sGPU2
=
NewTensor
(
sOrder2
,
sDimSize2
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
tGPU
=
NewTensor
(
tOrder
,
tDimSize
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
tUserGPU
;
/* create int8 tensors */
XTensor
int8SGPU1
;
XTensor
int8SGPU2
;
/* Initialize variables */
sGPU1
->
SetData
(
sData1
,
sUnitNum1
);
sGPU2
->
SetData
(
sData2
,
sUnitNum2
);
tGPU
->
SetZeroAll
();
/* convert data type from float to int8 */
int8SGPU1
=
ConvertDataType
(
*
sGPU1
,
X_INT8
);
int8SGPU2
=
ConvertDataType
(
*
sGPU2
,
X_INT8
);
/* call MatrixMul function */
_MatrixMul
(
&
int8SGPU1
,
X_NOTRANS
,
&
int8SGPU2
,
X_NOTRANS
,
tGPU
);
tUserGPU
=
MatrixMul
(
int8SGPU1
,
X_NOTRANS
,
int8SGPU2
,
X_NOTRANS
,
X_FLOAT
);
/* check results */
gpuTest
=
tGPU
->
CheckData
(
answer
,
tUnitNum
)
&&
tUserGPU
.
CheckData
(
answer
,
tUnitNum
);
/* destroy variables */
delete
sGPU1
;
delete
sGPU2
;
delete
tGPU
;
delete
[]
sDimSize1
;
delete
[]
sDimSize2
;
delete
[]
tDimSize
;
return
cpuTest
&&
gpuTest
;
#else
/* destroy variables */
delete
[]
sDimSize1
;
delete
[]
sDimSize2
;
delete
[]
tDimSize
;
return
cpuTest
;
#endif // USE_CUDA
}
/*
case 8: int8 matrix multiplication.
In this case, int8 a=(2, 3), int8 b=(3, 2) -> int32 c=(2, 2),
transposedA=X_NOTRANS, transposedB=X_NOTRANS.
*/
bool
TestMatrixMul8
()
{
/* a source tensor of size (2, 3) */
int
sOrder1
=
2
;
int
*
sDimSize1
=
new
int
[
sOrder1
];
sDimSize1
[
0
]
=
2
;
sDimSize1
[
1
]
=
3
;
int
sUnitNum1
=
1
;
for
(
int
i
=
0
;
i
<
sOrder1
;
i
++
)
sUnitNum1
*=
sDimSize1
[
i
];
/* a source tensor of size (3, 2) */
int
sOrder2
=
2
;
int
*
sDimSize2
=
new
int
[
sOrder2
];
sDimSize2
[
0
]
=
3
;
sDimSize2
[
1
]
=
2
;
int
sUnitNum2
=
1
;
for
(
int
i
=
0
;
i
<
sOrder2
;
i
++
)
sUnitNum2
*=
sDimSize2
[
i
];
/* a target tensor of size (2, 2) */
int
tOrder
=
2
;
int
*
tDimSize
=
new
int
[
tOrder
];
tDimSize
[
0
]
=
2
;
tDimSize
[
1
]
=
2
;
int
tUnitNum
=
1
;
for
(
int
i
=
0
;
i
<
tOrder
;
i
++
)
tUnitNum
*=
tDimSize
[
i
];
DTYPE
sData1
[
2
][
3
]
=
{
{
1
,
2
,
3
},
{
-
4
,
5
,
6
}
};
DTYPE
sData2
[
3
][
2
]
=
{
{
0
,
-
1
},
{
1
,
2
},
{
2
,
1
}
};
DTYPE
answer
[
2
][
2
]
=
{
{
8
,
6
},
{
17
,
20
}
};
/* CPU test */
bool
cpuTest
=
true
;
#ifdef USE_CUDA
/* GPU test */
bool
gpuTest
=
true
;
/* create tensor */
XTensor
*
sGPU1
=
NewTensor
(
sOrder1
,
sDimSize1
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
sGPU2
=
NewTensor
(
sOrder2
,
sDimSize2
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
tGPU
=
NewTensor
(
tOrder
,
tDimSize
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
intTGPU
=
NewTensor
(
tOrder
,
tDimSize
,
X_INT
,
1.0
F
,
0
);
XTensor
tUserGPU
;
XTensor
intTUserGPU
;
/* create int8 tensors */
XTensor
int8SGPU1
;
XTensor
int8SGPU2
;
/* Initialize variables */
sGPU1
->
SetData
(
sData1
,
sUnitNum1
);
sGPU2
->
SetData
(
sData2
,
sUnitNum2
);
tGPU
->
SetZeroAll
();
/* convert data type from float to int8 */
int8SGPU1
=
ConvertDataType
(
*
sGPU1
,
X_INT8
);
int8SGPU2
=
ConvertDataType
(
*
sGPU2
,
X_INT8
);
/* call MatrixMul function */
_MatrixMul
(
&
int8SGPU1
,
X_NOTRANS
,
&
int8SGPU2
,
X_NOTRANS
,
intTGPU
);
intTUserGPU
=
MatrixMul
(
int8SGPU1
,
X_NOTRANS
,
int8SGPU2
,
X_NOTRANS
,
X_INT
);
/* convert data type from int to float32 */
_ConvertDataType
(
intTGPU
,
tGPU
);
tUserGPU
=
ConvertDataType
(
intTUserGPU
,
X_FLOAT
);
/* check results */
gpuTest
=
tGPU
->
CheckData
(
answer
,
tUnitNum
)
&&
tUserGPU
.
CheckData
(
answer
,
tUnitNum
);
/* destroy variables */
delete
sGPU1
;
delete
sGPU2
;
delete
tGPU
;
delete
intTGPU
;
delete
[]
sDimSize1
;
delete
[]
sDimSize2
;
delete
[]
tDimSize
;
return
cpuTest
&&
gpuTest
;
#else
/* destroy variables */
delete
[]
sDimSize1
;
delete
[]
sDimSize2
;
delete
[]
tDimSize
;
return
cpuTest
;
#endif // USE_CUDA
}
/* other cases */
...
...
@@ -782,6 +978,24 @@ bool TestMatrixMul()
else
XPRINT
(
0
,
stdout
,
">> case 6 passed!
\n
"
);
/* case 7 test */
caseFlag
=
TestMatrixMul7
();
if
(
!
caseFlag
)
{
returnFlag
=
false
;
XPRINT
(
0
,
stdout
,
">> case 7 failed!
\n
"
);
}
else
XPRINT
(
0
,
stdout
,
">> case 7 passed!
\n
"
);
/* case 8 test */
caseFlag
=
TestMatrixMul8
();
if
(
!
caseFlag
)
{
returnFlag
=
false
;
XPRINT
(
0
,
stdout
,
">> case 8 failed!
\n
"
);
}
else
XPRINT
(
0
,
stdout
,
">> case 8 passed!
\n
"
);
/* other cases test */
/*
TODO!!
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论