Commit 6df1ecc9 by xuchen

add the reciprocal function

parent e193b1c2
...@@ -68,6 +68,14 @@ T UnaryIsZero(T r) ...@@ -68,6 +68,14 @@ T UnaryIsZero(T r)
return (r == 0.0) ? (T)1.0 : (T)0.0; return (r == 0.0) ? (T)1.0 : (T)0.0;
} }
template<class T>
T UnaryReciprocal(T r)
{
if (r == 0)
ShowNTErrors("Zero does not have reciprocal value.");
return (T)(1 / r);
}
/* define three marco separately, specify the respective function names */ /* define three marco separately, specify the respective function names */
#ifdef USE_CUDA #ifdef USE_CUDA
#define _SIMPLE_UNARY_FUNCTION(_funcName, _cudaFuncName, origFunc) \ #define _SIMPLE_UNARY_FUNCTION(_funcName, _cudaFuncName, origFunc) \
...@@ -186,6 +194,7 @@ _SIMPLE_UNARY_FUNCTION(_Square, _CudaSquare, UnarySquare) ...@@ -186,6 +194,7 @@ _SIMPLE_UNARY_FUNCTION(_Square, _CudaSquare, UnarySquare)
_SIMPLE_UNARY_FUNCTION(_Sin, _CudaSin, sin) _SIMPLE_UNARY_FUNCTION(_Sin, _CudaSin, sin)
_SIMPLE_UNARY_FUNCTION(_Cos, _CudaCos, cos) _SIMPLE_UNARY_FUNCTION(_Cos, _CudaCos, cos)
_SIMPLE_UNARY_FUNCTION(_Tan, _CudaTan, tan) _SIMPLE_UNARY_FUNCTION(_Tan, _CudaTan, tan)
_SIMPLE_UNARY_FUNCTION(_Reciprocal, _CudaReciprocal, UnaryReciprocal)
#else #else
_SIMPLE_UNARY_FUNCTION(_Absolute, fabs) _SIMPLE_UNARY_FUNCTION(_Absolute, fabs)
_SIMPLE_UNARY_FUNCTION(_Ceil, ceil) _SIMPLE_UNARY_FUNCTION(_Ceil, ceil)
...@@ -202,6 +211,7 @@ _SIMPLE_UNARY_FUNCTION(_Square, UnarySquare) ...@@ -202,6 +211,7 @@ _SIMPLE_UNARY_FUNCTION(_Square, UnarySquare)
_SIMPLE_UNARY_FUNCTION(_Sin, sin) _SIMPLE_UNARY_FUNCTION(_Sin, sin)
_SIMPLE_UNARY_FUNCTION(_Cos, cos) _SIMPLE_UNARY_FUNCTION(_Cos, cos)
_SIMPLE_UNARY_FUNCTION(_Tan, tan) _SIMPLE_UNARY_FUNCTION(_Tan, tan)
_SIMPLE_UNARY_FUNCTION(_Reciprocal, UnaryReciprocal)
#endif #endif
_SIMPLE_UNARY_FUNCTION_ME(_AbsoluteMe, _Absolute) _SIMPLE_UNARY_FUNCTION_ME(_AbsoluteMe, _Absolute)
...@@ -279,4 +289,9 @@ SIMPLE_UNARY_FUNCTION_ME(TanMe, _Tan) ...@@ -279,4 +289,9 @@ SIMPLE_UNARY_FUNCTION_ME(TanMe, _Tan)
SIMPLE_UNARY_FUNCTION(Tan, _Tan, MATH_TAN) SIMPLE_UNARY_FUNCTION(Tan, _Tan, MATH_TAN)
SIMPLE_UNARY_FUNCTION_VOID(Tan, _Tan, MATH_TAN) SIMPLE_UNARY_FUNCTION_VOID(Tan, _Tan, MATH_TAN)
_SIMPLE_UNARY_FUNCTION_ME(_ReciprocalMe, _Reciprocal)
SIMPLE_UNARY_FUNCTION_ME(ReciprocalMe, _Reciprocal)
SIMPLE_UNARY_FUNCTION(Reciprocal, _Reciprocal, MATH_RECIPROCAL)
SIMPLE_UNARY_FUNCTION_VOID(Reciprocal, _Reciprocal, MATH_RECIPROCAL)
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -142,6 +142,15 @@ T UnaryCudaTan(T x) ...@@ -142,6 +142,15 @@ T UnaryCudaTan(T x)
return (T)tan((float)x); return (T)tan((float)x);
} }
template<class T>
__device__
T UnaryCudaReciprocal(T x)
{
//if (x == 0)
//ShowNTErrors("Zero does not have reciprocal value.");
return (T)(1 / x);
}
#define SIMPLE_UNARY_FUNCTION_GPU(funcName, origFunc) \ #define SIMPLE_UNARY_FUNCTION_GPU(funcName, origFunc) \
template<class T> \ template<class T> \
...@@ -155,7 +164,7 @@ void Kernel##funcName(T * a, T * b, int size) \ ...@@ -155,7 +164,7 @@ void Kernel##funcName(T * a, T * b, int size) \
} \ } \
void _Cuda##funcName(const XTensor * a, XTensor * b) \ void _Cuda##funcName(const XTensor * a, XTensor * b) \
{ \ { \
CheckNTErrors((_IsSameShaped(a, b)), \ CheckNTErrors((_IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \ "Input tensors should have the same type!"); \
CheckNTErrors(a->isSparse == false, "TODO!"); \ CheckNTErrors(a->isSparse == false, "TODO!"); \
\ \
...@@ -208,6 +217,8 @@ SIMPLE_UNARY_FUNCTION_GPU(Sin, UnaryCudaSin) ...@@ -208,6 +217,8 @@ SIMPLE_UNARY_FUNCTION_GPU(Sin, UnaryCudaSin)
SIMPLE_UNARY_FUNCTION_GPU(Cos, UnaryCudaCos) SIMPLE_UNARY_FUNCTION_GPU(Cos, UnaryCudaCos)
SIMPLE_UNARY_FUNCTION_GPU(Tan, UnaryCudaTan) SIMPLE_UNARY_FUNCTION_GPU(Tan, UnaryCudaTan)
SIMPLE_UNARY_FUNCTION_GPU(Reciprocal, UnaryCudaReciprocal)
#endif // USE_CUDA #endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -75,6 +75,9 @@ void _CudaCos(const XTensor * a, XTensor * b); ...@@ -75,6 +75,9 @@ void _CudaCos(const XTensor * a, XTensor * b);
/* set each entry to its tangent value */ /* set each entry to its tangent value */
void _CudaTan(const XTensor * a, XTensor * b); void _CudaTan(const XTensor * a, XTensor * b);
/* set each entry to its reciprocal value */
void _CudaReciprocal(const XTensor * a, XTensor * b);
#endif // USE_CUDA #endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -236,6 +236,20 @@ XTensor Tan(const XTensor & a); ...@@ -236,6 +236,20 @@ XTensor Tan(const XTensor & a);
/* set every entry to its tangent value */ /* set every entry to its tangent value */
void Tan(const XTensor & a, XTensor & b); void Tan(const XTensor & a, XTensor & b);
/* set every entry to its reciprocal value */
void _Reciprocal(const XTensor * a, XTensor * b);
/* set every entry to its reciprocal value (do it on site)
keep the result in the input tensor a and return nothing */
void _ReciprocalMe(XTensor * a);
/* set every entry to its reciprocal value (do it on site)
keep the result in the input tensor a and return nothing */
void ReciprocalMe(XTensor & a);
/* set every entry to its reciprocal value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Reciprocal(const XTensor & a);
/* set every entry to its reciprocal value */
void Reciprocal(const XTensor & a, XTensor & b);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // end __UNARY_H__ #endif // end __UNARY_H__
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论