Commit f6fe574c by xiaotong

remove tensor connections

parent 9a0c89fb
...@@ -59,12 +59,6 @@ void MatrixMul(XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -59,12 +59,6 @@ void MatrixMul(XTensor * a, MATRIX_TRANS_TYPE transposedA,
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2), CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2),
"Input tensors must have a order > 2!"); "Input tensors must have a order > 2!");
/* make tensor connections */
XLink::MakeLink(a, b, c, MATH_MATRIXMUL);
XLink::AddParamToHeadInt(c, transposedA);
XLink::AddParamToHeadInt(c, transposedB);
XLink::AddParamToHead(c, alpha);
XLink::AddParamToHead(c, beta);
int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1]; int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1];
int am = transposedA == X_TRANS ? a->dimSizeRDI[1] : a->dimSizeRDI[0]; int am = transposedA == X_TRANS ? a->dimSizeRDI[1] : a->dimSizeRDI[0];
int bn = transposedB == X_TRANS ? b->dimSizeRDI[0] : b->dimSizeRDI[1]; int bn = transposedB == X_TRANS ? b->dimSizeRDI[0] : b->dimSizeRDI[1];
......
...@@ -48,17 +48,10 @@ void MatrixMul2D(XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -48,17 +48,10 @@ void MatrixMul2D(XTensor * a, MATRIX_TRANS_TYPE transposedA,
XPRunner * parallelRunner, XStream * stream) XPRunner * parallelRunner, XStream * stream)
{ {
CheckNTErrors((a && b && c), "Empty input tensors!"); CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->dataType == b->dataType), "Input tensors should have the same data type!"); CheckNTErrors((a->dataType == b->dataType), "Input tensors should have the same data type!");
CheckNTErrors((a->order == 2 && b->order == 2 && c->order == 2), CheckNTErrors((a->order == 2 && b->order == 2 && c->order == 2),
"Input tensors must have a order = 2!"); "Input tensors must have a order = 2!");
/* make tensor connections */
XLink::MakeLink(a, b, c, MATH_MATRIXMUL2D);
XLink::AddParamToHeadInt(c, transposedA);
XLink::AddParamToHeadInt(c, transposedB);
XLink::AddParamToHead(c, alpha);
XLink::AddParamToHead(c, beta);
int an = a->dimSize[0], am = a->dimSize[1]; int an = a->dimSize[0], am = a->dimSize[1];
int bn = b->dimSize[0], bm = b->dimSize[1]; int bn = b->dimSize[0], bm = b->dimSize[1];
int cn = c->dimSize[0], cm = c->dimSize[1]; int cn = c->dimSize[0], cm = c->dimSize[1];
......
...@@ -54,12 +54,6 @@ void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -54,12 +54,6 @@ void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA,
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2), CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2),
"Input tensors must have a order > 2!"); "Input tensors must have a order > 2!");
/* make tensor connections */
XLink::MakeLink(a, b, c, MATH_MATRIXMULBATCHED);
XLink::AddParamToHeadInt(c, transposedA);
XLink::AddParamToHeadInt(c, transposedB);
XLink::AddParamToHead(c, alpha);
XLink::AddParamToHead(c, beta);
int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1]; int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1];
int am = transposedA == X_TRANS ? a->dimSizeRDI[1] : a->dimSizeRDI[0]; int am = transposedA == X_TRANS ? a->dimSizeRDI[1] : a->dimSizeRDI[0];
int bn = transposedB == X_TRANS ? b->dimSizeRDI[0] : b->dimSizeRDI[1]; int bn = transposedB == X_TRANS ? b->dimSizeRDI[0] : b->dimSizeRDI[1];
......
...@@ -43,11 +43,6 @@ void Multiply(XTensor * a, XTensor * b, XTensor * c, int leadingDim, DTYPE alpha ...@@ -43,11 +43,6 @@ void Multiply(XTensor * a, XTensor * b, XTensor * c, int leadingDim, DTYPE alpha
"Unmatched tensors in multiplication!"); "Unmatched tensors in multiplication!");
CheckNTErrors((a->order == b->order && a->order == c->order), "Unmatched tensors!"); CheckNTErrors((a->order == b->order && a->order == c->order), "Unmatched tensors!");
/* make tensor connections */
XLink::MakeLink(a, b, c, MATH_MULTIPLY);
XLink::AddParamToHeadInt(c, leadingDim);
XLink::AddParamToHead(c, alpha);
#ifdef USE_CUDA #ifdef USE_CUDA
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) { if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
CudaMultiply(a, b, c, leadingDim, alpha); CudaMultiply(a, b, c, leadingDim, alpha);
...@@ -123,4 +118,4 @@ void Multiply(XTensor * a, XTensor * b, XTensor * c, int leadingDim, DTYPE alpha ...@@ -123,4 +118,4 @@ void Multiply(XTensor * a, XTensor * b, XTensor * c, int leadingDim, DTYPE alpha
} }
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -44,10 +44,6 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta) ...@@ -44,10 +44,6 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType, CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched tensors in addition!"); "Unmatched tensors in addition!");
/* make tensor connections */
XLink::MakeLink(a, b, c, MATH_SUM);
XLink::AddParamToHead(c, beta);
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) { if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -117,4 +113,4 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta) ...@@ -117,4 +113,4 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
} }
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -55,12 +55,6 @@ void SelectRange(XTensor * a, XTensor * c, int dim, int low, int high) ...@@ -55,12 +55,6 @@ void SelectRange(XTensor * a, XTensor * c, int dim, int low, int high)
} }
} }
/* make tensor connections */
XLink::MakeLink(a, NULL, c, MATH_SELECTRANGE);
XLink::AddParamToHeadInt(c, dim);
XLink::AddParamToHeadInt(c, low);
XLink::AddParamToHeadInt(c, high);
int stride = 1; int stride = 1;
int dimRDI = a->order - dim - 1; int dimRDI = a->order - dim - 1;
for(int i = 0; i < dimRDI; i++) for(int i = 0; i < dimRDI; i++)
......
...@@ -53,10 +53,6 @@ void ReduceMax(XTensor * input, XTensor * output, int dim) ...@@ -53,10 +53,6 @@ void ReduceMax(XTensor * input, XTensor * output, int dim)
} }
} }
/* make tensor connections */
XLink::MakeLink(input, NULL, output, MATH_REDUCEMAX);
XLink::AddParamToHeadInt(output, dim);
if(input->devID >= 0){ if(input->devID >= 0){
#ifdef USE_CUDA #ifdef USE_CUDA
CudaReduceMax(input, output, dim); CudaReduceMax(input, output, dim);
...@@ -94,4 +90,4 @@ void ReduceMax(XTensor * input, XTensor * output, int dim) ...@@ -94,4 +90,4 @@ void ReduceMax(XTensor * input, XTensor * output, int dim)
} }
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -59,12 +59,6 @@ void ReduceSum(XTensor * input, XTensor * output, int dim, XTensor * shift, DTYP ...@@ -59,12 +59,6 @@ void ReduceSum(XTensor * input, XTensor * output, int dim, XTensor * shift, DTYP
} }
} }
/* make tensor connections */
XLink::MakeLink(input, shift, output, MATH_REDUCESUM);
XLink::AddParamToHeadInt(output, dim);
XLink::AddParamToHead(output, power);
XLink::AddParamToHeadInt(output, isExp);
if(input->devID >= 0){ if(input->devID >= 0){
#ifdef USE_CUDA #ifdef USE_CUDA
CudaReduceSum(input, output, dim, shift, power, isExp); CudaReduceSum(input, output, dim, shift, power, isExp);
......
...@@ -63,11 +63,6 @@ void Merge(XTensor * s, XTensor * t, int whereToMerge, int leadingDim) ...@@ -63,11 +63,6 @@ void Merge(XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
} }
} }
/* make tensor connections */
XLink::MakeLink(s, NULL, t, MATH_MERGE);
XLink::AddParamToHeadInt(t, whereToMerge);
XLink::AddParamToHeadInt(t, leadingDim);
int blockSize = 1; int blockSize = 1;
int blockNum = 1; int blockNum = 1;
int gridSize = 1; int gridSize = 1;
......
...@@ -40,11 +40,6 @@ void Unsqueeze(XTensor * a, XTensor * b, int dim, int dSize) ...@@ -40,11 +40,6 @@ void Unsqueeze(XTensor * a, XTensor * b, int dim, int dSize)
CheckNTErrors((a->order == b->order - 1), "Unmatched tensors!"); CheckNTErrors((a->order == b->order - 1), "Unmatched tensors!");
CheckNTErrors((a->unitSize == b->unitSize), "Unmatched tensors!"); CheckNTErrors((a->unitSize == b->unitSize), "Unmatched tensors!");
/* make tensor connections */
XLink::MakeLink(a, NULL, b, MATH_UNSQUEEZE);
XLink::AddParamToHeadInt(b, dim);
XLink::AddParamToHeadInt(b, dSize);
int dimRDI = b->order - dim - 1; int dimRDI = b->order - dim - 1;
for (int i = 0; i < b->order; i++) { for (int i = 0; i < b->order; i++) {
if (i < dimRDI) { if (i < dimRDI) {
...@@ -99,4 +94,4 @@ void Unsqueeze(XTensor * a, XTensor * b, int dim, int dSize) ...@@ -99,4 +94,4 @@ void Unsqueeze(XTensor * a, XTensor * b, int dim, int dSize)
} }
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -39,10 +39,6 @@ void Sort(XTensor * a, XTensor * index, int dim) ...@@ -39,10 +39,6 @@ void Sort(XTensor * a, XTensor * index, int dim)
CheckNTErrors((a->order == index->order), "Unmatched input tensors!"); CheckNTErrors((a->order == index->order), "Unmatched input tensors!");
CheckNTErrors((index->dataType == X_INT), "Wrong data type!"); CheckNTErrors((index->dataType == X_INT), "Wrong data type!");
/* make tensor connections */
XLink::MakeLink(a, NULL, index, MATH_SORT);
XLink::AddParamToHeadInt(index, dim);
int dimRDI = a->order - dim - 1; int dimRDI = a->order - dim - 1;
/* make the index tensor */ /* make the index tensor */
index->SetAscendingOrder(dim); index->SetAscendingOrder(dim);
...@@ -81,4 +77,4 @@ void Sort(XTensor * a, XTensor * index, int dim) ...@@ -81,4 +77,4 @@ void Sort(XTensor * a, XTensor * index, int dim)
} }
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -41,11 +41,6 @@ void TopK(XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -41,11 +41,6 @@ void TopK(XTensor * a, XTensor * b, XTensor * index, int dim, int k)
CheckNTErrors((index == NULL || a->order == index->order), "Unmatched input tensors!"); CheckNTErrors((index == NULL || a->order == index->order), "Unmatched input tensors!");
CheckNTErrors((index->dataType == X_INT), "Wrong data type!"); CheckNTErrors((index->dataType == X_INT), "Wrong data type!");
/* make tensor connections */
XLink::MakeLink(a, b, index, MATH_TOPK);
XLink::AddParamToHeadInt(index, dim);
XLink::AddParamToHeadInt(index, k);
int dimRDI = a->order - dim - 1; int dimRDI = a->order - dim - 1;
for (int i = 0; i < a->order; i++) { for (int i = 0; i < a->order; i++) {
if (i == dimRDI) { if (i == dimRDI) {
...@@ -110,4 +105,4 @@ void TopK(XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -110,4 +105,4 @@ void TopK(XTensor * a, XTensor * b, XTensor * index, int dim, int k)
} }
} }
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论