Commit 04d23e39 by xiaotong

new comments

parent 0a9d9bd3
...@@ -28,17 +28,11 @@ namespace nts{ // namespace nts(NiuTrans.Tensor) ...@@ -28,17 +28,11 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA #ifdef USE_CUDA
/* compare whether every entry is equal to the specified value (cuda kernel) */ /* check whether every entry is equal to the given value (cuda version) */
__global__ void _CudaEqual(const XTensor * a, XTensor * b, DTYPE value);
void KernelEqual(DTYPE * a, DTYPE * b, DTYPE * number);
/* compare whether every entry is equal to the specified value (cuda version) */ /* check whether every entry is not equal to the given value (cuda version) */
void _CudaEqual(const XTensor * a, XTensor * b, DTYPE number); void _CudaNotEqual(const XTensor * a, XTensor * b, DTYPE value);
/* compare whether every entry is not equal to the specified value (cuda kernel) */
__global__
void KernelNotEqual(DTYPE * a, DTYPE * b, DTYPE * number);
/* compare whether every entry is not equal to the specified value (cuda version) */
void _CudaNotEqual(const XTensor * a, XTensor * b, DTYPE number);
#endif // USE_CUDA #endif // USE_CUDA
......
...@@ -26,23 +26,23 @@ ...@@ -26,23 +26,23 @@
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
/* compare whether every entry is equal to the specified value */ /* check whether every entry is equal to the given value */
void _Equal(const XTensor * a, XTensor * b, DTYPE number); void _Equal(const XTensor * a, XTensor * b, DTYPE value);
/* compare whether every entry is equal to the specified value (do it on site)
keep the result in the input tensor a and return nothing */ /* check whether every entry is equal to the given value (do it on site) */
void _EqualMe(XTensor * a, DTYPE number); void _EqualMe(XTensor * a, DTYPE value);
/* compare whether every entry is equal to the specified value (return an XTensor structure)
make a new tensor to keep the result and return it */ /* check whether every entry is equal to the given value (return an XTensor structure) */
XTensor Equal(const XTensor & a, DTYPE number); XTensor Equal(const XTensor & a, DTYPE value);
/* compare whether every entry is not equal to the specified value */ /* check whether every entry is not equal to the given value */
void _NotEqual(const XTensor * a, XTensor * b, DTYPE number); void _NotEqual(const XTensor * a, XTensor * b, DTYPE value);
/* compare whether every entry is not equal to the specified value (do it on site)
keep the result in the input tensor a and return nothing */ /* check whether every entry is not equal to the given value (do it on site) */
void _NotEqualMe(XTensor * a, DTYPE number); void _NotEqualMe(XTensor * a, DTYPE value);
/* compare whether every entry is not equal to the specified value (return an XTensor structure)
make a new tensor to keep the result and return it */ /* check whether every entry is not equal to the given value (return an XTensor structure) */
XTensor NotEqual(const XTensor & a, DTYPE number); XTensor NotEqual(const XTensor & a, DTYPE value);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论