Commit dcabc2b0 by xiaotong

redesign the interface

parent f69bd017
...@@ -45,6 +45,7 @@ ...@@ -45,6 +45,7 @@
#include "Multiply.h" #include "Multiply.h"
#include "Negate.h" #include "Negate.h"
#include "Normalize.h" #include "Normalize.h"
#include "Permute.h"
#include "Power.h" #include "Power.h"
#include "ReduceMax.h" #include "ReduceMax.h"
#include "ReduceMean.h" #include "ReduceMean.h"
......
...@@ -26,12 +26,25 @@ ...@@ -26,12 +26,25 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* permute the tensor dimensions on site: a = permuted(a) */ #define permute _Permute_
void Permute_(XTensor * a, int * dimPermute);
/* generate the tensor with permuted dimensions: b = permuted(a) */ /* generate the tensor with permuted dimensions: b = permuted(a) */
extern "C"
void Permute(XTensor * a, XTensor * b, int * dimPermute); void Permute(XTensor * a, XTensor * b, int * dimPermute);
/* permute the tensor dimensions on site: a = permuted(a) */
extern "C"
void Permute_(XTensor * a, int * dimPermute);
/* make a tensor with permuted dimensions: b = permuted(a) and return its pointer */
extern "C"
XTensor * _Permute(XTensor *a, int * dimPermute);
/* make a tensor with permuted dimensions: b = permuted(a) and return its body */
extern "C"
XTensor& _Permute_(XTensor &a, int * dimPermute);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // __PERMUTE_H__ #endif // __PERMUTE_H__
......
...@@ -27,12 +27,20 @@ ...@@ -27,12 +27,20 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* transpose a 1D/2D tensor on site: a = transposed(a) */ #define transpose _Transpose_
void Transpose_(XTensor * a);
/* generate a transposed 1D/2D tensor: b = transposed(a) */ /* generate a transposed 1D/2D tensor: b = transposed(a) */
void Transpose(XTensor * a, XTensor * b); void Transpose(XTensor * a, XTensor * b);
/* transpose a 1D/2D tensor on site: a = transposed(a) */
void Transpose_(XTensor * a);
/* make a transposed 1D/2D tensor: b = transposed(a) and return its pointer */
XTensor * _Transpose(XTensor * a);
/* make a transposed 1D/2D tensor: b = transposed(a) and return its body */
XTensor & _Transpose_(XTensor & a);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // __TRANSPOSE_H__ #endif // __TRANSPOSE_H__
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论