#define __XDATATYPE_H__
#define __XDATATYPE_H__
#include "XGlobal.h"
#include "core/utilities/Float16.h"
/* the nts (NiuTrans.Tensor) namespace */
namespace nts{
namespace nts{
extern const char * GetDataTypeName(TENSOR_DATA_TYPE type);
extern TENSOR_DATA_TYPE GetDataType(const char * typeName);
/* data conversion (for lower precision computation) */
unsigned short FloatToFloat16(float f);
float Float16ToFloat(unsigned short h);
unsigned short FloatToFloat16(float f);
float Float16ToFloat(unsigned short h);
for (int i = 0; i < input->unitNum; i++)
outputData[i] = (float)inputData[i];
else if (input->dataType == X_FLOAT && output->dataType == X_FLOAT16) {
float* inputData = (float*)input->data;
float16* outputData = (float16*)output->data;
for (int i = 0; i < input->unitNum; i++)
outputData[i] = (float16)inputData[i];
else if (input->dataType == X_FLOAT16 && output->dataType == X_FLOAT) {
float16* inputData = (float16*)input->data;
float* outputData = (float*)output->data;
for (int i = 0; i < input->unitNum; i++)
outputData[i] = inputData[i].Float();
ShowNTErrors("Unsupported data types for conversion!");
// float16.cpp
// 16bit
// Created by 管胡昊 on 2020/2/5.
// Copyright © 2020 管胡昊. All rights reserved.
#include "../../XGlobal.h"
#include "Float16.h"
int float16::IsOverlFlow() const
return exp==31;
_XINLINE_ float16 float16::SetOverFlow()
return *this;
// mask for calculate the highest 1
unsigned int float16::mask[32] = {
// to calculate the power of 2
unsigned int float16::pow2[32] = {
// compare the absolute value, if a < b return 1, else return 0
_XINLINE_ int float16::AbsCompare(const float16 & a, const float16 & b)
if (a.exp < b.exp)
return 1;
else if (a.exp > b.exp)
return 0;
return <;
// get inverse that a*inverse(a)==1
_XINLINE_ float16 float16::GetInverse() const
float16 ans;
ans.sign = sign;
ans.exp = 29 - exp;
int rec = pow2[31];
rec /= (this->data | pow2[10]); //let it div 0x80000000
if (!(rec & pow2[21])) {
rec <<= 1;
rec >>= 10; = rec;
return ans;
// constructor by sign, exp, data
// sign:1bit exp:5bit data:10bit similar to ieee 32 floating point
_XINLINE_ float16::float16(const int& s, const int& e, const int& d) {
sign = s;
exp = e;
data = d;
// default constructor
// This initializes the 16bit floating point to 0.
template<class T>
float16::float16(const T& data)
*this = (float)data;
template float16::float16 (const int &);
template float16::float16 (const double &);
change float16 to flaot as you can see the result is a 32-bit floating point
construct of 32-bit is
the 31th bit present the sign
the 30th~23th bit present the exp, with 128 offset
rest 23th~0th store the data
float float16::Float() {
int ret = 0;
// cout<<this->IsOverlFlow()<<endl;
ret = IsOverlFlow() ? 0x7f800000 :
(sign ? 0x80000000 : 0) | ((exp + 112) << 23) | (data << 13);
float p = *(float*)&ret;
return p;
// constructor by a 32-bit floating point
_XINLINE_ float16::float16(const float& data)
*this = data;
//float assignment function is the basic function
_XINLINE_ float16 float16::operator = (const float16& a)
sign = a.sign;
exp = a.exp;
data =;
return *this;
//float assignment function is the basic function
_XINLINE_ float16 float16::operator = (const float& a)
unsigned int p = *(unsigned int*)&a;
sign = p & pow2[31] ? 1 : 0;
if (a > 65535 || a < -65535) return SetOverFlow();
exp = ((p >> 23)& (0xf)) | ((p >> 26 & 0x10));
data = (p >> 13);
return *this;
template assignment function is force change other datetype to float
then call the float assignment function
template assignment function now support int,double
template <class T>
_XINLINE_ float16 float16::operator = (const T& data) {
*this = (float)data;
return *this;
template float16 float16:: operator = <int>(const int&);
template float16 float16:: operator = <double>(const double&);
template for multy-datatype overlaod
operator is the overload operator. eg. <,=
return_type is the datetype of thr function's return, like int, float
expression is the expression of return
#define _OVERLOAD_OPRATER_TEMPLATE(Operation, returnType, expression) \
template<class T> \
_XINLINE_ returnType float16::operator Operation (const T & data) \
{ \
float16 rec=(float)data; \
return expression; \
} \
template returnType float16::operator Operation <int>(const int&); \
template returnType float16::operator Operation <float>(const float&); \
template returnType float16::operator Operation <double>(const double&);
// overload operator (less than) eg. a<b
_XINLINE_ int float16::operator < (const float16& data)
if (sign < data.sign)
return 1;
else if (sign > data.sign)
return 0;
if (exp < data.exp)
return 1;
else if (exp > data.exp)
return 0;
return this->data <;
_OVERLOAD_OPRATER_TEMPLATE(< , int, *this < rec)
// overload opertator <= (less or equal than) a<=b
_XINLINE_ int float16::operator <= (const float16& data)
if (sign < data.sign)
return 1;
else if (sign > data.sign)
return 0;
if (exp < data.exp)
return 1;
else if (exp > data.exp)
return 0;
return this->data <=;
_OVERLOAD_OPRATER_TEMPLATE(<= , int, *this <= rec)
// overload operator (greater than) eg. a>b
_XINLINE_ int float16::operator > (const float16& data)
if (sign > data.sign)
return 1;
else if (sign < data.sign)
return 0;
if (exp > data.exp)
return 1;
else if (exp < data.exp)
return 0;
return this->data >;
_OVERLOAD_OPRATER_TEMPLATE(> , int, * this > rec)
//overload opertator >= (greater or equal than) a>=b
_XINLINE_ int float16::operator >= (const float16& data)
if (sign > data.sign)
return 1;
else if (sign < data.sign)
return 0;
if (exp > data.exp)
return 1;
else if (exp < data.exp)
return 0;
return this->data >=;
_OVERLOAD_OPRATER_TEMPLATE(>= , int, *this < rec)
// overide operator +
_XINLINE_ float16 float16::operator + (const float16& data) {
float16 ans;
// avoid overflow inf + anything = inf
if (this->IsOverlFlow())
return *this;
if (data.IsOverlFlow())
return data;
/* the greater number determine the sign and
the smaller should be >> to aligment to the greater one */
if (AbsCompare(*this, data)) {
ans.sign = data.sign;
// rec the exp
int recp = data.exp;
//to calculate the data
int recd = ( | (pow2[10])) +
((data.sign ^ sign) ? -1 : 1) *
(((pow2[10]) | this->data) >> (data.exp - exp));
//because the date may carry, if carryed >> the data, and change its exp
if (recd) {
//to make the highest one is 10th bit
while (mask[10] & recd) {
recd >>= 1;
//to make the highest one is 10th bit
while (!(mask[10] & recd)) {
recd <<= 1;
//if data==0, exp should be 0
recp = 0; = recd;
//if overflow should set overflow
if (recp >= 31)
else {
ans.exp = recp; = recd;
//the same as above. while divided into two part? reduce assignment to increase efficent
else {
ans.sign = sign;
int recp = exp;
int recd = (this->data | (pow2[10])) +
((sign ^ data.sign) ? -1 : 1) *
(((pow2[10]) | >> (exp - data.exp));
if (recd) {
while (mask[10] & recd) {
recd >>= 1;
while (!(mask[10] & recd)) {
recd <<= 1;
else recp = 0;
if (recp >= 31) ans.SetOverFlow();
else {
ans.exp = recp; = recd;
return ans;
_OVERLOAD_OPRATER_TEMPLATE(+, float16, *this = *this + rec)
//overide operator +=
_XINLINE_ float16 float16::operator+=(const float16& data) {
return *this = *this + data;
_OVERLOAD_OPRATER_TEMPLATE(+=, float16, *this = *this + rec)
//overide operator -(negetive) eg. -a
_XINLINE_ float16 float16::operator - () {
sign ^= 1;
float16 rec = *this;
sign ^= 1;
return rec;
//overide operator - (substraction) eg a-b
_XINLINE_ float16 float16::operator - (const float16& data) {
float16 ans;
if (this->IsOverlFlow())
return *this;
if (data.IsOverlFlow())
return data;
/* same as add only diffrent is the sign judge,
a possitive number sub a greater number will be negtive. */
if (AbsCompare(*this, data)) {
ans.sign = !data.sign;
int recp = data.exp;
int recd = ( | (pow2[10])) +
((data.sign ^ sign) ? 1 : -1) *
(((pow2[10]) | this->data) >> (data.exp - exp));
if (recd) {
while (mask[10] & recd) {
recd >>= 1;
while (!(mask[10] & recd)) {
recd <<= 1;
else recp = 0;
if (recp >= 31)
else { = recd;
ans.exp = recp;
else {
ans.sign = sign;
int recp = exp;
int recd = (this->data | (pow2[10])) +
((sign ^ data.sign) ? 1 : -1) *
(((pow2[10]) | >> (exp - data.exp));
if (recd) {
while (mask[10] & recd) {
recd >>= 1;
while (!(mask[10] & recd)) {
recd <<= 1;
else recp = 0;
if (recp >= 31)
else { = recd;
ans.exp = recp;
return ans;
_OVERLOAD_OPRATER_TEMPLATE(-, float16, *this = *this - rec)
// overide operator -=
_XINLINE_ float16 float16::operator-=(const float16& data)
return *this = *this - data;
_OVERLOAD_OPRATER_TEMPLATE(-=, float16, *this = *this - rec)
// overload operator * (multiple) eg a*b
_XINLINE_ float16 float16::operator * (const float16& data)
// if(IsOverlFlow()) return *this;
// if(data.IsOverlFlow()) return data;
float16 ans;
// ^ to get zhe result sign different will be 1(negtive),same will be 0 positive;
ans.sign = sign ^ data.sign;
// mul to get answer
int rec = ( | pow2[10]) * (this->data | pow2[10]);
//calculat the new exp
int recp = exp + data.exp - 15 > 0 ? exp + data.exp - 15 : 0;
// if carryed, to fix the exp, and data
rec >>= 10;
while (rec & mask[11]) {
rec >>= 1;
if (recp >= 31)
else {
ans.exp = recp; = rec;//assign data
return ans;
_OVERLOAD_OPRATER_TEMPLATE(*, float16, (*this)* rec)
// overide operator *=
_XINLINE_ float16 float16::operator*=(const float16& data)
return *this = *this * data;
_OVERLOAD_OPRATER_TEMPLATE(*=, float16, *this = *this * rec)
// overload operator / (division) rg a/b
_XINLINE_ float16 float16::operator / (const float16& data)
float16 ans;
// ^ to get zhe result sign different will be 1(negtive),same will be 0 positive;
ans.sign = sign ^ data.sign;
// calculat the new exp
int recp = exp - data.exp + 14;
// defore div should move to the left to avoid precision loss
int recd = (this->data << 21) | pow2[31];
recd /= ( | pow2[10]);
// to make the highest one is the 21st bit
if (recd & pow2[21]) {
recd >>= 1;
if (recp >= 31)
else {
recd >>= 10; = recd;
ans.exp = recp;
return ans;
_OVERLOAD_OPRATER_TEMPLATE(/ , float16, (*this) / rec)
//overide operator /=
_XINLINE_ float16 float16::operator/=(const float16& data) {
return *this = *this / data;
_OVERLOAD_OPRATER_TEMPLATE(/=, float16, *this = *this / rec)
// float16.h
// 16bit
// Created by 管胡昊 on 2020/2/5.
// Copyright © 2020 管胡昊. All rights reserved.
#ifndef FLOAT16_H
#define FLOAT16_H
struct float16
//private member variable
sign is the sign bit 1 means negative, 0 means positive
exp is the exponent with 16 offset
data is the data,similar to ieee-754,the highest is default 1 and ignored
// mask for calculate the highest 1
static unsigned int mask[32];
static unsigned int pow2[32];
// private function
int FindHighOne(const int &num, int &l, int &r);
int AbsCompare(const float16 & a,const float16 & b);
unsigned short data : 10;
unsigned short exp : 5;
unsigned short sign : 1;
float16 SetOverFlow();
// judge whether overflow
int IsOverlFlow() const;
/* constructor by sign, exp, data
sign:1bit exp:5bit data:10bit similar to ieee 32 floating point */
float16(const int& s, const int& e, const int& d);
/* default constructor
This initializes the 16bit floating point to 0. */
// constructor by a 32-bit floating point
float16(const float& data);
template<class T> float16(const T& data);
// constructor by other datatype
//template<class T> float16(const T &data);
// change float16 to flaot as you can see the result is a 32-bit floating point
float Float();
/* assignment function and tempalte function
float assignment function is the basic function
template assignment function is force change other datetype to float
then call the float assignment function
template assignment function now support int, double */
float16 operator = (const float16& data);
float16 operator = (const float& data);
template<class T> float16 operator = (const T& data);
// overload operator (less than) eg. a<b
int operator < (const float16& data);
template<class T> int operator <(const T& data);
// overload opertator <= (less or equal than) a<=b
int operator <= (const float16& data);
template<class T> int operator <=(const T& data);
// overload operator (greater than) eg. a>b
int operator > (const float16& data);
template<class T> int operator >(const T& data);
//overload opertator <= (greater or equal than) a>=b
int operator >= (const float16& data);
template<class T> int operator >=(const T& data);
// overload operator + (add) eg. a+b
float16 operator + (const float16& data);
template<class T> float16 operator +(const T& data);
// overload operator += (add) eg. a+=b
float16 operator += (const float16& data);
template<class T> float16 operator +=(const T& data);
// overload operator -(negetive) eg. -a
float16 operator - ();
// overload operator - (substraction) eg. a-b
float16 operator - (const float16& data);
template<class T> float16 operator -(const T& data);
// overload operator -= (substraction) eg. a-=b
float16 operator -= (const float16& data);
template<class T> float16 operator -=(const T& data);
// overload operator * (multiple) eg. a*b
float16 operator * (const float16& data);
template<class T> float16 operator *(const T& data);
// overload operator *= (multiple) eg. a*=b
float16 operator *= (const float16& data);
template<class T> float16 operator *=(const T& data);
// overload operator / (division) eg. a/b
float16 GetInverse() const;
float16 operator / (const float16& data);
template<class T> float16 operator /(const T& data);
// overload operator /= (division) eg. a/=b
float16 operator /= (const float16& data);
template<class T> float16 operator /=(const T& data);
#endif /* FLOAT16_H */
case 1: test ConvertDataType function.
In this case, the flaot32 data type is converted to int32 data type.
case 1: test ConvertDataType function.
In this case, the flaot32 data type is converted to int32 data type.
bool TestConvertDataType1()
a->SetData(data1, unitNum1);
a->SetData(data1, unitNum1);
/* call ConvertDataType function (We have not implemented this yet...) */
//_ConvertDataType(a, b);
//_ConvertDataType(b, c);
_ConvertDataType(a, b);
_ConvertDataType(b, c);
/* check results */
//cpuTest = _CheckData(a, data1, unitNum1, 1e-4F);
cpuTest = _CheckData(a, data1, unitNum1, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
_ConvertDataType(eGPU, fGPU);
_ConvertDataType(eGPU, fGPU);
/* check results */
gpuTest = _CheckData(fGPU, answer, unitNum3, 1e-4F);
//gpuTest = _CheckData(fGPU, answer, unitNum3, 1e-4F);
/* destroy variables */
delete a;
