Commit 0d96c2a0 by xuchen

update interface of convertdatatype again

parent 6fd2a671
......@@ -60,4 +60,25 @@ TENSOR_DATA_TYPE GetDataType(const char * typeName)
}
}
/*
Below is for calling CPU BLAS for fast matrix operations
I'm not sure how fast it is. But it seems that other
guys are crazy about this. So I decided to have a try.
*/
/* float -> float16 */
_XINLINE_ unsigned short FloatToFloat16(float f)
{
unsigned int x = *((unsigned int*)&f);
unsigned short h = ((x>>16)&0x8000)|((((x&0x7f800000)-0x38000000)>>13)&0x7c00)|((x>>13)&0x03ff);
return h;
}
/* float16 -> float */
_XINLINE_ float Float16ToFloat(unsigned short h)
{
float f = float(((h&0x8000)<<16) | (((h&0x7c00)+0x1C000)<<13) | ((h&0x03FF)<<13));
return f;
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -46,6 +46,9 @@ enum MATRIX_TRANS_TYPE{X_TRANS, X_NOTRANS};
extern const char * GetDataTypeName(TENSOR_DATA_TYPE type);
extern TENSOR_DATA_TYPE GetDataType(const char * typeName);
unsigned short FloatToFloat16(float f);
float Float16ToFloat(unsigned short h);
} /* end of the nts (NiuTrans.Tensor) namespace */
#endif
\ No newline at end of file
......@@ -28,27 +28,6 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
Below is for calling CPU BLAS for fast matrix operations
I'm not sure how fast it is. But it seems that other
guys are crazy about this. So I decided to have a try.
*/
/* float -> float16 */
_XINLINE_ unsigned short FloatToFloat16(float f)
{
unsigned int x = *((unsigned int*)&f);
unsigned short h = ((x>>16)&0x8000)|((((x&0x7f800000)-0x38000000)>>13)&0x7c00)|((x>>13)&0x03ff);
return h;
}
/* float16 -> float */
_XINLINE_ float Float16ToFloat(unsigned short h)
{
float f = float(((h&0x8000)<<16) | (((h&0x7c00)+0x1C000)<<13) | ((h&0x03FF)<<13));
return f;
}
/*
data type conversion
>> devID - device id
>> s - source data array
......
......@@ -28,8 +28,6 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* data conversion (for lower precision computation) */
unsigned short FloatToFloat16(float f);
float Float16ToFloat(unsigned short h);
void ConvertDataType(int devID,
void * s, TENSOR_DATA_TYPE typeS,
void * t, TENSOR_DATA_TYPE typeT, int size);
......
......@@ -23,6 +23,7 @@
#include "../../XUtility.h"
#include "CopyValues.h"
#include "CopyValues.cuh"
#include "../getandset/ConvertDataType.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论