Commit a89ee126 by xuchen

opimize the float16 CPU implementation

parent 89ad96e6
...@@ -36,11 +36,65 @@ using namespace nts; ...@@ -36,11 +36,65 @@ using namespace nts;
using namespace fnnlm; using namespace fnnlm;
using namespace transformer; using namespace transformer;
int MyTest()
{
float16 x;
printf("%f\n", x.Float());
x = 3.5;
printf("%f\n", x.Float());
x = 0.0F;
printf("%f\n", x.Float());
x.Dump();
x = -3.5;
printf("%f\n", x.Float());
printf("%d\n", sizeof(float16));
FILE* f = fopen("test_fp16", "w");
fwrite(&x, sizeof(float16), 1, f);
fclose(f);
FILE* f2 = fopen("test_fp16", "r");
fread(&x, sizeof(float16), 1, f2);
fclose(f2);
printf("%f\n", x.Float());
return 0;
}
int MyTest2()
{
GDevs.Init();
GDevs.Clear();
XTensor a;
InitTensor2D(&a, 2, 3, X_FLOAT, 0);
a.SetZeroAll();
ScaleAndShift(a, 1);
a.Dump();
printf("dump\n");
getchar();
return 0;
}
int main( int argc, const char ** argv ) int main( int argc, const char ** argv )
{ {
//_CrtSetDbgFlag(_CrtSetDbgFlag(_CRTDBG_REPORT_FLAG) | _CRTDBG_LEAK_CHECK_DF); //_CrtSetDbgFlag(_CrtSetDbgFlag(_CRTDBG_REPORT_FLAG) | _CRTDBG_LEAK_CHECK_DF);
//_CrtSetBreakAlloc(2708); //_CrtSetBreakAlloc(2708);
//MyTest2();
//printf("release\n");
//getchar();
//GDevs.GPUs[0].Reset();
//printf("reset\n");
//getchar();
//printf("bye.\n");
MyTest();
exit(1);
if(argc > 1 && !strcmp(argv[1], "-test")) if(argc > 1 && !strcmp(argv[1], "-test"))
Test(); Test();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm")) else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
......
...@@ -55,6 +55,10 @@ const char * GetOPName(int type) ...@@ -55,6 +55,10 @@ const char * GetOPName(int type)
return "M_ROUND"; return "M_ROUND";
else if (type == MATH_RECIPROCAL) else if (type == MATH_RECIPROCAL)
return "M_RECIPROCAL"; return "M_RECIPROCAL";
else if (type == MATH_EQUAL)
return "M_EQUAL";
else if (type == MATH_NOTEQUAL)
return "M_NOTEQUAL";
else if (type == MATH_CLIP) else if (type == MATH_CLIP)
return "M_CLIP"; return "M_CLIP";
else if (type == MATH_DIV) else if (type == MATH_DIV)
...@@ -67,6 +71,10 @@ const char * GetOPName(int type) ...@@ -67,6 +71,10 @@ const char * GetOPName(int type)
return "M_MATRIXMUL"; return "M_MATRIXMUL";
else if (type == MATH_MATRIXMULBATCHED) else if (type == MATH_MATRIXMULBATCHED)
return "M_MATRIXMULBATCHED"; return "M_MATRIXMULBATCHED";
else if (type == MATH_MAX)
return "M_MAX";
else if (type == MATH_MIN)
return "M_MIN";
else if (type == MATH_MULTIPLY) else if (type == MATH_MULTIPLY)
return "M_MULTIPLY"; return "M_MULTIPLY";
else if (type == MATH_MULTIPLYDIM) else if (type == MATH_MULTIPLYDIM)
......
...@@ -46,7 +46,10 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -46,7 +46,10 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_ROUND MATH_TAN + 1 #define MATH_ROUND MATH_TAN + 1
#define MATH_RECIPROCAL MATH_ROUND + 1 #define MATH_RECIPROCAL MATH_ROUND + 1
#define MATH_CLIP MATH_RECIPROCAL + 1 #define MATH_EQUAL MATH_RECIPROCAL + 1
#define MATH_NOTEQUAL MATH_EQUAL + 1
#define MATH_CLIP MATH_NOTEQUAL + 1
#define MATH_DIV MATH_CLIP + 1 #define MATH_DIV MATH_CLIP + 1
#define MATH_DIVDIM MATH_DIV + 1 #define MATH_DIVDIM MATH_DIV + 1
#define MATH_MASK MATH_DIVDIM + 1 #define MATH_MASK MATH_DIVDIM + 1
......
...@@ -1784,9 +1784,15 @@ void XTensor::BinaryDump(FILE* file) ...@@ -1784,9 +1784,15 @@ void XTensor::BinaryDump(FILE* file)
switch (dataType) { switch (dataType) {
case X_INT: { case X_INT: {
fwrite(tmp.data, sizeof(int), unitNum, file); fwrite(tmp.data, sizeof(int), unitNum, file);
break;
}
case X_FLOAT16: {
fwrite(tmp.data, sizeof(float16), unitNum, file);
break;
} }
default: { default: {
fwrite(tmp.data, sizeof(float), unitNum, file); fwrite(tmp.data, sizeof(float), unitNum, file);
break;
} }
} }
} }
...@@ -1917,12 +1923,21 @@ void XTensor::BinaryRead(FILE* file, size_t offset) ...@@ -1917,12 +1923,21 @@ void XTensor::BinaryRead(FILE* file, size_t offset)
fread(d, sizeof(int), unitNum, file); fread(d, sizeof(int), unitNum, file);
SetData(d, unitNum); SetData(d, unitNum);
delete[] d; delete[] d;
break;
}
case X_FLOAT16: {
int* d = new int[unitNum];
fread(d, sizeof(float16), unitNum, file);
SetData(d, unitNum);
delete[] d;
break;
} }
default: { default: {
float* d = new float[unitNum]; float* d = new float[unitNum];
fread(d, sizeof(float), unitNum, file); fread(d, sizeof(float), unitNum, file);
SetData(d, unitNum); SetData(d, unitNum);
delete[] d; delete[] d;
break;
} }
} }
} }
......
...@@ -51,6 +51,7 @@ void KernelSetDataFixed(T * d, T v, int size) ...@@ -51,6 +51,7 @@ void KernelSetDataFixed(T * d, T v, int size)
template __global__ void KernelSetDataFixed<int>(int *, int, int); template __global__ void KernelSetDataFixed<int>(int *, int, int);
template __global__ void KernelSetDataFixed<float>(float *, float, int); template __global__ void KernelSetDataFixed<float>(float *, float, int);
template __global__ void KernelSetDataFixed<double>(double *, double, int); template __global__ void KernelSetDataFixed<double>(double *, double, int);
template __global__ void KernelSetDataFixed<__half>(__half*, __half, int);
/* /*
generate data items with a fixed value generate data items with a fixed value
...@@ -79,6 +80,8 @@ void _CudaSetDataFixed(XTensor * tensor, T value) ...@@ -79,6 +80,8 @@ void _CudaSetDataFixed(XTensor * tensor, T value)
KernelSetDataFixed << <blocks, threads >> > ((float*)tensor->data, (float)value, tensor->unitNum); KernelSetDataFixed << <blocks, threads >> > ((float*)tensor->data, (float)value, tensor->unitNum);
else if (tensor->dataType == X_DOUBLE) else if (tensor->dataType == X_DOUBLE)
KernelSetDataFixed << <blocks, threads >> > ((double*)tensor->data, (double)value, tensor->unitNum); KernelSetDataFixed << <blocks, threads >> > ((double*)tensor->data, (double)value, tensor->unitNum);
else if (tensor->dataType == X_FLOAT16)
KernelSetDataFixed << <blocks, threads >> > ((__half*)tensor->data, (__half)value, tensor->unitNum);
else else
ShowNTErrors("TODO! Unsupported datatype!") ShowNTErrors("TODO! Unsupported datatype!")
......
...@@ -92,6 +92,10 @@ XTensor funcName(const XTensor &a, DTYPE number) ...@@ -92,6 +92,10 @@ XTensor funcName(const XTensor &a, DTYPE number)
XTensor b(&a); \ XTensor b(&a); \
b.SetTMPFlag(); \ b.SetTMPFlag(); \
_funcName(&a, &b, number); \ _funcName(&a, &b, number); \
if (a.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \
XLink::AddParamToHead(&b, (DTYPE)number); \
} \
return b; \ return b; \
} }
...@@ -102,6 +106,10 @@ void funcName(const XTensor &a, XTensor &b, DTYPE number) ...@@ -102,6 +106,10 @@ void funcName(const XTensor &a, XTensor &b, DTYPE number)
InitTensorV2(&b, &a); \ InitTensorV2(&b, &a); \
} \ } \
_funcName(&a, &b, number); \ _funcName(&a, &b, number); \
if (a.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \
XLink::AddParamToHead(&b, (DTYPE)number); \
} \
} }
// I think we needn't to make link. // I think we needn't to make link.
...@@ -186,6 +194,9 @@ XTensor funcName(const XTensor & a, const XTensor & b) ...@@ -186,6 +194,9 @@ XTensor funcName(const XTensor & a, const XTensor & b)
XTensor c(&a); \ XTensor c(&a); \
c.SetTMPFlag(); \ c.SetTMPFlag(); \
_funcName(&a, &b, &c); \ _funcName(&a, &b, &c); \
if (a.enableGrad && b.enableGrad) { \
XLink::MakeLink(&a, &b, &c, operationId); \
} \
return c; \ return c; \
} }
...@@ -196,16 +207,33 @@ void funcName(const XTensor &a, const XTensor &b, XTensor c) ...@@ -196,16 +207,33 @@ void funcName(const XTensor &a, const XTensor &b, XTensor c)
InitTensor(&c, &a); \ InitTensor(&c, &a); \
} \ } \
_funcName(&a, &b, &c); \ _funcName(&a, &b, &c); \
if (a.enableGrad && b.enableGrad) { \
XLink::MakeLink(&a, &b, &c, operationId); \
} \
} }
#ifdef USE_CUDA #ifdef USE_CUDA
_SIMPLE_MAX_MIN_FUNCTION(_Equal, _CudaEqual, myIsEqual)
_SIMPLE_MAX_MIN_FUNCTION(_NotEqual, _CudaNotEqual, myIsNotEqual)
_SIMPLE_MAX_MIN_FUNCTION(_Max, _CudaMax, MAX) _SIMPLE_MAX_MIN_FUNCTION(_Max, _CudaMax, MAX)
_SIMPLE_MAX_MIN_FUNCTION(_Min, _CudaMin, MIN) _SIMPLE_MAX_MIN_FUNCTION(_Min, _CudaMin, MIN)
#else #else
_SIMPLE_MAX_MIN_FUNCTION(_Equal, myIsEqual)
_SIMPLE_MAX_MIN_FUNCTION(_NotEqual, myIsNotEqual)
_SIMPLE_MAX_MIN_FUNCTION(_Max, MAX) _SIMPLE_MAX_MIN_FUNCTION(_Max, MAX)
_SIMPLE_MAX_MIN_FUNCTION(_Min, MIN) _SIMPLE_MAX_MIN_FUNCTION(_Min, MIN)
#endif #endif
_SIMPLE_MAX_MIN_FUNCTION_ME(_EqualMe, _Equal)
SIMPLE_MAX_MIN_FUNCTION_ME(EqualMe, _Equal)
SIMPLE_MAX_MIN_FUNCTION(Equal, _Equal, MATH_EQUAL)
SIMPLE_MAX_MIN_FUNCTION_VOID(Equal, _Equal, MATH_EQUAL)
_SIMPLE_MAX_MIN_FUNCTION_ME(_NotEqualMe, _NotEqual)
SIMPLE_MAX_MIN_FUNCTION_ME(NotEqualMe, _NotEqual)
SIMPLE_MAX_MIN_FUNCTION(NotEqual, _NotEqual, MATH_NOTEQUAL)
SIMPLE_MAX_MIN_FUNCTION_VOID(NotEqual, _NotEqual, MATH_NOTEQUAL)
_SIMPLE_MAX_MIN_FUNCTION_ME(_MaxMe, _Max) _SIMPLE_MAX_MIN_FUNCTION_ME(_MaxMe, _Max)
SIMPLE_MAX_MIN_FUNCTION_ME(MaxMe, _Max) SIMPLE_MAX_MIN_FUNCTION_ME(MaxMe, _Max)
SIMPLE_MAX_MIN_FUNCTION(Max, _Max, MATH_MAX) SIMPLE_MAX_MIN_FUNCTION(Max, _Max, MATH_MAX)
......
...@@ -134,6 +134,9 @@ void _Cuda##funcName(const XTensor * a, const XTensor * b, XTensor * c) \ ...@@ -134,6 +134,9 @@ void _Cuda##funcName(const XTensor * a, const XTensor * b, XTensor * c) \
BacktoCudaDev(a->devID, devIDBackup); \ BacktoCudaDev(a->devID, devIDBackup); \
} }
SIMPLE_MAX_MIN_FUNCTION_GPU(Equal, cudaIsEqual)
SIMPLE_MAX_MIN_FUNCTION_GPU(NotEqual, cudaIsNotEqual)
SIMPLE_MAX_MIN_FUNCTION_GPU(Max, max) SIMPLE_MAX_MIN_FUNCTION_GPU(Max, max)
SIMPLE_MAX_MIN_FUNCTION_GPU(Min, min) SIMPLE_MAX_MIN_FUNCTION_GPU(Min, min)
......
...@@ -31,9 +31,15 @@ namespace nts{ // namespace nts(NiuTrans.Tensor) ...@@ -31,9 +31,15 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* check whether every entry is equal to the given value (cuda version) */ /* check whether every entry is equal to the given value (cuda version) */
void _CudaEqual(const XTensor * a, XTensor * b, DTYPE value); void _CudaEqual(const XTensor * a, XTensor * b, DTYPE value);
/* check whether every entry is equal to the given value (cuda version) */
void _CudaEqual(const XTensor * a, const XTensor * b, XTensor * c);
/* check whether every entry is not equal to the given value (cuda version) */ /* check whether every entry is not equal to the given value (cuda version) */
void _CudaNotEqual(const XTensor * a, XTensor * b, DTYPE value); void _CudaNotEqual(const XTensor * a, XTensor * b, DTYPE value);
/* check whether every entry is not equal to the given value (cuda version) */
void _CudaNotEqual(const XTensor * a, const XTensor * b, XTensor * c);
/* return maximum of two tensor for each items (cuda version) */ /* return maximum of two tensor for each items (cuda version) */
void _CudaMax(const XTensor * a, const XTensor * b, XTensor *c); void _CudaMax(const XTensor * a, const XTensor * b, XTensor *c);
......
...@@ -39,7 +39,23 @@ void EqualMe(XTensor & a, DTYPE value); ...@@ -39,7 +39,23 @@ void EqualMe(XTensor & a, DTYPE value);
XTensor Equal(const XTensor & a, DTYPE value); XTensor Equal(const XTensor & a, DTYPE value);
/* check whether every entry is equal to the given value */ /* check whether every entry is equal to the given value */
void Equal(const XTensor & a, XTensor & b, DTYPE value); void Equal(const XTensor & a, XTensor & b, XTensor & c);
/* check whether every entry is equal to the given value */
void _Equal(const XTensor * a, const XTensor * b, XTensor * c);
/* check whether every entry is equal to the given value (do it on site) */
void _EqualMe(XTensor * a, XTensor * b);
/* check whether every entry is equal to the given value (do it on site) */
void EqualMe(XTensor & a, XTensor & b);
/* check whether every entry is equal to the given value (return an XTensor structure) */
XTensor Equal(const XTensor & a, const XTensor & b);
/* check whether every entry is equal to the given value */
void Equal(const XTensor & a, const XTensor & b, XTensor & c);
/* check whether every entry is not equal to the given value */ /* check whether every entry is not equal to the given value */
void _NotEqual(const XTensor * a, XTensor * b, DTYPE value); void _NotEqual(const XTensor * a, XTensor * b, DTYPE value);
...@@ -56,6 +72,22 @@ XTensor NotEqual(const XTensor & a, DTYPE value); ...@@ -56,6 +72,22 @@ XTensor NotEqual(const XTensor & a, DTYPE value);
/* check whether every entry is not equal to the given value */ /* check whether every entry is not equal to the given value */
void NotEqual(const XTensor & a, XTensor & b, DTYPE value); void NotEqual(const XTensor & a, XTensor & b, DTYPE value);
/* check whether every entry is not equal to the given value */
void _NotEqual(const XTensor * a, const XTensor * b, XTensor * c);
/* check whether every entry is not equal to the given value (do it on site) */
void _NotEqualMe(XTensor * a, XTensor * b);
/* check whether every entry is not equal to the given value (do it on site) */
void NotEqualMe(XTensor & a, XTensor * b);
/* check whether every entry is not equal to the given value (return an XTensor structure) */
XTensor NotEqual(const XTensor & a, const XTensor & b);
/* check whether every entry is not equal to the given value */
void NotEqual(const XTensor & a, const XTensor & b, XTensor & c);
/* return maximum of two tensor for each items */ /* return maximum of two tensor for each items */
void _Max(const XTensor * a, const XTensor * b, XTensor * c); void _Max(const XTensor * a, const XTensor * b, XTensor * c);
...@@ -71,6 +103,7 @@ XTensor Max(const XTensor & a, const XTensor & b); ...@@ -71,6 +103,7 @@ XTensor Max(const XTensor & a, const XTensor & b);
/* return maximum of two tensor for each items */ /* return maximum of two tensor for each items */
void Max(const XTensor & a, const XTensor & b, XTensor & c); void Max(const XTensor & a, const XTensor & b, XTensor & c);
/* return minimum of two tensor for each items */ /* return minimum of two tensor for each items */
void _Min(const XTensor * a, const XTensor * b, XTensor * c); void _Min(const XTensor * a, const XTensor * b, XTensor * c);
......
// /* NiuTrans.Tensor - an open-source tensor library
// float16.cpp * Copyright (C) 2020, Natural Language Processing Lab, Northestern University.
// 16bit * All rights reserved.
// *
// Created by 管胡昊 on 2020/2/5. * Licensed under the Apache License, Version 2.0 (the "License");
// Copyright © 2020 管胡昊. All rights reserved. * you may not use this file except in compliance with the License.
// * You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Creted by: Guan Huhao 2020-02-05
* $Updated by: Xu Chen (email: hello_master1954@163.com) 2020-05-01
*/
#include "../../XGlobal.h" #include "../../XGlobal.h"
#include "Float16.h" #include "float16.h"
int float16::IsOverlFlow() const namespace nts { // namespace nts(NiuTrans.Tensor)
float16 float16::SetOverFlow()
{ {
return exp==31; exp = 31;
data = 0;
return *this;
} }
_XINLINE_ float16 float16::SetOverFlow() int float16::IsOverlFlow() const
{ {
exp=31; return exp==31;
data=0;
return *this;
} }
// mask for calculate the highest 1 // mask for calculate the highest 1
unsigned int float16::mask[32] = { unsigned int float16::mask[32] =
{
0xffffffff,0xfffffffe,0xfffffffc,0xfffffff8,0xfffffff0,0xffffffe0,0xffffffc0,0xffffff80, 0xffffffff,0xfffffffe,0xfffffffc,0xfffffff8,0xfffffff0,0xffffffe0,0xffffffc0,0xffffff80,
0xffffff00,0xfffffe00,0xfffffc00,0xfffff800,0xfffff000,0xffffe000,0xffffc000,0xffff8000, 0xffffff00,0xfffffe00,0xfffffc00,0xfffff800,0xfffff000,0xffffe000,0xffffc000,0xffff8000,
0xffff0000,0xfffe0000,0xfffc0000,0xfff80000,0xfff00000,0xffe00000,0xffc00000,0xff800000, 0xffff0000,0xfffe0000,0xfffc0000,0xfff80000,0xfff00000,0xffe00000,0xffc00000,0xff800000,
0xff000000,0xfe000000,0xfc000000,0xf8000000,0xf0000000,0xe0000000,0xc0000000,0x80000000} 0xff000000,0xfe000000,0xfc000000,0xf8000000,0xf0000000,0xe0000000,0xc0000000,0x80000000
; };
// to calculate the power of 2 // to calculate the power of 2
unsigned int float16::pow2[32] = { unsigned int float16::pow2[32] =
{
0x00000001,0x00000002,0x00000004,0x00000008,0x00000010,0x00000020,0x00000040,0x00000080, 0x00000001,0x00000002,0x00000004,0x00000008,0x00000010,0x00000020,0x00000040,0x00000080,
0x00000100,0x00000200,0x00000400,0x00000800,0x00001000,0x00002000,0x00004000,0x00008000, 0x00000100,0x00000200,0x00000400,0x00000800,0x00001000,0x00002000,0x00004000,0x00008000,
0x00010000,0x00020000,0x00040000,0x00080000,0x00100000,0x00200000,0x00400000,0x00800000, 0x00010000,0x00020000,0x00040000,0x00080000,0x00100000,0x00200000,0x00400000,0x00800000,
...@@ -38,23 +56,26 @@ unsigned int float16::pow2[32] = { ...@@ -38,23 +56,26 @@ unsigned int float16::pow2[32] = {
}; };
// compare the absolute value, if a < b return 1, else return 0 // compare the absolute value, if a < b return 1, else return 0
_XINLINE_ int float16::AbsCompare(const float16 & a, const float16 & b) int float16::AbsCompare(const float16 & a, const float16 & b)
{ {
if (a.exp < b.exp) if (a.exp < b.exp)
return 1; return 1;
else if (a.exp > b.exp) else if (a.exp > b.exp)
return 0; return 0;
return a.data < b.data; return a.data < b.data;
} }
// get inverse that a*inverse(a)==1 // get inverse that a * inverse(a) == 1
_XINLINE_ float16 float16::GetInverse() const float16 float16::GetInverse() const
{ {
float16 ans; float16 ans;
ans.sign = sign; ans.sign = sign;
ans.exp = 29 - exp; ans.exp = 29 - exp;
int rec = pow2[31]; int rec = pow2[31];
rec /= (this->data | pow2[10]); //let it div 0x80000000 //let it div 0x80000000
rec /= (this->data | pow2[10]);
if (!(rec & pow2[21])) { if (!(rec & pow2[21])) {
rec <<= 1; rec <<= 1;
ans.exp++; ans.exp++;
...@@ -64,20 +85,31 @@ _XINLINE_ float16 float16::GetInverse() const ...@@ -64,20 +85,31 @@ _XINLINE_ float16 float16::GetInverse() const
return ans; return ans;
} }
// constructor by sign, exp, data /* constructor by (sign, exp, data), similar to ieee 32 floating point
// sign:1bit exp:5bit data:10bit similar to ieee 32 floating point >> s - sign: 1bit
_XINLINE_ float16::float16(const int& s, const int& e, const int& d) { >> e - exp: 5bit
>> d - data: 10bit
*/
float16::float16(const int& s, const int& e, const int& d)
{
sign = s; sign = s;
exp = e; exp = e;
data = d; data = d;
} }
// default constructor /* initializes the 16bit floating point to 0
// This initializes the 16bit floating point to 0. */
float16::float16(){ float16::float16()
{
sign = 0;
exp = 0;
data = 0;
} }
/* constructor by other datatype
We convert the data to float and convert float to float16.
>> data - num
*/
template<class T> template<class T>
float16::float16(const T& data) float16::float16(const T& data)
{ {
...@@ -86,30 +118,37 @@ float16::float16(const T& data) ...@@ -86,30 +118,37 @@ float16::float16(const T& data)
template float16::float16 (const int &); template float16::float16 (const int &);
template float16::float16 (const double &); template float16::float16 (const double &);
/* constructor by a 32-bit float num
>> data - 32-bit float num
*/
float16::float16(const float& data)
{
*this = data;
}
void float16::Dump()
{
printf("sign: %d\texp: %d\tdata: %d\n", sign, exp, data);
}
/* /*
change float16 to flaot as you can see the result is a 32-bit floating point convert float16 to float and return
construct of 32-bit is construct of 32-bit is
the 31th bit present the sign the 31th bit present the sign
the 30th~23th bit present the exp, with 128 offset the 30th~23th bit present the exp, with 128 offset
rest 23th~0th store the data rest 23th~0th store the data
*/ */
float float16::Float() { float float16::Float()
{
int ret = 0; int ret = 0;
// cout<<this->IsOverlFlow()<<endl;
ret = IsOverlFlow() ? 0x7f800000 : ret = IsOverlFlow() ? 0x7f800000 :
(sign ? 0x80000000 : 0) | ((exp + 112) << 23) | (data << 13); (sign ? 0x80000000 : 0) | ((exp + 112) << 23) | (data << 13);
float p = *(float*)&ret; float p = *(float*)&ret;
return p; return p;
} }
// constructor by a 32-bit floating point // basic assignment function
_XINLINE_ float16::float16(const float& data) float16 float16::operator = (const float16& a)
{
*this = data;
}
//float assignment function is the basic function
_XINLINE_ float16 float16::operator = (const float16& a)
{ {
sign = a.sign; sign = a.sign;
exp = a.exp; exp = a.exp;
...@@ -117,24 +156,26 @@ _XINLINE_ float16 float16::operator = (const float16& a) ...@@ -117,24 +156,26 @@ _XINLINE_ float16 float16::operator = (const float16& a)
return *this; return *this;
} }
//float assignment function is the basic function // convert float to float16
_XINLINE_ float16 float16::operator = (const float& a) float16 float16::operator = (const float& a)
{ {
unsigned int p = *(unsigned int*)&a; unsigned int p = *(unsigned int*)&a;
sign = p & pow2[31] ? 1 : 0; sign = p & pow2[31] ? 1 : 0;
if (a > 65535 || a < -65535) return SetOverFlow();
if (a > 65535 || a < -65535)
return SetOverFlow();
exp = ((p >> 23)& (0xf)) | ((p >> 26 & 0x10)); exp = ((p >> 23)& (0xf)) | ((p >> 26 & 0x10));
data = (p >> 13); data = (p >> 13);
return *this; return *this;
} }
/* /* Template assignment function is force change other datetype to float,
template assignment function is force change other datetype to float then call the float assignment function.
then call the float assignment function Template assignment function now support int and double.
template assignment function now support int,double
*/ */
template <class T> template <class T>
_XINLINE_ float16 float16::operator = (const T& data) { float16 float16::operator = (const T& data)
{
*this = (float)data; *this = (float)data;
return *this; return *this;
} }
...@@ -142,24 +183,24 @@ template float16 float16:: operator = <int>(const int&); ...@@ -142,24 +183,24 @@ template float16 float16:: operator = <int>(const int&);
template float16 float16:: operator = <double>(const double&); template float16 float16:: operator = <double>(const double&);
/* /*
template for multy-datatype overlaod template for multi-datatype overload
operator is the overload operator. eg. <,= >> operator - the overload operator, e.g. <, =
return_type is the datetype of thr function's return, like int, float >> return_type - the returned datetype of function, e.g, int, float
expression is the expression of return >> expression - the returned expression
*/ */
#define _OVERLOAD_OPRATER_TEMPLATE(Operation, returnType, expression) \ #define _OVERLOAD_OPRATER_TEMPLATE(operation, returnType, expression) \
template<class T> \ template<class T> \
_XINLINE_ returnType float16::operator Operation (const T & data) \ returnType float16::operator operation (const T & data) \
{ \ { \
float16 rec=(float)data; \ float16 rec=(float)data; \
return expression; \ return expression; \
} \ } \
template returnType float16::operator Operation <int>(const int&); \ template returnType float16::operator operation <int>(const int&); \
template returnType float16::operator Operation <float>(const float&); \ template returnType float16::operator operation <float>(const float&); \
template returnType float16::operator Operation <double>(const double&); template returnType float16::operator operation <double>(const double&);
// overload operator (less than) eg. a<b // overload operator (less than) a<b
_XINLINE_ int float16::operator < (const float16& data) int float16::operator < (const float16& data)
{ {
if (sign < data.sign) if (sign < data.sign)
return 1; return 1;
...@@ -175,8 +216,8 @@ _XINLINE_ int float16::operator < (const float16& data) ...@@ -175,8 +216,8 @@ _XINLINE_ int float16::operator < (const float16& data)
} }
_OVERLOAD_OPRATER_TEMPLATE(< , int, *this < rec) _OVERLOAD_OPRATER_TEMPLATE(< , int, *this < rec)
// overload opertator <= (less or equal than) a<=b // overload opertator <= (less or equal than) a <= b
_XINLINE_ int float16::operator <= (const float16& data) int float16::operator <= (const float16& data)
{ {
if (sign < data.sign) if (sign < data.sign)
return 1; return 1;
...@@ -192,8 +233,8 @@ _XINLINE_ int float16::operator <= (const float16& data) ...@@ -192,8 +233,8 @@ _XINLINE_ int float16::operator <= (const float16& data)
} }
_OVERLOAD_OPRATER_TEMPLATE(<= , int, *this <= rec) _OVERLOAD_OPRATER_TEMPLATE(<= , int, *this <= rec)
// overload operator (greater than) eg. a>b // overload operator (greater than) a > b
_XINLINE_ int float16::operator > (const float16& data) int float16::operator > (const float16& data)
{ {
if (sign > data.sign) if (sign > data.sign)
return 1; return 1;
...@@ -209,8 +250,8 @@ _XINLINE_ int float16::operator > (const float16& data) ...@@ -209,8 +250,8 @@ _XINLINE_ int float16::operator > (const float16& data)
} }
_OVERLOAD_OPRATER_TEMPLATE(> , int, * this > rec) _OVERLOAD_OPRATER_TEMPLATE(> , int, * this > rec)
//overload opertator >= (greater or equal than) a>=b // overload opertator >= (greater or equal than) a >= b
_XINLINE_ int float16::operator >= (const float16& data) int float16::operator >= (const float16& data)
{ {
if (sign > data.sign) if (sign > data.sign)
return 1; return 1;
...@@ -226,8 +267,9 @@ _XINLINE_ int float16::operator >= (const float16& data) ...@@ -226,8 +267,9 @@ _XINLINE_ int float16::operator >= (const float16& data)
} }
_OVERLOAD_OPRATER_TEMPLATE(>= , int, *this < rec) _OVERLOAD_OPRATER_TEMPLATE(>= , int, *this < rec)
// overide operator + // overload operator + (add) a + b
_XINLINE_ float16 float16::operator + (const float16& data) { float16 float16::operator + (const float16& data)
{
float16 ans; float16 ans;
// avoid overflow inf + anything = inf // avoid overflow inf + anything = inf
...@@ -235,6 +277,7 @@ _XINLINE_ float16 float16::operator + (const float16& data) { ...@@ -235,6 +277,7 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
return *this; return *this;
if (data.IsOverlFlow()) if (data.IsOverlFlow())
return data; return data;
/* the greater number determine the sign and /* the greater number determine the sign and
the smaller should be >> to aligment to the greater one */ the smaller should be >> to aligment to the greater one */
if (AbsCompare(*this, data)) { if (AbsCompare(*this, data)) {
...@@ -259,12 +302,12 @@ _XINLINE_ float16 float16::operator + (const float16& data) { ...@@ -259,12 +302,12 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
recp--; recp--;
} }
} }
//if data==0, exp should be 0 // if data==0, exp should be 0
else else
recp = 0; recp = 0;
ans.data = recd; ans.data = recd;
//if overflow should set overflow // if overflow should set overflow
if (recp >= 31) if (recp >= 31)
ans.SetOverFlow(); ans.SetOverFlow();
else { else {
...@@ -272,7 +315,7 @@ _XINLINE_ float16 float16::operator + (const float16& data) { ...@@ -272,7 +315,7 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
ans.data = recd; ans.data = recd;
} }
} }
//the same as above. while divided into two part? reduce assignment to increase efficent // same as above. while divided into two part? reduce assignment to increase efficent
else { else {
ans.sign = sign; ans.sign = sign;
int recp = exp; int recp = exp;
...@@ -289,8 +332,11 @@ _XINLINE_ float16 float16::operator + (const float16& data) { ...@@ -289,8 +332,11 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
recp--; recp--;
} }
} }
else recp = 0; else
if (recp >= 31) ans.SetOverFlow(); recp = 0;
if (recp >= 31)
ans.SetOverFlow();
else { else {
ans.exp = recp; ans.exp = recp;
ans.data = recd; ans.data = recd;
...@@ -301,21 +347,23 @@ _XINLINE_ float16 float16::operator + (const float16& data) { ...@@ -301,21 +347,23 @@ _XINLINE_ float16 float16::operator + (const float16& data) {
_OVERLOAD_OPRATER_TEMPLATE(+, float16, *this = *this + rec) _OVERLOAD_OPRATER_TEMPLATE(+, float16, *this = *this + rec)
//overide operator += //overide operator +=
_XINLINE_ float16 float16::operator+=(const float16& data) { float16 float16::operator+=(const float16& data) {
return *this = *this + data; return *this = *this + data;
} }
_OVERLOAD_OPRATER_TEMPLATE(+=, float16, *this = *this + rec) _OVERLOAD_OPRATER_TEMPLATE(+=, float16, *this = *this + rec)
//overide operator -(negetive) eg. -a //overide operator -(negetive) -a
_XINLINE_ float16 float16::operator - () { float16 float16::operator - ()
{
sign ^= 1; sign ^= 1;
float16 rec = *this; float16 rec = *this;
sign ^= 1; sign ^= 1;
return rec; return rec;
} }
//overide operator - (substraction) eg a-b //overide operator - (substraction) a-b
_XINLINE_ float16 float16::operator - (const float16& data) { float16 float16::operator - (const float16& data)
{
float16 ans; float16 ans;
if (this->IsOverlFlow()) if (this->IsOverlFlow())
return *this; return *this;
...@@ -377,49 +425,56 @@ _XINLINE_ float16 float16::operator - (const float16& data) { ...@@ -377,49 +425,56 @@ _XINLINE_ float16 float16::operator - (const float16& data) {
_OVERLOAD_OPRATER_TEMPLATE(-, float16, *this = *this - rec) _OVERLOAD_OPRATER_TEMPLATE(-, float16, *this = *this - rec)
// overide operator -= // overide operator -=
_XINLINE_ float16 float16::operator-=(const float16& data) float16 float16::operator-=(const float16& data)
{ {
return *this = *this - data; return *this = *this - data;
} }
_OVERLOAD_OPRATER_TEMPLATE(-=, float16, *this = *this - rec) _OVERLOAD_OPRATER_TEMPLATE(-=, float16, *this = *this - rec)
// overload operator * (multiple) eg a*b // overload operator * (multiple) a * b
_XINLINE_ float16 float16::operator * (const float16& data) float16 float16::operator * (const float16& data)
{ {
// if(IsOverlFlow()) return *this; //if(IsOverlFlow())
// if(data.IsOverlFlow()) return data; // return *this;
//if(data.IsOverlFlow())
// return data;
float16 ans; float16 ans;
// ^ to get zhe result sign different will be 1(negtive),same will be 0 positive; // ^ to get zhe result sign different will be 1(negtive), same will be 0 positive;
ans.sign = sign ^ data.sign; ans.sign = sign ^ data.sign;
// mul to get answer // mul to get answer
int rec = (data.data | pow2[10]) * (this->data | pow2[10]); int rec = (data.data | pow2[10]) * (this->data | pow2[10]);
//calculat the new exp
// calculat the new exp
int recp = exp + data.exp - 15 > 0 ? exp + data.exp - 15 : 0; int recp = exp + data.exp - 15 > 0 ? exp + data.exp - 15 : 0;
// if carryed, to fix the exp, and data
// if carryed, to fix the exp and data
rec >>= 10; rec >>= 10;
while (rec & mask[11]) { while (rec & mask[11]) {
++recp; ++recp;
rec >>= 1; rec >>= 1;
} }
if (recp >= 31) if (recp >= 31)
ans.SetOverFlow(); ans.SetOverFlow();
else { else {
ans.exp = recp; ans.exp = recp;
ans.data = rec;//assign data ans.data = rec;
} }
return ans; return ans;
} }
_OVERLOAD_OPRATER_TEMPLATE(*, float16, (*this)* rec) _OVERLOAD_OPRATER_TEMPLATE(*, float16, (*this)* rec)
// overide operator *= // overload operator *= (multiple) a *= b
_XINLINE_ float16 float16::operator*=(const float16& data) float16 float16::operator *= (const float16& data)
{ {
return *this = *this * data; return *this = *this * data;
} }
_OVERLOAD_OPRATER_TEMPLATE(*=, float16, *this = *this * rec) _OVERLOAD_OPRATER_TEMPLATE(*=, float16, *this = *this * rec)
// overload operator / (division) rg a/b // overload operator / (division) a / b
_XINLINE_ float16 float16::operator / (const float16& data) float16 float16::operator / (const float16& data)
{ {
float16 ans; float16 ans;
// ^ to get zhe result sign different will be 1(negtive),same will be 0 positive; // ^ to get zhe result sign different will be 1(negtive),same will be 0 positive;
...@@ -445,8 +500,10 @@ _XINLINE_ float16 float16::operator / (const float16& data) ...@@ -445,8 +500,10 @@ _XINLINE_ float16 float16::operator / (const float16& data)
} }
_OVERLOAD_OPRATER_TEMPLATE(/ , float16, (*this) / rec) _OVERLOAD_OPRATER_TEMPLATE(/ , float16, (*this) / rec)
//overide operator /= // overload operator /= (division) a /= b
_XINLINE_ float16 float16::operator/=(const float16& data) { float16 float16::operator /= (const float16& data) {
return *this = *this / data; return *this = *this / data;
} }
_OVERLOAD_OPRATER_TEMPLATE(/=, float16, *this = *this / rec) _OVERLOAD_OPRATER_TEMPLATE(/=, float16, *this = *this / rec)
} // namespace nts(NiuTrans.Tensor)
// /* NiuTrans.Tensor - an open-source tensor library
// float16.h * Copyright (C) 2020, Natural Language Processing Lab, Northestern University.
// 16bit * All rights reserved.
// *
// Created by 管胡昊 on 2020/2/5. * Licensed under the Apache License, Version 2.0 (the "License");
// Copyright © 2020 管胡昊. All rights reserved. * you may not use this file except in compliance with the License.
// * You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Creted by: Guan Huhao 2020-02-05
* $Updated by: Xu Chen (email: hello_master1954@163.com) 2020-05-01
*/
#ifndef FLOAT16_H #ifndef FLOAT16_H
#define FLOAT16_H #define FLOAT16_H
namespace nts { // namespace nts(NiuTrans.Tensor)
struct float16 struct float16
{ {
//private member variable
private: private:
/* /*
sign is the sign bit 1 means negative, 0 means positive sign is the sign bit 1 means negative, 0 means positive
exp is the exponent with 16 offset exp is the exponent with 16 offset
data is the data,similar to ieee-754,the highest is default 1 and ignored data is the data, similar to ieee-754, the highest is default 1 and ignored
*/ */
unsigned short data : 10;
unsigned short exp : 5;
unsigned short sign : 1;
// mask for calculate the highest 1 // mask for calculate the highest 1
static unsigned int mask[32]; static unsigned int mask[32];
static unsigned int pow2[32]; static unsigned int pow2[32];
// private function //int FindHighOne(const int &num, int &l, int &r);
int FindHighOne(const int &num, int &l, int &r);
int AbsCompare(const float16 & a,const float16 & b); int AbsCompare(const float16 & a,const float16 & b);
public: public:
unsigned short data : 10;
unsigned short exp : 5;
unsigned short sign : 1;
float16 SetOverFlow(); float16 SetOverFlow();
// judge whether overflow // judge whether overflow
int IsOverlFlow() const; int IsOverlFlow() const;
/* constructor by sign, exp, data /* constructor by (sign, exp, data)
sign:1bit exp:5bit data:10bit similar to ieee 32 floating point */ similar to ieee 32 floating point
sign: 1bit
exp: 5bit
data: 10bit */
float16(const int& s, const int& e, const int& d); float16(const int& s, const int& e, const int& d);
/* default constructor /* default constructor
This initializes the 16bit floating point to 0. */ This initializes the 16bit floating point to 0. */
float16(); float16();
// constructor by a 32-bit floating point // constructor by a 32-bit float num
float16(const float& data); float16(const float& data);
template<class T> float16(const T& data);
// constructor by other datatype // constructor by other datatype
//template<class T> float16(const T &data); template<class T> float16(const T& data);
void Dump();
// change float16 to flaot as you can see the result is a 32-bit floating point // convert float16 to float and return
float Float(); float Float();
/* assignment function and tempalte function /* assignment function and tempalte function
float assignment function is the basic function Float assignment function is the basic function.
template assignment function is force change other datetype to float Template assignment function is force change other datetype to float,
then call the float assignment function then call the float assignment function.
template assignment function now support int, double */ Template assignment function now support int and double. */
float16 operator = (const float16& data);
float16 operator = (const float& data); float16 operator = (const float& data);
float16 operator = (const float16& data);
template<class T> float16 operator = (const T& data); template<class T> float16 operator = (const T& data);
// overload operator (less than) eg. a<b // overload operator (less than) a < b
int operator < (const float16& data); int operator < (const float16& data);
template<class T> int operator <(const T& data); template<class T> int operator < (const T& data);
// overload opertator <= (less or equal than) a<=b // overload opertator <= (less or equal than) a <= b
int operator <= (const float16& data); int operator <= (const float16& data);
template<class T> int operator <=(const T& data); template<class T> int operator <= (const T& data);
// overload operator (greater than) eg. a>b // overload operator (greater than) a > b
int operator > (const float16& data); int operator > (const float16& data);
template<class T> int operator >(const T& data); template<class T> int operator > (const T& data);
//overload opertator <= (greater or equal than) a>=b // overload opertator >= (greater or equal than) a >= b
int operator >= (const float16& data); int operator >= (const float16& data);
template<class T> int operator >=(const T& data); template<class T> int operator >= (const T& data);
// overload operator + (add) eg. a+b // overload operator + (add) a + b
float16 operator + (const float16& data); float16 operator + (const float16& data);
template<class T> float16 operator +(const T& data); template<class T> float16 operator + (const T& data);
// overload operator += (add) eg. a+=b // overload operator += (add) a += b
float16 operator += (const float16& data); float16 operator += (const float16& data);
template<class T> float16 operator +=(const T& data); template<class T> float16 operator += (const T& data);
// overload operator -(negetive) eg. -a // overload operator -(negetive) -a
float16 operator - (); float16 operator - ();
// overload operator - (substraction) eg. a-b // overload operator - (substraction) a - b
float16 operator - (const float16& data); float16 operator - (const float16& data);
template<class T> float16 operator -(const T& data); template<class T> float16 operator - (const T& data);
// overload operator -= (substraction) eg. a-=b // overload operator -= (substraction) a -= b
float16 operator -= (const float16& data); float16 operator -= (const float16& data);
template<class T> float16 operator -=(const T& data); template<class T> float16 operator -= (const T& data);
// overload operator * (multiple) eg. a*b // overload operator * (multiple) a * b
float16 operator * (const float16& data); float16 operator * (const float16& data);
template<class T> float16 operator *(const T& data); template<class T> float16 operator * (const T& data);
// overload operator *= (multiple) eg. a*=b // overload operator *= (multiple) a *= b
float16 operator *= (const float16& data); float16 operator *= (const float16& data);
template<class T> float16 operator *=(const T& data); template<class T> float16 operator *= (const T& data);
// overload operator / (division) eg. a/b // overload operator / (division) a / b
float16 GetInverse() const; float16 GetInverse() const;
float16 operator / (const float16& data); float16 operator / (const float16& data);
template<class T> float16 operator /(const T& data); template<class T> float16 operator / (const T& data);
// overload operator /= (division) eg. a/=b // overload operator /= (division) a /= b
float16 operator /= (const float16& data); float16 operator /= (const float16& data);
template<class T> float16 operator /=(const T& data); template<class T> float16 operator /= (const T& data);
}; };
} // namespace nts(NiuTrans.Tensor)
#endif /* FLOAT16_H */ #endif /* FLOAT16_H */
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论