Commit 2918c894 by liyinqiao

Merge with XU Chen branch (Don't use this! It's an incomplete version)

Clean the codes and fix some warning issues.
parent 44e2fd1a
...@@ -44,7 +44,7 @@ T1 BinaryPower(T1 x, T2 num) ...@@ -44,7 +44,7 @@ T1 BinaryPower(T1 x, T2 num)
return x * x; return x * x;
else { else {
if (x == 0 && num < 0) if (x == 0 && num < 0)
return (T1)1e20F; return (T1)1e9F;
else else
return (T1)pow(x, num); return (T1)pow(x, num);
} }
...@@ -62,9 +62,10 @@ T1 BinaryShift(T1 x, T2 num) ...@@ -62,9 +62,10 @@ T1 BinaryShift(T1 x, T2 num)
return (T1)(x + num); return (T1)(x + num);
} }
int BinaryMod(int x, int num) template<class T1, class T2>
int BinaryMod(T1 x, T2 num)
{ {
return x % num; return (int)x % (int)num;
} }
/* define three marco separately, specify the respective function names */ /* define three marco separately, specify the respective function names */
...@@ -170,7 +171,7 @@ XTensor funcName(const XTensor &a, T num) ...@@ -170,7 +171,7 @@ XTensor funcName(const XTensor &a, T num)
_funcName(&a, &b, num); \ _funcName(&a, &b, num); \
if(a.enableGrad){ \ if(a.enableGrad){ \
XLink::MakeLink(&a, NULL, &b, operationId); \ XLink::MakeLink(&a, NULL, &b, operationId); \
XLink::AddParamToHead(&b, num); \ XLink::AddParamToHead(&b, (DTYPE)num); \
} \ } \
return b; \ return b; \
} \ } \
...@@ -188,7 +189,7 @@ void funcName(const XTensor &a, XTensor &b, T num) ...@@ -188,7 +189,7 @@ void funcName(const XTensor &a, XTensor &b, T num)
_funcName(&a, &b, num); \ _funcName(&a, &b, num); \
if (a.enableGrad) { \ if (a.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \ XLink::MakeLink(&a, NULL, &b, operationId); \
XLink::AddParamToHead(&b, num); \ XLink::AddParamToHead(&b, (DTYPE)num); \
} \ } \
} \ } \
template void funcName<int>(const XTensor&, XTensor&, int); \ template void funcName<int>(const XTensor&, XTensor&, int); \
......
...@@ -56,7 +56,7 @@ T1 BinaryCudaPower(T1 x, T2 num) ...@@ -56,7 +56,7 @@ T1 BinaryCudaPower(T1 x, T2 num)
return (T1)(x * x); return (T1)(x * x);
else { else {
if (x == 0 && num < 0) if (x == 0 && num < 0)
return (T1)1e20F; return (T1)1e9F;
else else
return (T1)pow((float)x, (float)num); return (T1)pow((float)x, (float)num);
} }
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* calculate result between a tensor and a constant */
/* descale tensor entires /* descale tensor entires
b = a / num */ b = a / num */
template<class T> template<class T>
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论