Commit bd05b21b by xuchen

fix the bug and add the function CheckDataType

parent 6c15668f
...@@ -20,13 +20,13 @@ ...@@ -20,13 +20,13 @@
*/ */
#include <stdio.h> #include <stdio.h>
#include "XNet.h" #include "./network/XNet.h"
#include "../tensor/XUtility.h" #include "./tensor/XUtility.h"
#include "../tensor/function/FHeader.h" #include "./tensor/function/FHeader.h"
#include "../tensor/core/CHeader.h" #include "./tensor/core/CHeader.h"
#include "../tensor/test/Test.h" #include "./tensor/test/Test.h"
#include "../sample/fnnlm/FNNLM.h" #include "./sample/fnnlm/FNNLM.h"
#include "../sample/transformer/Transformer.h" #include "./sample/transformer/Transformer.h"
//#define CRTDBG_MAP_ALLOC //#define CRTDBG_MAP_ALLOC
//#include <stdlib.h> //#include <stdlib.h>
......
...@@ -778,7 +778,7 @@ XTensor * NewTensor5D(const int d0, const int d1, const int d2, const int d3, co ...@@ -778,7 +778,7 @@ XTensor * NewTensor5D(const int d0, const int d1, const int d2, const int d3, co
XTensor * NewTensorRange(int lower, int upper, int step, const TENSOR_DATA_TYPE myDataType, const int myDevID, const bool isEnableGrad) XTensor * NewTensorRange(int lower, int upper, int step, const TENSOR_DATA_TYPE myDataType, const int myDevID, const bool isEnableGrad)
{ {
int size = abs(upper - lower); int size = abs(upper - lower);
int unitNum = ceil(1.0 * size / abs(step)); int unitNum = (int)ceil(1.0 * size / abs(step));
XTensor * tensor = NewTensor1D(unitNum, myDataType, myDevID, isEnableGrad); XTensor * tensor = NewTensor1D(unitNum, myDataType, myDevID, isEnableGrad);
tensor->Range(lower, upper, step); tensor->Range(lower, upper, step);
......
...@@ -49,6 +49,15 @@ extern TENSOR_DATA_TYPE GetDataType(const char * typeName); ...@@ -49,6 +49,15 @@ extern TENSOR_DATA_TYPE GetDataType(const char * typeName);
unsigned short FloatToFloat16(float f); unsigned short FloatToFloat16(float f);
float Float16ToFloat(unsigned short h); float Float16ToFloat(unsigned short h);
#define CheckDataType(a, b) \
{ \
if(GetDataTypeName(a) != GetDataTypeName(a)){ \
fprintf(stderr, "[ERROR] (%s line %d): we must run the code on the same datatype (%s vs %s)\n", \
__FILENAME__, __LINE__, GetDataTypeName(a), GetDataTypeName(b)); \
exit(1); \
} \
} \
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
#endif #endif
\ No newline at end of file
...@@ -1482,7 +1482,7 @@ void XMem::ShowMemUsage(FILE * file) ...@@ -1482,7 +1482,7 @@ void XMem::ShowMemUsage(FILE * file)
} }
fprintf(file, "mem:%.1fMB used:%.1fMB usage:%.3f\n", fprintf(file, "mem:%.1fMB used:%.1fMB usage:%.3f\n",
(DTYPE)used/MILLION, (DTYPE)total/MILLION, (DTYPE)used/total); (DTYPE)total/MILLION, (DTYPE)used/MILLION, (DTYPE)used/total);
} }
#ifdef USE_CUDA #ifdef USE_CUDA
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论