Commit 0d6a5f7d by xiaotong

add tensor connections

parent 4a0b5fed
......@@ -188,5 +188,17 @@ void XLink::AddParamToHead(XTensor * h, DTYPE param)
h->income.AddParam(param);
}
/*
add an integer parameter
>> h - head
>> param - parameter we want introduce
*/
void XLink::AddParamToHeadInt(XTensor * h, int param)
{
if(h != NULL)
return;
h->income.AddParam(&param, sizeof(int));
}
} // namespace nts(NiuTrans.Tensor)
......@@ -105,9 +105,13 @@ struct XLink
static
void MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeName);
/* add parameters */
/* add a parameter */
static
void AddParamToHead(XTensor * h, DTYPE param);
/* add an integer parameter */
static
void AddParamToHeadInt(XTensor * h, int param);
};
} // namespace nts(NiuTrans.Tensor)
......
......@@ -28,8 +28,13 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_SUM "M_SUM"
#define MATH_MATMUL "M_MATMUL"
#define MATH_SELECTRANGE "M_SELECTRANGE"
#define MATH_SORT "M_SORT"
#define MATH_SUM "M_SUM"
#define MATH_TOPK "M_TOPK"
#define MATH_UNSQUEEZE "M_UNSQUEEZE"
} // namespace nts(NiuTrans.Tensor)
......
......@@ -21,6 +21,7 @@
#include "../XTensor.h"
#include "../XUtility.h"
#include "../XName.h"
#include "Sort.h"
#include "Sort.cuh"
......@@ -38,6 +39,10 @@ void Sort(XTensor * a, XTensor * index, int dim)
CheckNTErrors((a->order == index->order), "Unmatched input tensors!");
CheckNTErrors((index->dataType == X_INT), "Wrong data type!");
/* make tensor connections */
XLink::MakeLink(a, NULL, index, MATH_SORT);
XLink::AddParamToHeadInt(index, dim);
int dimRDI = a->order - dim - 1;
/* make the index tensor */
index->SetAscendingOrder(dim);
......
......@@ -20,6 +20,7 @@
*/
#include "../XTensor.h"
#include "../XName.h"
#include "TopK.h"
#include "TopK.cuh"
......@@ -39,6 +40,11 @@ void TopK(XTensor * a, XTensor * b, XTensor * index, int dim, int k)
CheckNTErrors((index == NULL || a->order == index->order), "Unmatched input tensors!");
CheckNTErrors((index->dataType == X_INT), "Wrong data type!");
/* make tensor connections */
XLink::MakeLink(a, b, index, MATH_TOPK);
XLink::AddParamToHeadInt(index, dim);
XLink::AddParamToHeadInt(index, k);
int dimRDI = a->order - dim - 1;
for (int i = 0; i < a->order; i++) {
if (i == dimRDI) {
......
......@@ -20,6 +20,7 @@
*/
#include "../XTensor.h"
#include "../XName.h"
#include "Unsqueeze.h"
#include "MergeBlockLists.h"
#include "Unsqueeze.cuh"
......@@ -39,6 +40,11 @@ void Unsqueeze(XTensor * a, XTensor * b, int dim, int dSize)
CheckNTErrors((a->order == b->order - 1), "Unmatched tensors!");
CheckNTErrors((a->unitSize == b->unitSize), "Unmatched tensors!");
/* make tensor connections */
XLink::MakeLink(a, NULL, b, MATH_UNSQUEEZE);
XLink::AddParamToHeadInt(b, dim);
XLink::AddParamToHeadInt(b, dSize);
int dimRDI = b->order - dim - 1;
for (int i = 0; i < b->order; i++) {
if (i < dimRDI) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论