Commit 4a0b5fed by xiaotong

modify the arguments in Select

parent 07a5ae75
...@@ -19,8 +19,9 @@ ...@@ -19,8 +19,9 @@
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-04 * $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-04
*/ */
#include "Select.h"
#include "../XUtility.h" #include "../XUtility.h"
#include "../XName.h"
#include "Select.h"
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
...@@ -28,13 +29,13 @@ namespace nts{ // namespace nts(NiuTrans.Tensor) ...@@ -28,13 +29,13 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
generate a tensor with seleccted data in range[low,high] along the given dimension generate a tensor with seleccted data in range[low,high] along the given dimension
c = select(a) c = select(a)
>> a - input tensor >> a - input tensor
>> c - result tensor
>> dim - the dimension along with which we do the job >> dim - the dimension along with which we do the job
>> low - lower bound >> low - lower bound
>> high - higher bound. >> high - higher bound.
Note that range [1,3] means that we select 1 and 2. Note that range [1,3] means that we select 1 and 2.
>> c - result tensor
*/ */
void SelectRange(XTensor * a, int dim, int low, int high, XTensor * c) void SelectRange(XTensor * a, XTensor * c, int dim, int low, int high)
{ {
CheckNTErrors(a != NULL && c != NULL, "empty tensors!"); CheckNTErrors(a != NULL && c != NULL, "empty tensors!");
CheckNTErrors(a->order == c->order, "The input and output tensors must in the same order!"); CheckNTErrors(a->order == c->order, "The input and output tensors must in the same order!");
...@@ -54,6 +55,12 @@ void SelectRange(XTensor * a, int dim, int low, int high, XTensor * c) ...@@ -54,6 +55,12 @@ void SelectRange(XTensor * a, int dim, int low, int high, XTensor * c)
} }
} }
/* make tensor connections */
XLink::MakeLink(a, NULL, c, MATH_SELECTRANGE);
XLink::AddParamToHeadInt(c, dim);
XLink::AddParamToHeadInt(c, low);
XLink::AddParamToHeadInt(c, high);
int stride = 1; int stride = 1;
for(int i = 0; i < dim; i++) for(int i = 0; i < dim; i++)
stride *= a->dimSizeRDI[i]; stride *= a->dimSizeRDI[i];
......
...@@ -28,12 +28,12 @@ namespace nts{ // namespace nts(NiuTrans.Tensor) ...@@ -28,12 +28,12 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* generate a tensor with seleccted data c = select(a) */ /* generate a tensor with seleccted data c = select(a) */
extern "C" extern "C"
void Select(XTensor * a, XTensor * indexCPU, XTensor * c); void Select(XTensor * a, XTensor * c, XTensor * indexCPU);
/* generate a tensor with seleccted data in range[low,high] along the given dimension /* generate a tensor with seleccted data in range[low,high] along the given dimension
c = select(a) */ c = select(a) */
extern "C" extern "C"
void SelectRange(XTensor * a, int dim, int low, int high, XTensor * c); void SelectRange(XTensor * a, XTensor * c, int dim, int low, int high);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论