Commit 4a0b5fed by xiaotong

modify the arguments in Select

parent 07a5ae75
......@@ -19,8 +19,9 @@
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-04
*/
#include "Select.h"
#include "../XUtility.h"
#include "../XName.h"
#include "Select.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -28,13 +29,13 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
generate a tensor with seleccted data in range[low,high] along the given dimension
c = select(a)
>> a - input tensor
>> c - result tensor
>> dim - the dimension along with which we do the job
>> low - lower bound
>> high - higher bound.
Note that range [1,3] means that we select 1 and 2.
>> c - result tensor
*/
void SelectRange(XTensor * a, int dim, int low, int high, XTensor * c)
void SelectRange(XTensor * a, XTensor * c, int dim, int low, int high)
{
CheckNTErrors(a != NULL && c != NULL, "empty tensors!");
CheckNTErrors(a->order == c->order, "The input and output tensors must in the same order!");
......@@ -54,6 +55,12 @@ void SelectRange(XTensor * a, int dim, int low, int high, XTensor * c)
}
}
/* make tensor connections */
XLink::MakeLink(a, NULL, c, MATH_SELECTRANGE);
XLink::AddParamToHeadInt(c, dim);
XLink::AddParamToHeadInt(c, low);
XLink::AddParamToHeadInt(c, high);
int stride = 1;
for(int i = 0; i < dim; i++)
stride *= a->dimSizeRDI[i];
......
......@@ -28,12 +28,12 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* generate a tensor with seleccted data c = select(a) */
extern "C"
void Select(XTensor * a, XTensor * indexCPU, XTensor * c);
void Select(XTensor * a, XTensor * c, XTensor * indexCPU);
/* generate a tensor with seleccted data in range[low,high] along the given dimension
c = select(a) */
extern "C"
void SelectRange(XTensor * a, int dim, int low, int high, XTensor * c);
void SelectRange(XTensor * a, XTensor * c, int dim, int low, int high);
} // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论