Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
a89ee126
Commit
a89ee126
authored
Sep 06, 2020
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
opimize the float16 CPU implementation
parent
89ad96e6
隐藏空白字符变更
内嵌
并排
正在显示
11 个修改的文件
包含
380 行增加
和
151 行删除
+380
-151
source/Main.cpp
+54
-0
source/tensor/XName.cpp
+8
-0
source/tensor/XName.h
+4
-1
source/tensor/XTensor.cpp
+15
-0
source/tensor/core/getandset/SetData.cu
+3
-0
source/tensor/core/math/Compare.cpp
+28
-0
source/tensor/core/math/Compare.cu
+3
-0
source/tensor/core/math/Compare.cuh
+6
-0
source/tensor/core/math/Compare.h
+34
-1
source/tensor/core/utilities/Float16.cpp
+153
-96
source/tensor/core/utilities/Float16.h
+72
-53
没有找到文件。
source/Main.cpp
查看文件 @
a89ee126
...
@@ -36,11 +36,65 @@ using namespace nts;
...
@@ -36,11 +36,65 @@ using namespace nts;
using
namespace
fnnlm
;
using
namespace
fnnlm
;
using
namespace
transformer
;
using
namespace
transformer
;
int
MyTest
()
{
float16
x
;
printf
(
"%f
\n
"
,
x
.
Float
());
x
=
3.5
;
printf
(
"%f
\n
"
,
x
.
Float
());
x
=
0.0
F
;
printf
(
"%f
\n
"
,
x
.
Float
());
x
.
Dump
();
x
=
-
3.5
;
printf
(
"%f
\n
"
,
x
.
Float
());
printf
(
"%d
\n
"
,
sizeof
(
float16
));
FILE
*
f
=
fopen
(
"test_fp16"
,
"w"
);
fwrite
(
&
x
,
sizeof
(
float16
),
1
,
f
);
fclose
(
f
);
FILE
*
f2
=
fopen
(
"test_fp16"
,
"r"
);
fread
(
&
x
,
sizeof
(
float16
),
1
,
f2
);
fclose
(
f2
);
printf
(
"%f
\n
"
,
x
.
Float
());
return
0
;
}
int
MyTest2
()
{
GDevs
.
Init
();
GDevs
.
Clear
();
XTensor
a
;
InitTensor2D
(
&
a
,
2
,
3
,
X_FLOAT
,
0
);
a
.
SetZeroAll
();
ScaleAndShift
(
a
,
1
);
a
.
Dump
();
printf
(
"dump
\n
"
);
getchar
();
return
0
;
}
int
main
(
int
argc
,
const
char
**
argv
)
int
main
(
int
argc
,
const
char
**
argv
)
{
{
//_CrtSetDbgFlag(_CrtSetDbgFlag(_CRTDBG_REPORT_FLAG) | _CRTDBG_LEAK_CHECK_DF);
//_CrtSetDbgFlag(_CrtSetDbgFlag(_CRTDBG_REPORT_FLAG) | _CRTDBG_LEAK_CHECK_DF);
//_CrtSetBreakAlloc(2708);
//_CrtSetBreakAlloc(2708);
//MyTest2();
//printf("release\n");
//getchar();
//GDevs.GPUs[0].Reset();
//printf("reset\n");
//getchar();
//printf("bye.\n");
MyTest
();
exit
(
1
);
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-test"
))
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-test"
))
Test
();
Test
();
else
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-fnnlm"
))
else
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-fnnlm"
))
...
...
source/tensor/XName.cpp
查看文件 @
a89ee126
...
@@ -55,6 +55,10 @@ const char * GetOPName(int type)
...
@@ -55,6 +55,10 @@ const char * GetOPName(int type)
return
"M_ROUND"
;
return
"M_ROUND"
;
else
if
(
type
==
MATH_RECIPROCAL
)
else
if
(
type
==
MATH_RECIPROCAL
)
return
"M_RECIPROCAL"
;
return
"M_RECIPROCAL"
;
else
if
(
type
==
MATH_EQUAL
)
return
"M_EQUAL"
;
else
if
(
type
==
MATH_NOTEQUAL
)
return
"M_NOTEQUAL"
;
else
if
(
type
==
MATH_CLIP
)
else
if
(
type
==
MATH_CLIP
)
return
"M_CLIP"
;
return
"M_CLIP"
;
else
if
(
type
==
MATH_DIV
)
else
if
(
type
==
MATH_DIV
)
...
@@ -67,6 +71,10 @@ const char * GetOPName(int type)
...
@@ -67,6 +71,10 @@ const char * GetOPName(int type)
return
"M_MATRIXMUL"
;
return
"M_MATRIXMUL"
;
else
if
(
type
==
MATH_MATRIXMULBATCHED
)
else
if
(
type
==
MATH_MATRIXMULBATCHED
)
return
"M_MATRIXMULBATCHED"
;
return
"M_MATRIXMULBATCHED"
;
else
if
(
type
==
MATH_MAX
)
return
"M_MAX"
;
else
if
(
type
==
MATH_MIN
)
return
"M_MIN"
;
else
if
(
type
==
MATH_MULTIPLY
)
else
if
(
type
==
MATH_MULTIPLY
)
return
"M_MULTIPLY"
;
return
"M_MULTIPLY"
;
else
if
(
type
==
MATH_MULTIPLYDIM
)
else
if
(
type
==
MATH_MULTIPLYDIM
)
...
...
source/tensor/XName.h
查看文件 @
a89ee126
...
@@ -46,7 +46,10 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
...
@@ -46,7 +46,10 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_ROUND MATH_TAN + 1
#define MATH_ROUND MATH_TAN + 1
#define MATH_RECIPROCAL MATH_ROUND + 1
#define MATH_RECIPROCAL MATH_ROUND + 1
#define MATH_CLIP MATH_RECIPROCAL + 1
#define MATH_EQUAL MATH_RECIPROCAL + 1
#define MATH_NOTEQUAL MATH_EQUAL + 1
#define MATH_CLIP MATH_NOTEQUAL + 1
#define MATH_DIV MATH_CLIP + 1
#define MATH_DIV MATH_CLIP + 1
#define MATH_DIVDIM MATH_DIV + 1
#define MATH_DIVDIM MATH_DIV + 1
#define MATH_MASK MATH_DIVDIM + 1
#define MATH_MASK MATH_DIVDIM + 1
...
...
source/tensor/XTensor.cpp
查看文件 @
a89ee126
...
@@ -1784,9 +1784,15 @@ void XTensor::BinaryDump(FILE* file)
...
@@ -1784,9 +1784,15 @@ void XTensor::BinaryDump(FILE* file)
switch
(
dataType
)
{
switch
(
dataType
)
{
case
X_INT
:
{
case
X_INT
:
{
fwrite
(
tmp
.
data
,
sizeof
(
int
),
unitNum
,
file
);
fwrite
(
tmp
.
data
,
sizeof
(
int
),
unitNum
,
file
);
break
;
}
case
X_FLOAT16
:
{
fwrite
(
tmp
.
data
,
sizeof
(
float16
),
unitNum
,
file
);
break
;
}
}
default
:
{
default
:
{
fwrite
(
tmp
.
data
,
sizeof
(
float
),
unitNum
,
file
);
fwrite
(
tmp
.
data
,
sizeof
(
float
),
unitNum
,
file
);
break
;
}
}
}
}
}
}
...
@@ -1917,12 +1923,21 @@ void XTensor::BinaryRead(FILE* file, size_t offset)
...
@@ -1917,12 +1923,21 @@ void XTensor::BinaryRead(FILE* file, size_t offset)
fread
(
d
,
sizeof
(
int
),
unitNum
,
file
);
fread
(
d
,
sizeof
(
int
),
unitNum
,
file
);
SetData
(
d
,
unitNum
);
SetData
(
d
,
unitNum
);
delete
[]
d
;
delete
[]
d
;
break
;
}
case
X_FLOAT16
:
{
int
*
d
=
new
int
[
unitNum
];
fread
(
d
,
sizeof
(
float16
),
unitNum
,
file
);
SetData
(
d
,
unitNum
);
delete
[]
d
;
break
;
}
}
default
:
{
default
:
{
float
*
d
=
new
float
[
unitNum
];
float
*
d
=
new
float
[
unitNum
];
fread
(
d
,
sizeof
(
float
),
unitNum
,
file
);
fread
(
d
,
sizeof
(
float
),
unitNum
,
file
);
SetData
(
d
,
unitNum
);
SetData
(
d
,
unitNum
);
delete
[]
d
;
delete
[]
d
;
break
;
}
}
}
}
}
}
...
...
source/tensor/core/getandset/SetData.cu
查看文件 @
a89ee126
...
@@ -51,6 +51,7 @@ void KernelSetDataFixed(T * d, T v, int size)
...
@@ -51,6 +51,7 @@ void KernelSetDataFixed(T * d, T v, int size)
template __global__ void KernelSetDataFixed<int>(int *, int, int);
template __global__ void KernelSetDataFixed<int>(int *, int, int);
template __global__ void KernelSetDataFixed<float>(float *, float, int);
template __global__ void KernelSetDataFixed<float>(float *, float, int);
template __global__ void KernelSetDataFixed<double>(double *, double, int);
template __global__ void KernelSetDataFixed<double>(double *, double, int);
template __global__ void KernelSetDataFixed<__half>(__half*, __half, int);
/*
/*
generate data items with a fixed value
generate data items with a fixed value
...
@@ -79,6 +80,8 @@ void _CudaSetDataFixed(XTensor * tensor, T value)
...
@@ -79,6 +80,8 @@ void _CudaSetDataFixed(XTensor * tensor, T value)
KernelSetDataFixed << <blocks, threads >> > ((float*)tensor->data, (float)value, tensor->unitNum);
KernelSetDataFixed << <blocks, threads >> > ((float*)tensor->data, (float)value, tensor->unitNum);
else if (tensor->dataType == X_DOUBLE)
else if (tensor->dataType == X_DOUBLE)
KernelSetDataFixed << <blocks, threads >> > ((double*)tensor->data, (double)value, tensor->unitNum);
KernelSetDataFixed << <blocks, threads >> > ((double*)tensor->data, (double)value, tensor->unitNum);
else if (tensor->dataType == X_FLOAT16)
KernelSetDataFixed << <blocks, threads >> > ((__half*)tensor->data, (__half)value, tensor->unitNum);
else
else
ShowNTErrors("TODO! Unsupported datatype!")
ShowNTErrors("TODO! Unsupported datatype!")
...
...
source/tensor/core/math/Compare.cpp
查看文件 @
a89ee126
...
@@ -92,6 +92,10 @@ XTensor funcName(const XTensor &a, DTYPE number)
...
@@ -92,6 +92,10 @@ XTensor funcName(const XTensor &a, DTYPE number)
XTensor b(&a); \
XTensor b(&a); \
b.SetTMPFlag(); \
b.SetTMPFlag(); \
_funcName(&a, &b, number); \
_funcName(&a, &b, number); \
if (a.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \
XLink::AddParamToHead(&b, (DTYPE)number); \
} \
return b; \
return b; \
}
}
...
@@ -102,6 +106,10 @@ void funcName(const XTensor &a, XTensor &b, DTYPE number)
...
@@ -102,6 +106,10 @@ void funcName(const XTensor &a, XTensor &b, DTYPE number)
InitTensorV2(&b, &a); \
InitTensorV2(&b, &a); \
} \
} \
_funcName(&a, &b, number); \
_funcName(&a, &b, number); \
if (a.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \
XLink::AddParamToHead(&b, (DTYPE)number); \
} \
}
}
// I think we needn't to make link.
// I think we needn't to make link.
...
@@ -186,6 +194,9 @@ XTensor funcName(const XTensor & a, const XTensor & b)
...
@@ -186,6 +194,9 @@ XTensor funcName(const XTensor & a, const XTensor & b)
XTensor c(&a); \
XTensor c(&a); \
c.SetTMPFlag(); \
c.SetTMPFlag(); \
_funcName(&a, &b, &c); \
_funcName(&a, &b, &c); \
if (a.enableGrad && b.enableGrad) { \
XLink::MakeLink(&a, &b, &c, operationId); \
} \
return c; \
return c; \
}
}
...
@@ -196,16 +207,33 @@ void funcName(const XTensor &a, const XTensor &b, XTensor c)
...
@@ -196,16 +207,33 @@ void funcName(const XTensor &a, const XTensor &b, XTensor c)
InitTensor(&c, &a); \
InitTensor(&c, &a); \
} \
} \
_funcName(&a, &b, &c); \
_funcName(&a, &b, &c); \
if (a.enableGrad && b.enableGrad) { \
XLink::MakeLink(&a, &b, &c, operationId); \
} \
}
}
#ifdef USE_CUDA
#ifdef USE_CUDA
_SIMPLE_MAX_MIN_FUNCTION
(
_Equal
,
_CudaEqual
,
myIsEqual
)
_SIMPLE_MAX_MIN_FUNCTION
(
_NotEqual
,
_CudaNotEqual
,
myIsNotEqual
)
_SIMPLE_MAX_MIN_FUNCTION
(
_Max
,
_CudaMax
,
MAX
)
_SIMPLE_MAX_MIN_FUNCTION
(
_Max
,
_CudaMax
,
MAX
)
_SIMPLE_MAX_MIN_FUNCTION
(
_Min
,
_CudaMin
,
MIN
)
_SIMPLE_MAX_MIN_FUNCTION
(
_Min
,
_CudaMin
,
MIN
)
#else
#else
_SIMPLE_MAX_MIN_FUNCTION
(
_Equal
,
myIsEqual
)
_SIMPLE_MAX_MIN_FUNCTION
(
_NotEqual
,
myIsNotEqual
)
_SIMPLE_MAX_MIN_FUNCTION
(
_Max
,
MAX
)
_SIMPLE_MAX_MIN_FUNCTION
(
_Max
,
MAX
)
_SIMPLE_MAX_MIN_FUNCTION
(
_Min
,
MIN
)
_SIMPLE_MAX_MIN_FUNCTION
(
_Min
,
MIN
)
#endif
#endif
_SIMPLE_MAX_MIN_FUNCTION_ME
(
_EqualMe
,
_Equal
)
SIMPLE_MAX_MIN_FUNCTION_ME
(
EqualMe
,
_Equal
)
SIMPLE_MAX_MIN_FUNCTION
(
Equal
,
_Equal
,
MATH_EQUAL
)
SIMPLE_MAX_MIN_FUNCTION_VOID
(
Equal
,
_Equal
,
MATH_EQUAL
)
_SIMPLE_MAX_MIN_FUNCTION_ME
(
_NotEqualMe
,
_NotEqual
)
SIMPLE_MAX_MIN_FUNCTION_ME
(
NotEqualMe
,
_NotEqual
)
SIMPLE_MAX_MIN_FUNCTION
(
NotEqual
,
_NotEqual
,
MATH_NOTEQUAL
)
SIMPLE_MAX_MIN_FUNCTION_VOID
(
NotEqual
,
_NotEqual
,
MATH_NOTEQUAL
)
_SIMPLE_MAX_MIN_FUNCTION_ME
(
_MaxMe
,
_Max
)
_SIMPLE_MAX_MIN_FUNCTION_ME
(
_MaxMe
,
_Max
)
SIMPLE_MAX_MIN_FUNCTION_ME
(
MaxMe
,
_Max
)
SIMPLE_MAX_MIN_FUNCTION_ME
(
MaxMe
,
_Max
)
SIMPLE_MAX_MIN_FUNCTION
(
Max
,
_Max
,
MATH_MAX
)
SIMPLE_MAX_MIN_FUNCTION
(
Max
,
_Max
,
MATH_MAX
)
...
...
source/tensor/core/math/Compare.cu
查看文件 @
a89ee126
...
@@ -134,6 +134,9 @@ void _Cuda##funcName(const XTensor * a, const XTensor * b, XTensor * c) \
...
@@ -134,6 +134,9 @@ void _Cuda##funcName(const XTensor * a, const XTensor * b, XTensor * c) \
BacktoCudaDev(a->devID, devIDBackup); \
BacktoCudaDev(a->devID, devIDBackup); \
}
}
SIMPLE_MAX_MIN_FUNCTION_GPU(Equal, cudaIsEqual)
SIMPLE_MAX_MIN_FUNCTION_GPU(NotEqual, cudaIsNotEqual)
SIMPLE_MAX_MIN_FUNCTION_GPU(Max, max)
SIMPLE_MAX_MIN_FUNCTION_GPU(Max, max)
SIMPLE_MAX_MIN_FUNCTION_GPU(Min, min)
SIMPLE_MAX_MIN_FUNCTION_GPU(Min, min)
...
...
source/tensor/core/math/Compare.cuh
查看文件 @
a89ee126
...
@@ -31,9 +31,15 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
...
@@ -31,9 +31,15 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* check whether every entry is equal to the given value (cuda version) */
/* check whether every entry is equal to the given value (cuda version) */
void _CudaEqual(const XTensor * a, XTensor * b, DTYPE value);
void _CudaEqual(const XTensor * a, XTensor * b, DTYPE value);
/* check whether every entry is equal to the given value (cuda version) */
void _CudaEqual(const XTensor * a, const XTensor * b, XTensor * c);
/* check whether every entry is not equal to the given value (cuda version) */
/* check whether every entry is not equal to the given value (cuda version) */
void _CudaNotEqual(const XTensor * a, XTensor * b, DTYPE value);
void _CudaNotEqual(const XTensor * a, XTensor * b, DTYPE value);
/* check whether every entry is not equal to the given value (cuda version) */
void _CudaNotEqual(const XTensor * a, const XTensor * b, XTensor * c);
/* return maximum of two tensor for each items (cuda version) */
/* return maximum of two tensor for each items (cuda version) */
void _CudaMax(const XTensor * a, const XTensor * b, XTensor *c);
void _CudaMax(const XTensor * a, const XTensor * b, XTensor *c);
...
...
source/tensor/core/math/Compare.h
查看文件 @
a89ee126
...
@@ -39,7 +39,23 @@ void EqualMe(XTensor & a, DTYPE value);
...
@@ -39,7 +39,23 @@ void EqualMe(XTensor & a, DTYPE value);
XTensor
Equal
(
const
XTensor
&
a
,
DTYPE
value
);
XTensor
Equal
(
const
XTensor
&
a
,
DTYPE
value
);
/* check whether every entry is equal to the given value */
/* check whether every entry is equal to the given value */
void
Equal
(
const
XTensor
&
a
,
XTensor
&
b
,
DTYPE
value
);
void
Equal
(
const
XTensor
&
a
,
XTensor
&
b
,
XTensor
&
c
);
/* check whether every entry is equal to the given value */
void
_Equal
(
const
XTensor
*
a
,
const
XTensor
*
b
,
XTensor
*
c
);
/* check whether every entry is equal to the given value (do it on site) */
void
_EqualMe
(
XTensor
*
a
,
XTensor
*
b
);
/* check whether every entry is equal to the given value (do it on site) */
void
EqualMe
(
XTensor
&
a
,
XTensor
&
b
);
/* check whether every entry is equal to the given value (return an XTensor structure) */
XTensor
Equal
(
const
XTensor
&
a
,
const
XTensor
&
b
);
/* check whether every entry is equal to the given value */
void
Equal
(
const
XTensor
&
a
,
const
XTensor
&
b
,
XTensor
&
c
);
/* check whether every entry is not equal to the given value */
/* check whether every entry is not equal to the given value */
void
_NotEqual
(
const
XTensor
*
a
,
XTensor
*
b
,
DTYPE
value
);
void
_NotEqual
(
const
XTensor
*
a
,
XTensor
*
b
,
DTYPE
value
);
...
@@ -56,6 +72,22 @@ XTensor NotEqual(const XTensor & a, DTYPE value);
...
@@ -56,6 +72,22 @@ XTensor NotEqual(const XTensor & a, DTYPE value);
/* check whether every entry is not equal to the given value */
/* check whether every entry is not equal to the given value */
void
NotEqual
(
const
XTensor
&
a
,
XTensor
&
b
,
DTYPE
value
);
void
NotEqual
(
const
XTensor
&
a
,
XTensor
&
b
,
DTYPE
value
);
/* check whether every entry is not equal to the given value */
void
_NotEqual
(
const
XTensor
*
a
,
const
XTensor
*
b
,
XTensor
*
c
);
/* check whether every entry is not equal to the given value (do it on site) */
void
_NotEqualMe
(
XTensor
*
a
,
XTensor
*
b
);
/* check whether every entry is not equal to the given value (do it on site) */
void
NotEqualMe
(
XTensor
&
a
,
XTensor
*
b
);
/* check whether every entry is not equal to the given value (return an XTensor structure) */
XTensor
NotEqual
(
const
XTensor
&
a
,
const
XTensor
&
b
);
/* check whether every entry is not equal to the given value */
void
NotEqual
(
const
XTensor
&
a
,
const
XTensor
&
b
,
XTensor
&
c
);
/* return maximum of two tensor for each items */
/* return maximum of two tensor for each items */
void
_Max
(
const
XTensor
*
a
,
const
XTensor
*
b
,
XTensor
*
c
);
void
_Max
(
const
XTensor
*
a
,
const
XTensor
*
b
,
XTensor
*
c
);
...
@@ -71,6 +103,7 @@ XTensor Max(const XTensor & a, const XTensor & b);
...
@@ -71,6 +103,7 @@ XTensor Max(const XTensor & a, const XTensor & b);
/* return maximum of two tensor for each items */
/* return maximum of two tensor for each items */
void
Max
(
const
XTensor
&
a
,
const
XTensor
&
b
,
XTensor
&
c
);
void
Max
(
const
XTensor
&
a
,
const
XTensor
&
b
,
XTensor
&
c
);
/* return minimum of two tensor for each items */
/* return minimum of two tensor for each items */
void
_Min
(
const
XTensor
*
a
,
const
XTensor
*
b
,
XTensor
*
c
);
void
_Min
(
const
XTensor
*
a
,
const
XTensor
*
b
,
XTensor
*
c
);
...
...
source/tensor/core/utilities/Float16.cpp
查看文件 @
a89ee126
//
/* NiuTrans.Tensor - an open-source tensor library
// float16.cpp
* Copyright (C) 2020, Natural Language Processing Lab, Northestern University.
// 16bit
* All rights reserved.
//
*
// Created by 管胡昊 on 2020/2/5.
* Licensed under the Apache License, Version 2.0 (the "License");
// Copyright © 2020 管胡昊. All rights reserved.
* you may not use this file except in compliance with the License.
//
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Creted by: Guan Huhao 2020-02-05
* $Updated by: Xu Chen (email: hello_master1954@163.com) 2020-05-01
*/
#include "../../XGlobal.h"
#include "../../XGlobal.h"
#include "
F
loat16.h"
#include "
f
loat16.h"
int
float16
::
IsOverlFlow
()
const
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
float16
float16
::
SetOverFlow
()
{
{
return
exp
==
31
;
exp
=
31
;
data
=
0
;
return
*
this
;
}
}
_XINLINE_
float16
float16
::
SetOverFlow
()
int
float16
::
IsOverlFlow
()
const
{
{
exp
=
31
;
return
exp
==
31
;
data
=
0
;
return
*
this
;
}
}
// mask for calculate the highest 1
// mask for calculate the highest 1
unsigned
int
float16
::
mask
[
32
]
=
{
unsigned
int
float16
::
mask
[
32
]
=
{
0xffffffff
,
0xfffffffe
,
0xfffffffc
,
0xfffffff8
,
0xfffffff0
,
0xffffffe0
,
0xffffffc0
,
0xffffff80
,
0xffffffff
,
0xfffffffe
,
0xfffffffc
,
0xfffffff8
,
0xfffffff0
,
0xffffffe0
,
0xffffffc0
,
0xffffff80
,
0xffffff00
,
0xfffffe00
,
0xfffffc00
,
0xfffff800
,
0xfffff000
,
0xffffe000
,
0xffffc000
,
0xffff8000
,
0xffffff00
,
0xfffffe00
,
0xfffffc00
,
0xfffff800
,
0xfffff000
,
0xffffe000
,
0xffffc000
,
0xffff8000
,
0xffff0000
,
0xfffe0000
,
0xfffc0000
,
0xfff80000
,
0xfff00000
,
0xffe00000
,
0xffc00000
,
0xff800000
,
0xffff0000
,
0xfffe0000
,
0xfffc0000
,
0xfff80000
,
0xfff00000
,
0xffe00000
,
0xffc00000
,
0xff800000
,
0xff000000
,
0xfe000000
,
0xfc000000
,
0xf8000000
,
0xf0000000
,
0xe0000000
,
0xc0000000
,
0x80000000
}
0xff000000
,
0xfe000000
,
0xfc000000
,
0xf8000000
,
0xf0000000
,
0xe0000000
,
0xc0000000
,
0x80000000
;
}
;
// to calculate the power of 2
// to calculate the power of 2
unsigned
int
float16
::
pow2
[
32
]
=
{
unsigned
int
float16
::
pow2
[
32
]
=
{
0x00000001
,
0x00000002
,
0x00000004
,
0x00000008
,
0x00000010
,
0x00000020
,
0x00000040
,
0x00000080
,
0x00000001
,
0x00000002
,
0x00000004
,
0x00000008
,
0x00000010
,
0x00000020
,
0x00000040
,
0x00000080
,
0x00000100
,
0x00000200
,
0x00000400
,
0x00000800
,
0x00001000
,
0x00002000
,
0x00004000
,
0x00008000
,
0x00000100
,
0x00000200
,
0x00000400
,
0x00000800
,
0x00001000
,
0x00002000
,
0x00004000
,
0x00008000
,
0x00010000
,
0x00020000
,
0x00040000
,
0x00080000
,
0x00100000
,
0x00200000
,
0x00400000
,
0x00800000
,
0x00010000
,
0x00020000
,
0x00040000
,
0x00080000
,
0x00100000
,
0x00200000
,
0x00400000
,
0x00800000
,
...
@@ -38,23 +56,26 @@ unsigned int float16::pow2[32] = {
...
@@ -38,23 +56,26 @@ unsigned int float16::pow2[32] = {
};
};
// compare the absolute value, if a < b return 1, else return 0
// compare the absolute value, if a < b return 1, else return 0
_XINLINE_
int
float16
::
AbsCompare
(
const
float16
&
a
,
const
float16
&
b
)
int
float16
::
AbsCompare
(
const
float16
&
a
,
const
float16
&
b
)
{
{
if
(
a
.
exp
<
b
.
exp
)
if
(
a
.
exp
<
b
.
exp
)
return
1
;
return
1
;
else
if
(
a
.
exp
>
b
.
exp
)
else
if
(
a
.
exp
>
b
.
exp
)
return
0
;
return
0
;
return
a
.
data
<
b
.
data
;
return
a
.
data
<
b
.
data
;
}
}
// get inverse that a
*inverse(a)==
1
// get inverse that a
* inverse(a) ==
1
_XINLINE_
float16
float16
::
GetInverse
()
const
float16
float16
::
GetInverse
()
const
{
{
float16
ans
;
float16
ans
;
ans
.
sign
=
sign
;
ans
.
sign
=
sign
;
ans
.
exp
=
29
-
exp
;
ans
.
exp
=
29
-
exp
;
int
rec
=
pow2
[
31
];
int
rec
=
pow2
[
31
];
rec
/=
(
this
->
data
|
pow2
[
10
]);
//let it div 0x80000000
//let it div 0x80000000
rec
/=
(
this
->
data
|
pow2
[
10
]);
if
(
!
(
rec
&
pow2
[
21
]))
{
if
(
!
(
rec
&
pow2
[
21
]))
{
rec
<<=
1
;
rec
<<=
1
;
ans
.
exp
++
;
ans
.
exp
++
;
...
@@ -64,20 +85,31 @@ _XINLINE_ float16 float16::GetInverse() const
...
@@ -64,20 +85,31 @@ _XINLINE_ float16 float16::GetInverse() const
return
ans
;
return
ans
;
}
}
// constructor by sign, exp, data
/* constructor by (sign, exp, data), similar to ieee 32 floating point
// sign:1bit exp:5bit data:10bit similar to ieee 32 floating point
>> s - sign: 1bit
_XINLINE_
float16
::
float16
(
const
int
&
s
,
const
int
&
e
,
const
int
&
d
)
{
>> e - exp: 5bit
>> d - data: 10bit
*/
float16
::
float16
(
const
int
&
s
,
const
int
&
e
,
const
int
&
d
)
{
sign
=
s
;
sign
=
s
;
exp
=
e
;
exp
=
e
;
data
=
d
;
data
=
d
;
}
}
// default constructor
/* initializes the 16bit floating point to 0
// This initializes the 16bit floating point to 0.
*/
float16
::
float16
(){
float16
::
float16
()
{
sign
=
0
;
exp
=
0
;
data
=
0
;
}
}
/* constructor by other datatype
We convert the data to float and convert float to float16.
>> data - num
*/
template
<
class
T
>
template
<
class
T
>
float16
::
float16
(
const
T
&
data
)
float16
::
float16
(
const
T
&
data
)
{
{
...
@@ -86,30 +118,37 @@ float16::float16(const T& data)
...
@@ -86,30 +118,37 @@ float16::float16(const T& data)
template
float16
::
float16
(
const
int
&
);
template
float16
::
float16
(
const
int
&
);
template
float16
::
float16
(
const
double
&
);
template
float16
::
float16
(
const
double
&
);
/* constructor by a 32-bit float num
>> data - 32-bit float num
*/
float16
::
float16
(
const
float
&
data
)
{
*
this
=
data
;
}
void
float16
::
Dump
()
{
printf
(
"sign: %d
\t
exp: %d
\t
data: %d
\n
"
,
sign
,
exp
,
data
);
}
/*
/*
c
hange float16 to flaot as you can see the result is a 32-bit floating point
c
onvert float16 to float and return
construct of 32-bit is
construct of 32-bit is
the 31th bit present the sign
the 31th bit present the sign
the 30th~23th bit present the exp, with 128 offset
the 30th~23th bit present the exp, with 128 offset
rest 23th~0th store the data
rest 23th~0th store the data
*/
*/
float
float16
::
Float
()
{
float
float16
::
Float
()
{
int
ret
=
0
;
int
ret
=
0
;
// cout<<this->IsOverlFlow()<<endl;
ret
=
IsOverlFlow
()
?
0x7f800000
:
ret
=
IsOverlFlow
()
?
0x7f800000
:
(
sign
?
0x80000000
:
0
)
|
((
exp
+
112
)
<<
23
)
|
(
data
<<
13
);
(
sign
?
0x80000000
:
0
)
|
((
exp
+
112
)
<<
23
)
|
(
data
<<
13
);
float
p
=
*
(
float
*
)
&
ret
;
float
p
=
*
(
float
*
)
&
ret
;
return
p
;
return
p
;
}
}
// constructor by a 32-bit floating point
// basic assignment function
_XINLINE_
float16
::
float16
(
const
float
&
data
)
float16
float16
::
operator
=
(
const
float16
&
a
)
{
*
this
=
data
;
}
//float assignment function is the basic function
_XINLINE_
float16
float16
::
operator
=
(
const
float16
&
a
)
{
{
sign
=
a
.
sign
;
sign
=
a
.
sign
;
exp
=
a
.
exp
;
exp
=
a
.
exp
;
...
@@ -117,24 +156,26 @@ _XINLINE_ float16 float16::operator = (const float16& a)
...
@@ -117,24 +156,26 @@ _XINLINE_ float16 float16::operator = (const float16& a)
return
*
this
;
return
*
this
;
}
}
//
float assignment function is the basic function
//
convert float to float16
_XINLINE_
float16
float16
::
operator
=
(
const
float
&
a
)
float16
float16
::
operator
=
(
const
float
&
a
)
{
{
unsigned
int
p
=
*
(
unsigned
int
*
)
&
a
;
unsigned
int
p
=
*
(
unsigned
int
*
)
&
a
;
sign
=
p
&
pow2
[
31
]
?
1
:
0
;
sign
=
p
&
pow2
[
31
]
?
1
:
0
;
if
(
a
>
65535
||
a
<
-
65535
)
return
SetOverFlow
();
if
(
a
>
65535
||
a
<
-
65535
)
return
SetOverFlow
();
exp
=
((
p
>>
23
)
&
(
0xf
))
|
((
p
>>
26
&
0x10
));
exp
=
((
p
>>
23
)
&
(
0xf
))
|
((
p
>>
26
&
0x10
));
data
=
(
p
>>
13
);
data
=
(
p
>>
13
);
return
*
this
;
return
*
this
;
}
}
/*
/* Template assignment function is force change other datetype to float,
template assignment function is force change other datetype to float
then call the float assignment function.
then call the float assignment function
Template assignment function now support int and double.
template assignment function now support int,double
*/
*/
template
<
class
T
>
template
<
class
T
>
_XINLINE_
float16
float16
::
operator
=
(
const
T
&
data
)
{
float16
float16
::
operator
=
(
const
T
&
data
)
{
*
this
=
(
float
)
data
;
*
this
=
(
float
)
data
;
return
*
this
;
return
*
this
;
}
}
...
@@ -142,24 +183,24 @@ template float16 float16:: operator = <int>(const int&);
...
@@ -142,24 +183,24 @@ template float16 float16:: operator = <int>(const int&);
template
float16
float16
::
operator
=
<
double
>
(
const
double
&
);
template
float16
float16
::
operator
=
<
double
>
(
const
double
&
);
/*
/*
template for mult
y-datatype overlao
d
template for mult
i-datatype overloa
d
operator is the overload operator. eg. <,
=
>> operator - the overload operator, e.g. <,
=
return_type is the datetype of thr function's return, like
int, float
>> return_type - the returned datetype of function, e.g,
int, float
expression is the expression of retur
n
>> expression - the returned expressio
n
*/
*/
#define _OVERLOAD_OPRATER_TEMPLATE(
O
peration, returnType, expression) \
#define _OVERLOAD_OPRATER_TEMPLATE(
o
peration, returnType, expression) \
template<class T> \
template<class T> \
_XINLINE_ returnType float16::operator Operation (const T & data)
\
returnType float16::operator operation (const T & data)
\
{ \
{ \
float16 rec=(float)data; \
float16 rec=(float)data; \
return expression; \
return expression; \
} \
} \
template returnType float16::operator
O
peration <int>(const int&); \
template returnType float16::operator
o
peration <int>(const int&); \
template returnType float16::operator
O
peration <float>(const float&); \
template returnType float16::operator
o
peration <float>(const float&); \
template returnType float16::operator
O
peration <double>(const double&);
template returnType float16::operator
o
peration <double>(const double&);
// overload operator (less than)
eg.
a<b
// overload operator (less than) a<b
_XINLINE_
int
float16
::
operator
<
(
const
float16
&
data
)
int
float16
::
operator
<
(
const
float16
&
data
)
{
{
if
(
sign
<
data
.
sign
)
if
(
sign
<
data
.
sign
)
return
1
;
return
1
;
...
@@ -175,8 +216,8 @@ _XINLINE_ int float16::operator < (const float16& data)
...
@@ -175,8 +216,8 @@ _XINLINE_ int float16::operator < (const float16& data)
}
}
_OVERLOAD_OPRATER_TEMPLATE
(
<
,
int
,
*
this
<
rec
)
_OVERLOAD_OPRATER_TEMPLATE
(
<
,
int
,
*
this
<
rec
)
// overload opertator <= (less or equal than) a
<=
b
// overload opertator <= (less or equal than) a
<=
b
_XINLINE_
int
float16
::
operator
<=
(
const
float16
&
data
)
int
float16
::
operator
<=
(
const
float16
&
data
)
{
{
if
(
sign
<
data
.
sign
)
if
(
sign
<
data
.
sign
)
return
1
;
return
1
;
...
@@ -192,8 +233,8 @@ _XINLINE_ int float16::operator <= (const float16& data)
...
@@ -192,8 +233,8 @@ _XINLINE_ int float16::operator <= (const float16& data)
}
}
_OVERLOAD_OPRATER_TEMPLATE
(
<=
,
int
,
*
this
<=
rec
)
_OVERLOAD_OPRATER_TEMPLATE
(
<=
,
int
,
*
this
<=
rec
)
// overload operator (greater than)
eg. a>
b
// overload operator (greater than)
a >
b
_XINLINE_
int
float16
::
operator
>
(
const
float16
&
data
)
int
float16
::
operator
>
(
const
float16
&
data
)
{
{
if
(
sign
>
data
.
sign
)
if
(
sign
>
data
.
sign
)
return
1
;
return
1
;
...
@@ -209,8 +250,8 @@ _XINLINE_ int float16::operator > (const float16& data)
...
@@ -209,8 +250,8 @@ _XINLINE_ int float16::operator > (const float16& data)
}
}
_OVERLOAD_OPRATER_TEMPLATE
(
>
,
int
,
*
this
>
rec
)
_OVERLOAD_OPRATER_TEMPLATE
(
>
,
int
,
*
this
>
rec
)
//
overload opertator >= (greater or equal than) a>=
b
//
overload opertator >= (greater or equal than) a >=
b
_XINLINE_
int
float16
::
operator
>=
(
const
float16
&
data
)
int
float16
::
operator
>=
(
const
float16
&
data
)
{
{
if
(
sign
>
data
.
sign
)
if
(
sign
>
data
.
sign
)
return
1
;
return
1
;
...
@@ -226,8 +267,9 @@ _XINLINE_ int float16::operator >= (const float16& data)
...
@@ -226,8 +267,9 @@ _XINLINE_ int float16::operator >= (const float16& data)
}
}
_OVERLOAD_OPRATER_TEMPLATE
(
>=
,
int
,
*
this
<
rec
)
_OVERLOAD_OPRATER_TEMPLATE
(
>=
,
int
,
*
this
<
rec
)
// overide operator +
// overload operator + (add) a + b
_XINLINE_
float16
float16
::
operator
+
(
const
float16
&
data
)
{
float16
float16
::
operator
+
(
const
float16
&
data
)
{
float16
ans
;
float16
ans
;
// avoid overflow inf + anything = inf
// avoid overflow inf + anything = inf
...
@@ -235,6 +277,7 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
...
@@ -235,6 +277,7 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
return
*
this
;
return
*
this
;
if
(
data
.
IsOverlFlow
())
if
(
data
.
IsOverlFlow
())
return
data
;
return
data
;
/* the greater number determine the sign and
/* the greater number determine the sign and
the smaller should be >> to aligment to the greater one */
the smaller should be >> to aligment to the greater one */
if
(
AbsCompare
(
*
this
,
data
))
{
if
(
AbsCompare
(
*
this
,
data
))
{
...
@@ -259,12 +302,12 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
...
@@ -259,12 +302,12 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
recp
--
;
recp
--
;
}
}
}
}
//if data==0, exp should be 0
//
if data==0, exp should be 0
else
else
recp
=
0
;
recp
=
0
;
ans
.
data
=
recd
;
ans
.
data
=
recd
;
//if overflow should set overflow
//
if overflow should set overflow
if
(
recp
>=
31
)
if
(
recp
>=
31
)
ans
.
SetOverFlow
();
ans
.
SetOverFlow
();
else
{
else
{
...
@@ -272,13 +315,13 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
...
@@ -272,13 +315,13 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
ans
.
data
=
recd
;
ans
.
data
=
recd
;
}
}
}
}
//
the
same as above. while divided into two part? reduce assignment to increase efficent
// same as above. while divided into two part? reduce assignment to increase efficent
else
{
else
{
ans
.
sign
=
sign
;
ans
.
sign
=
sign
;
int
recp
=
exp
;
int
recp
=
exp
;
int
recd
=
(
this
->
data
|
(
pow2
[
10
]))
+
int
recd
=
(
this
->
data
|
(
pow2
[
10
]))
+
((
sign
^
data
.
sign
)
?
-
1
:
1
)
*
((
sign
^
data
.
sign
)
?
-
1
:
1
)
*
(((
pow2
[
10
])
|
data
.
data
)
>>
(
exp
-
data
.
exp
));
(((
pow2
[
10
])
|
data
.
data
)
>>
(
exp
-
data
.
exp
));
if
(
recd
)
{
if
(
recd
)
{
while
(
mask
[
10
]
&
recd
)
{
while
(
mask
[
10
]
&
recd
)
{
recd
>>=
1
;
recd
>>=
1
;
...
@@ -289,8 +332,11 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
...
@@ -289,8 +332,11 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
recp
--
;
recp
--
;
}
}
}
}
else
recp
=
0
;
else
if
(
recp
>=
31
)
ans
.
SetOverFlow
();
recp
=
0
;
if
(
recp
>=
31
)
ans
.
SetOverFlow
();
else
{
else
{
ans
.
exp
=
recp
;
ans
.
exp
=
recp
;
ans
.
data
=
recd
;
ans
.
data
=
recd
;
...
@@ -301,21 +347,23 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
...
@@ -301,21 +347,23 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
_OVERLOAD_OPRATER_TEMPLATE
(
+
,
float16
,
*
this
=
*
this
+
rec
)
_OVERLOAD_OPRATER_TEMPLATE
(
+
,
float16
,
*
this
=
*
this
+
rec
)
//overide operator +=
//overide operator +=
_XINLINE_
float16
float16
::
operator
+=
(
const
float16
&
data
)
{
float16
float16
::
operator
+=
(
const
float16
&
data
)
{
return
*
this
=
*
this
+
data
;
return
*
this
=
*
this
+
data
;
}
}
_OVERLOAD_OPRATER_TEMPLATE
(
+=
,
float16
,
*
this
=
*
this
+
rec
)
_OVERLOAD_OPRATER_TEMPLATE
(
+=
,
float16
,
*
this
=
*
this
+
rec
)
//overide operator -(negetive) eg. -a
//overide operator -(negetive) -a
_XINLINE_
float16
float16
::
operator
-
()
{
float16
float16
::
operator
-
()
{
sign
^=
1
;
sign
^=
1
;
float16
rec
=
*
this
;
float16
rec
=
*
this
;
sign
^=
1
;
sign
^=
1
;
return
rec
;
return
rec
;
}
}
//overide operator - (substraction) eg a-b
//overide operator - (substraction) a-b
_XINLINE_
float16
float16
::
operator
-
(
const
float16
&
data
)
{
float16
float16
::
operator
-
(
const
float16
&
data
)
{
float16
ans
;
float16
ans
;
if
(
this
->
IsOverlFlow
())
if
(
this
->
IsOverlFlow
())
return
*
this
;
return
*
this
;
...
@@ -377,49 +425,56 @@ _XINLINE_ float16 float16::operator - (const float16& data) {
...
@@ -377,49 +425,56 @@ _XINLINE_ float16 float16::operator - (const float16& data) {
_OVERLOAD_OPRATER_TEMPLATE
(
-
,
float16
,
*
this
=
*
this
-
rec
)
_OVERLOAD_OPRATER_TEMPLATE
(
-
,
float16
,
*
this
=
*
this
-
rec
)
// overide operator -=
// overide operator -=
_XINLINE_
float16
float16
::
operator
-=
(
const
float16
&
data
)
float16
float16
::
operator
-=
(
const
float16
&
data
)
{
{
return
*
this
=
*
this
-
data
;
return
*
this
=
*
this
-
data
;
}
}
_OVERLOAD_OPRATER_TEMPLATE
(
-=
,
float16
,
*
this
=
*
this
-
rec
)
_OVERLOAD_OPRATER_TEMPLATE
(
-=
,
float16
,
*
this
=
*
this
-
rec
)
// overload operator * (multiple)
eg a*
b
// overload operator * (multiple)
a *
b
_XINLINE_
float16
float16
::
operator
*
(
const
float16
&
data
)
float16
float16
::
operator
*
(
const
float16
&
data
)
{
{
// if(IsOverlFlow()) return *this;
//if(IsOverlFlow())
// if(data.IsOverlFlow()) return data;
// return *this;
//if(data.IsOverlFlow())
// return data;
float16
ans
;
float16
ans
;
// ^ to get zhe result sign different will be 1(negtive),same will be 0 positive;
// ^ to get zhe result sign different will be 1(negtive), same will be 0 positive;
ans
.
sign
=
sign
^
data
.
sign
;
ans
.
sign
=
sign
^
data
.
sign
;
// mul to get answer
// mul to get answer
int
rec
=
(
data
.
data
|
pow2
[
10
])
*
(
this
->
data
|
pow2
[
10
]);
int
rec
=
(
data
.
data
|
pow2
[
10
])
*
(
this
->
data
|
pow2
[
10
]);
//calculat the new exp
// calculat the new exp
int
recp
=
exp
+
data
.
exp
-
15
>
0
?
exp
+
data
.
exp
-
15
:
0
;
int
recp
=
exp
+
data
.
exp
-
15
>
0
?
exp
+
data
.
exp
-
15
:
0
;
// if carryed, to fix the exp, and data
// if carryed, to fix the exp and data
rec
>>=
10
;
rec
>>=
10
;
while
(
rec
&
mask
[
11
])
{
while
(
rec
&
mask
[
11
])
{
++
recp
;
++
recp
;
rec
>>=
1
;
rec
>>=
1
;
}
}
if
(
recp
>=
31
)
if
(
recp
>=
31
)
ans
.
SetOverFlow
();
ans
.
SetOverFlow
();
else
{
else
{
ans
.
exp
=
recp
;
ans
.
exp
=
recp
;
ans
.
data
=
rec
;
//assign data
ans
.
data
=
rec
;
}
}
return
ans
;
return
ans
;
}
}
_OVERLOAD_OPRATER_TEMPLATE
(
*
,
float16
,
(
*
this
)
*
rec
)
_OVERLOAD_OPRATER_TEMPLATE
(
*
,
float16
,
(
*
this
)
*
rec
)
// over
ide operator *=
// over
load operator *= (multiple) a *= b
_XINLINE_
float16
float16
::
operator
*=
(
const
float16
&
data
)
float16
float16
::
operator
*=
(
const
float16
&
data
)
{
{
return
*
this
=
*
this
*
data
;
return
*
this
=
*
this
*
data
;
}
}
_OVERLOAD_OPRATER_TEMPLATE
(
*=
,
float16
,
*
this
=
*
this
*
rec
)
_OVERLOAD_OPRATER_TEMPLATE
(
*=
,
float16
,
*
this
=
*
this
*
rec
)
// overload operator / (division)
rg a/
b
// overload operator / (division)
a /
b
_XINLINE_
float16
float16
::
operator
/
(
const
float16
&
data
)
float16
float16
::
operator
/
(
const
float16
&
data
)
{
{
float16
ans
;
float16
ans
;
// ^ to get zhe result sign different will be 1(negtive),same will be 0 positive;
// ^ to get zhe result sign different will be 1(negtive),same will be 0 positive;
...
@@ -445,8 +500,10 @@ _XINLINE_ float16 float16::operator / (const float16& data)
...
@@ -445,8 +500,10 @@ _XINLINE_ float16 float16::operator / (const float16& data)
}
}
_OVERLOAD_OPRATER_TEMPLATE
(
/
,
float16
,
(
*
this
)
/
rec
)
_OVERLOAD_OPRATER_TEMPLATE
(
/
,
float16
,
(
*
this
)
/
rec
)
//
overide operator /=
//
overload operator /= (division) a /= b
_XINLINE_
float16
float16
::
operator
/=
(
const
float16
&
data
)
{
float16
float16
::
operator
/=
(
const
float16
&
data
)
{
return
*
this
=
*
this
/
data
;
return
*
this
=
*
this
/
data
;
}
}
_OVERLOAD_OPRATER_TEMPLATE
(
/=
,
float16
,
*
this
=
*
this
/
rec
)
_OVERLOAD_OPRATER_TEMPLATE
(
/=
,
float16
,
*
this
=
*
this
/
rec
)
}
// namespace nts(NiuTrans.Tensor)
source/tensor/core/utilities/Float16.h
查看文件 @
a89ee126
//
/* NiuTrans.Tensor - an open-source tensor library
// float16.h
* Copyright (C) 2020, Natural Language Processing Lab, Northestern University.
// 16bit
* All rights reserved.
//
*
// Created by 管胡昊 on 2020/2/5.
* Licensed under the Apache License, Version 2.0 (the "License");
// Copyright © 2020 管胡昊. All rights reserved.
* you may not use this file except in compliance with the License.
//
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Creted by: Guan Huhao 2020-02-05
* $Updated by: Xu Chen (email: hello_master1954@163.com) 2020-05-01
*/
#ifndef FLOAT16_H
#ifndef FLOAT16_H
#define FLOAT16_H
#define FLOAT16_H
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
struct
float16
struct
float16
{
{
//private member variable
private
:
private
:
/*
/*
sign is the sign bit 1 means negative, 0 means positive
sign is the sign bit 1 means negative, 0 means positive
exp is the exponent with 16 offset
exp is the exponent with 16 offset
data is the data,
similar to ieee-754,
the highest is default 1 and ignored
data is the data,
similar to ieee-754,
the highest is default 1 and ignored
*/
*/
unsigned
short
data
:
10
;
unsigned
short
exp
:
5
;
unsigned
short
sign
:
1
;
// mask for calculate the highest 1
// mask for calculate the highest 1
static
unsigned
int
mask
[
32
];
static
unsigned
int
mask
[
32
];
static
unsigned
int
pow2
[
32
];
static
unsigned
int
pow2
[
32
];
// private function
//int FindHighOne(const int &num, int &l, int &r);
int
FindHighOne
(
const
int
&
num
,
int
&
l
,
int
&
r
);
int
AbsCompare
(
const
float16
&
a
,
const
float16
&
b
);
int
AbsCompare
(
const
float16
&
a
,
const
float16
&
b
);
public
:
public
:
unsigned
short
data
:
10
;
unsigned
short
exp
:
5
;
unsigned
short
sign
:
1
;
float16
SetOverFlow
();
float16
SetOverFlow
();
// judge whether overflow
// judge whether overflow
int
IsOverlFlow
()
const
;
int
IsOverlFlow
()
const
;
/* constructor by sign, exp, data
/* constructor by (sign, exp, data)
sign:1bit exp:5bit data:10bit similar to ieee 32 floating point */
similar to ieee 32 floating point
sign: 1bit
exp: 5bit
data: 10bit */
float16
(
const
int
&
s
,
const
int
&
e
,
const
int
&
d
);
float16
(
const
int
&
s
,
const
int
&
e
,
const
int
&
d
);
/* default constructor
/* default constructor
This initializes the 16bit floating point to 0. */
This initializes the 16bit floating point to 0. */
float16
();
float16
();
// constructor by a 32-bit float
ing point
// constructor by a 32-bit float
num
float16
(
const
float
&
data
);
float16
(
const
float
&
data
);
template
<
class
T
>
float16
(
const
T
&
data
);
// constructor by other datatype
// constructor by other datatype
//template<class T> float16(const T &data);
template
<
class
T
>
float16
(
const
T
&
data
);
void
Dump
();
// c
hange float16 to flaot as you can see the result is a 32-bit floating point
// c
onvert float16 to float and return
float
Float
();
float
Float
();
/* assignment function and tempalte function
/* assignment function and tempalte function
float assignment function is the basic function
Float assignment function is the basic function.
template assignment function is force change other datetype to float
Template assignment function is force change other datetype to float,
then call the float assignment function
then call the float assignment function.
template assignment function now support int, double */
Template assignment function now support int and double. */
float16
operator
=
(
const
float16
&
data
);
float16
operator
=
(
const
float
&
data
);
float16
operator
=
(
const
float
&
data
);
template
<
class
T
>
float16
operator
=
(
const
T
&
data
);
float16
operator
=
(
const
float16
&
data
);
template
<
class
T
>
float16
operator
=
(
const
T
&
data
);
// overload operator (less than)
eg. a<
b
// overload operator (less than)
a <
b
int
operator
<
(
const
float16
&
data
);
int
operator
<
(
const
float16
&
data
);
template
<
class
T
>
int
operator
<
(
const
T
&
data
);
template
<
class
T
>
int
operator
<
(
const
T
&
data
);
// overload opertator <= (less or equal than) a
<=
b
// overload opertator <= (less or equal than) a
<=
b
int
operator
<=
(
const
float16
&
data
);
int
operator
<=
(
const
float16
&
data
);
template
<
class
T
>
int
operator
<=
(
const
T
&
data
);
template
<
class
T
>
int
operator
<=
(
const
T
&
data
);
// overload operator (greater than)
eg. a>
b
// overload operator (greater than)
a >
b
int
operator
>
(
const
float16
&
data
);
int
operator
>
(
const
float16
&
data
);
template
<
class
T
>
int
operator
>
(
const
T
&
data
);
template
<
class
T
>
int
operator
>
(
const
T
&
data
);
//
overload opertator <= (greater or equal than) a>=
b
//
overload opertator >= (greater or equal than) a >=
b
int
operator
>=
(
const
float16
&
data
);
int
operator
>=
(
const
float16
&
data
);
template
<
class
T
>
int
operator
>=
(
const
T
&
data
);
template
<
class
T
>
int
operator
>=
(
const
T
&
data
);
// overload operator + (add)
eg. a+
b
// overload operator + (add)
a +
b
float16
operator
+
(
const
float16
&
data
);
float16
operator
+
(
const
float16
&
data
);
template
<
class
T
>
float16
operator
+
(
const
T
&
data
);
template
<
class
T
>
float16
operator
+
(
const
T
&
data
);
// overload operator += (add)
eg. a+=
b
// overload operator += (add)
a +=
b
float16
operator
+=
(
const
float16
&
data
);
float16
operator
+=
(
const
float16
&
data
);
template
<
class
T
>
float16
operator
+=
(
const
T
&
data
);
template
<
class
T
>
float16
operator
+=
(
const
T
&
data
);
// overload operator -(negetive)
eg.
-a
// overload operator -(negetive) -a
float16
operator
-
();
float16
operator
-
();
// overload operator - (substraction)
eg. a-
b
// overload operator - (substraction)
a -
b
float16
operator
-
(
const
float16
&
data
);
float16
operator
-
(
const
float16
&
data
);
template
<
class
T
>
float16
operator
-
(
const
T
&
data
);
template
<
class
T
>
float16
operator
-
(
const
T
&
data
);
// overload operator -= (substraction)
eg. a-=
b
// overload operator -= (substraction)
a -=
b
float16
operator
-=
(
const
float16
&
data
);
float16
operator
-=
(
const
float16
&
data
);
template
<
class
T
>
float16
operator
-=
(
const
T
&
data
);
template
<
class
T
>
float16
operator
-=
(
const
T
&
data
);
// overload operator * (multiple)
eg. a*
b
// overload operator * (multiple)
a *
b
float16
operator
*
(
const
float16
&
data
);
float16
operator
*
(
const
float16
&
data
);
template
<
class
T
>
float16
operator
*
(
const
T
&
data
);
template
<
class
T
>
float16
operator
*
(
const
T
&
data
);
// overload operator *= (multiple)
eg. a*=
b
// overload operator *= (multiple)
a *=
b
float16
operator
*=
(
const
float16
&
data
);
float16
operator
*=
(
const
float16
&
data
);
template
<
class
T
>
float16
operator
*=
(
const
T
&
data
);
template
<
class
T
>
float16
operator
*=
(
const
T
&
data
);
// overload operator / (division)
eg. a/
b
// overload operator / (division)
a /
b
float16
GetInverse
()
const
;
float16
GetInverse
()
const
;
float16
operator
/
(
const
float16
&
data
);
float16
operator
/
(
const
float16
&
data
);
template
<
class
T
>
float16
operator
/
(
const
T
&
data
);
template
<
class
T
>
float16
operator
/
(
const
T
&
data
);
// overload operator /= (division)
eg. a/=
b
// overload operator /= (division)
a /=
b
float16
operator
/=
(
const
float16
&
data
);
float16
operator
/=
(
const
float16
&
data
);
template
<
class
T
>
float16
operator
/=
(
const
T
&
data
);
template
<
class
T
>
float16
operator
/=
(
const
T
&
data
);
};
};
}
// namespace nts(NiuTrans.Tensor)
#endif
/* FLOAT16_H */
#endif
/* FLOAT16_H */
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论