Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
36903fdb
Commit
36903fdb
authored
Sep 11, 2019
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix the bug of marco compile
parent
d3a0b984
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
118 行增加
和
13 行删除
+118
-13
source/tensor/core/math/Binary.cpp
+47
-5
source/tensor/core/math/Compare.cpp
+23
-7
source/tensor/core/math/Unary.cpp
+48
-1
没有找到文件。
source/tensor/core/math/Binary.cpp
查看文件 @
36903fdb
...
...
@@ -67,17 +67,49 @@ int BinaryMod(int x, int num)
}
/* define three marco separately, specify the respective function names */
#ifdef USE_CUDA
#define _SIMPLE_BINARY_FUNCTION(_funcName, _cudaFuncName, origFunc) \
template<class T> \
void _funcName(const XTensor * a, XTensor * b, T num) \
{ \
/* run it on GPUs */
\
if (a->devID >= 0) { \
if (useCUDA) { \
_cudaFuncName(a, b, num); \
return; \
} \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same data type!"); \
if (a->dataType == X_INT) { \
int * d = (int*)a->data; \
int * db = (int*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (int)origFunc((int)d[i], (T)num); \
} \
else if (a->dataType == X_FLOAT) { \
float * d = (float*)a->data; \
float * db = (float*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (float)origFunc((float)d[i], (T)num); \
} \
else if (a->dataType == X_DOUBLE) { \
double * d = (double*)a->data; \
double * db = (double*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (double)origFunc((double)d[i], (T)num); \
} \
else \
ShowNTErrors("TO DO!"); \
} \
template void _funcName<int>(const XTensor*, XTensor*, int); \
template void _funcName<float>(const XTensor*, XTensor*, float); \
template void _funcName<double>(const XTensor*, XTensor*, double);
#else
#define _SIMPLE_BINARY_FUNCTION(_funcName, origFunc) \
template<class T> \
void _funcName(const XTensor * a, XTensor * b, T num) \
{ \
/* run it on GPUs */
\
if (a->devID >= 0) { \
ShowNTErrors("No GPU devices support!") \
} \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
...
...
@@ -106,6 +138,7 @@ void _funcName(const XTensor * a, XTensor * b, T num)
template void _funcName<int>(const XTensor*, XTensor*, int); \
template void _funcName<float>(const XTensor*, XTensor*, float); \
template void _funcName<double>(const XTensor*, XTensor*, double);
#endif
#define _SIMPLE_BINARY_FUNCTION_ME(_funcNameMe, _funcName) \
template<class T> \
...
...
@@ -159,31 +192,40 @@ template void funcName<int>(const XTensor&, XTensor&, int);
template void funcName<float>(const XTensor&, XTensor&, float); \
template void funcName<double>(const XTensor&, XTensor&, double);
#ifdef USE_CUDA
_SIMPLE_BINARY_FUNCTION
(
_Descale
,
_CudaDescale
,
BinaryDescale
)
_SIMPLE_BINARY_FUNCTION
(
_Mod
,
_CudaMod
,
BinaryMod
)
_SIMPLE_BINARY_FUNCTION
(
_Power
,
_CudaPower
,
BinaryPower
)
_SIMPLE_BINARY_FUNCTION
(
_Scale
,
_CudaScale
,
BinaryScale
)
_SIMPLE_BINARY_FUNCTION
(
_Shift
,
_CudaShift
,
BinaryShift
)
#else
_SIMPLE_BINARY_FUNCTION
(
_Descale
,
BinaryDescale
)
_SIMPLE_BINARY_FUNCTION
(
_Mod
,
BinaryMod
)
_SIMPLE_BINARY_FUNCTION
(
_Power
,
BinaryPower
)
_SIMPLE_BINARY_FUNCTION
(
_Scale
,
BinaryScale
)
_SIMPLE_BINARY_FUNCTION
(
_Shift
,
BinaryShift
)
#endif
_SIMPLE_BINARY_FUNCTION_ME
(
_DescaleMe
,
_Descale
)
SIMPLE_BINARY_FUNCTION_ME
(
DescaleMe
,
_Descale
)
SIMPLE_BINARY_FUNCTION
(
Descale
,
_Descale
,
MATH_DESCALE
)
SIMPLE_BINARY_FUNCTION_VOID
(
Descale
,
_Descale
,
MATH_DESCALE
)
_SIMPLE_BINARY_FUNCTION
(
_Mod
,
_CudaMod
,
BinaryMod
)
_SIMPLE_BINARY_FUNCTION_ME
(
_ModMe
,
_Mod
)
SIMPLE_BINARY_FUNCTION_ME
(
ModMe
,
_Mod
)
SIMPLE_BINARY_FUNCTION
(
Mod
,
_Mod
,
MATH_MOD
)
SIMPLE_BINARY_FUNCTION_VOID
(
Mod
,
_Mod
,
MATH_MOD
)
_SIMPLE_BINARY_FUNCTION
(
_Power
,
_CudaPower
,
BinaryPower
)
_SIMPLE_BINARY_FUNCTION_ME
(
_PowerMe
,
_Power
)
SIMPLE_BINARY_FUNCTION_ME
(
PowerMe
,
_Power
)
SIMPLE_BINARY_FUNCTION
(
Power
,
_Power
,
MATH_POWER
)
SIMPLE_BINARY_FUNCTION_VOID
(
Power
,
_Power
,
MATH_POWER
)
_SIMPLE_BINARY_FUNCTION
(
_Scale
,
_CudaScale
,
BinaryScale
)
_SIMPLE_BINARY_FUNCTION_ME
(
_ScaleMe
,
_Scale
)
SIMPLE_BINARY_FUNCTION_ME
(
ScaleMe
,
_Scale
)
SIMPLE_BINARY_FUNCTION
(
Scale
,
_Scale
,
MATH_SCALE
)
SIMPLE_BINARY_FUNCTION_VOID
(
Scale
,
_Scale
,
MATH_SCALE
)
_SIMPLE_BINARY_FUNCTION
(
_Shift
,
_CudaShift
,
BinaryShift
)
_SIMPLE_BINARY_FUNCTION_ME
(
_ShiftMe
,
_Shift
)
SIMPLE_BINARY_FUNCTION_ME
(
ShiftMe
,
_Shift
)
SIMPLE_BINARY_FUNCTION
(
Shift
,
_Shift
,
MATH_SHIFT
)
...
...
source/tensor/core/math/Compare.cpp
查看文件 @
36903fdb
...
...
@@ -36,8 +36,8 @@ DTYPE myIsNotEqual(DTYPE a, DTYPE b)
return
(
a
!=
b
?
1.0
F
:
0.0
F
);
}
#ifdef USE_CUDA
/* define three marco separately, specify the respective function names */
#ifdef USE_CUDA
#define _SIMPLE_COMPARE_FUNCTION(_funcName, _cudaFuncName, origFunc) \
void _funcName(const XTensor * a, XTensor * b, DTYPE number) \
{ \
...
...
@@ -46,11 +46,23 @@ void _funcName(const XTensor * a, XTensor * b, DTYPE number)
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!"); \
/* run it on GPUs */
\
if (a->devID >= 0) { \
if (useCUDA) { \
_cudaFuncName(a, b, number); \
return; \
} \
else \
DTYPE * d = (DTYPE*)a->data; \
DTYPE * db = (DTYPE*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (DTYPE)origFunc(d[i], number); \
}
#else
#define _SIMPLE_COMPARE_FUNCTION(_funcName, origFunc) \
void _funcName(const XTensor * a, XTensor * b, DTYPE number) \
{ \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!"); \
/* run it on GPUs */
\
if (a->devID >= 0) { \
ShowNTErrors("No GPU devices support!") \
} \
DTYPE * d = (DTYPE*)a->data; \
...
...
@@ -58,6 +70,7 @@ void _funcName(const XTensor * a, XTensor * b, DTYPE number)
for (int i = 0; i < a->unitNum; i++) \
db[i] = (DTYPE)origFunc(d[i], number); \
}
#endif
#define _SIMPLE_COMPARE_FUNCTION_ME(_funcNameMe, _funcName) \
void _funcNameMe(XTensor * a, DTYPE number) \
...
...
@@ -92,18 +105,22 @@ void funcName(const XTensor &a, XTensor &b, DTYPE number)
// I think we needn't to make link.
// XLink::MakeLink(&a, NULL, &b, operationId);
#ifdef USE_CUDA
_SIMPLE_COMPARE_FUNCTION
(
_Equal
,
_CudaEqual
,
myIsEqual
)
_SIMPLE_COMPARE_FUNCTION
(
_NotEqual
,
_CudaNotEqual
,
myIsNotEqual
)
#else
_SIMPLE_COMPARE_FUNCTION
(
_Equal
,
myIsEqual
)
_SIMPLE_COMPARE_FUNCTION
(
_NotEqual
,
myIsNotEqual
)
#endif
_SIMPLE_COMPARE_FUNCTION_ME
(
_EqualMe
,
_Equal
)
SIMPLE_COMPARE_FUNCTION_ME
(
EqualMe
,
_Equal
)
SIMPLE_COMPARE_FUNCTION
(
Equal
,
_Equal
,
MATH_EQUAL
)
SIMPLE_COMPARE_FUNCTION_VOID
(
Equal
,
_Equal
,
MATH_EQUAL
)
_SIMPLE_COMPARE_FUNCTION
(
_NotEqual
,
_CudaNotEqual
,
myIsNotEqual
)
_SIMPLE_COMPARE_FUNCTION_ME
(
_NotEqualMe
,
_NotEqual
)
SIMPLE_COMPARE_FUNCTION_ME
(
NotEqualMe
,
_NotEqual
)
SIMPLE_COMPARE_FUNCTION
(
NotEqual
,
_NotEqual
,
MATH_NOTEQUAL
)
SIMPLE_COMPARE_FUNCTION_VOID
(
NotEqual
,
_NotEqual
,
MATH_NOTEQUAL
)
#endif
}
//
namespace
nts
(
NiuTrans
.
Tensor
)
\ No newline at end of file
source/tensor/core/math/Unary.cpp
查看文件 @
36903fdb
...
...
@@ -68,16 +68,44 @@ T UnaryIsZero(T r)
}
/* define three marco separately, specify the respective function names */
#ifdef USE_CUDA
#define _SIMPLE_UNARY_FUNCTION(_funcName, _cudaFuncName, origFunc) \
void _funcName(const XTensor * a, XTensor * b) \
{ \
/* run it on GPUs */
\
if (a->devID >= 0) { \
if (useCUDA) { \
_cudaFuncName(a, b); \
return; \
} \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \
if (a->dataType == X_INT) { \
int * d = (int*)a->data; \
int * db = (int*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (int)origFunc(d[i]); \
} \
else if (a->dataType == X_FLOAT) { \
float * d = (float*)a->data; \
float * db = (float*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (float)origFunc(d[i]); \
} \
else if (a->dataType == X_DOUBLE) { \
double * d = (double*)a->data; \
double * db = (double*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (double)origFunc(d[i]); \
} \
else \
ShowNTErrors("TO DO!"); \
}
#else
#define _SIMPLE_UNARY_FUNCTION(_funcName, origFunc) \
void _funcName(const XTensor * a, XTensor * b) \
{ \
/* run it on GPUs */
\
if (a->devID >= 0) { \
ShowNTErrors("No GPU devices support!") \
} \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
...
...
@@ -103,6 +131,7 @@ void _funcName(const XTensor * a, XTensor * b)
else \
ShowNTErrors("TO DO!"); \
}
#endif
#define _SIMPLE_UNARY_FUNCTION_ME(_funcNameMe, _funcName) \
void _funcNameMe(XTensor * a) \
...
...
@@ -138,6 +167,7 @@ void funcName(const XTensor & a, XTensor & b)
} \
}
#ifdef USE_CUDA
_SIMPLE_UNARY_FUNCTION
(
_Absolute
,
_CudaAbsolute
,
fabs
)
_SIMPLE_UNARY_FUNCTION
(
_Ceil
,
_CudaCeil
,
ceil
)
_SIMPLE_UNARY_FUNCTION
(
_Exp
,
_CudaExp
,
exp
)
...
...
@@ -153,6 +183,23 @@ _SIMPLE_UNARY_FUNCTION(_Square, _CudaSquare, UnarySquare)
_SIMPLE_UNARY_FUNCTION
(
_Sin
,
_CudaSin
,
sin
)
_SIMPLE_UNARY_FUNCTION
(
_Cos
,
_CudaCos
,
cos
)
_SIMPLE_UNARY_FUNCTION
(
_Tan
,
_CudaTan
,
tan
)
#else
_SIMPLE_UNARY_FUNCTION
(
_Absolute
,
fabs
)
_SIMPLE_UNARY_FUNCTION
(
_Ceil
,
ceil
)
_SIMPLE_UNARY_FUNCTION
(
_Exp
,
exp
)
_SIMPLE_UNARY_FUNCTION
(
_Floor
,
floor
)
_SIMPLE_UNARY_FUNCTION
(
_IsNonZero
,
UnaryIsNonZero
)
_SIMPLE_UNARY_FUNCTION
(
_IsZero
,
UnaryIsZero
)
_SIMPLE_UNARY_FUNCTION
(
_Log
,
log
)
_SIMPLE_UNARY_FUNCTION
(
_Negate
,
UnaryNegate
)
_SIMPLE_UNARY_FUNCTION
(
_Round
,
round
)
_SIMPLE_UNARY_FUNCTION
(
_Sign
,
UnarySign
)
_SIMPLE_UNARY_FUNCTION
(
_Sqrt
,
sqrt
)
_SIMPLE_UNARY_FUNCTION
(
_Square
,
UnarySquare
)
_SIMPLE_UNARY_FUNCTION
(
_Sin
,
sin
)
_SIMPLE_UNARY_FUNCTION
(
_Cos
,
cos
)
_SIMPLE_UNARY_FUNCTION
(
_Tan
,
tan
)
#endif
_SIMPLE_UNARY_FUNCTION_ME
(
_AbsoluteMe
,
_Absolute
)
SIMPLE_UNARY_FUNCTION_ME
(
AbsoluteMe
,
_Absolute
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论