Commit 12be734f by xiaotong

add tensor connections

parent e02c6b92
...@@ -164,14 +164,42 @@ void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeN ...@@ -164,14 +164,42 @@ void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeN
/* backward for t1 */ /* backward for t1 */
if(t1 != NULL){ if(t1 != NULL){
XLink &outgo = t1->outgo; XLink &outgo = t1->outgo;
CheckNTErrors(outgo.head != h, "Wrong head of the hyperedge!"); CheckNTErrors(outgo.head != t1, "Wrong head of the hyperedge!");
outgo.AddTail(h); outgo.AddTail(h);
} }
/* backward for t2 */ /* backward for t2 */
if(t2 != NULL){ if(t2 != NULL){
XLink &outgo = t2->outgo; XLink &outgo = t2->outgo;
CheckNTErrors(outgo.head != h, "Wrong head of the hyperedge!"); CheckNTErrors(outgo.head != t2, "Wrong head of the hyperedge!");
outgo.AddTail(h);
}
}
/*
create a hyper edge with a list of tensors and a output tensor
>> list - a list of input tensors
>> h - head tensor
>> typeName - name of edge type
*/
void XLink::MakeLink(XList * list, XTensor * h, const char * typeName)
{
/* forward */
XLink &income = h->income;
income.Reset();
income.SetHead(h);
income.SetType(typeName);
for(int i = 0; i < list->count; i++){
XTensor * t = (XTensor*)list->GetItem(i);
income.AddTail(t);
}
/* backward */
for(int i = 0; i < list->count; i++){
XTensor * t = (XTensor*)list->GetItem(i);
XLink &outgo = t->outgo;
CheckNTErrors(outgo.head != t, "Wrong head of the hyperedge!");
outgo.AddTail(h); outgo.AddTail(h);
} }
} }
......
...@@ -105,6 +105,10 @@ struct XLink ...@@ -105,6 +105,10 @@ struct XLink
static static
void MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeName); void MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeName);
/* create a hyper edge with a list of tensors and a output tensor */
static
void MakeLink(XList * list, XTensor * h, const char * typeName);
/* add a parameter */ /* add a parameter */
static static
void AddParamToHead(XTensor * h, DTYPE param); void AddParamToHead(XTensor * h, DTYPE param);
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
* We define various names here * We define various names here
* *
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05
* It was really HOT these days. I can't imagine what a hot day in Shenyang! * It was really HOT these days. I can't imagine what a hot day here in Shenyang!
*/ */
#ifndef __XNAME_H__ #ifndef __XNAME_H__
...@@ -29,6 +29,13 @@ ...@@ -29,6 +29,13 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_MATMUL "M_MATMUL" #define MATH_MATMUL "M_MATMUL"
#define MATH_CONCATENATESOLY "M_CONCATENATESOLY"
#define MATH_COPYVALUES "M_COPYVALUES"
#define MATH_MATRIXMUL "M_MATRIXMUL"
#define MATH_MATRIXMUL2D "M_MATRIXMUL2D"
#define MATH_MATRIXMULBATCHED "M_MATRIXMULBATCHED"
#define MATH_MERGE "M_MERGE"
#define MATH_MULTIPLY "M_MULTIPLY"
#define MATH_REDUCEMAX "M_REDUCEMAX" #define MATH_REDUCEMAX "M_REDUCEMAX"
#define MATH_REDUCESUM "M_REDUCESUM" #define MATH_REDUCESUM "M_REDUCESUM"
#define MATH_SELECTRANGE "M_SELECTRANGE" #define MATH_SELECTRANGE "M_SELECTRANGE"
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "../XTensor.h" #include "../XTensor.h"
#include "../XUtility.h" #include "../XUtility.h"
#include "../XName.h"
#include "ConcatenateSolely.h" #include "ConcatenateSolely.h"
#include "MergeBlockLists.h" #include "MergeBlockLists.h"
...@@ -36,6 +37,10 @@ void ConcatenateSolely(XList * smalls, XTensor * big, int dim) ...@@ -36,6 +37,10 @@ void ConcatenateSolely(XList * smalls, XTensor * big, int dim)
{ {
CheckNTErrors((big->order > dim && dim >= 0), "Illegal dimension to concatenate!"); CheckNTErrors((big->order > dim && dim >= 0), "Illegal dimension to concatenate!");
/* make tensor connections */
XLink::MakeLink(smalls, big, MATH_CONCATENATESOLY);
XLink::AddParamToHeadInt(big, dim);
int catDimSize = 0; int catDimSize = 0;
int dimRDI = big->order - dim - 1; int dimRDI = big->order - dim - 1;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24 * $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/ */
#include "../XName.h"
#include "CopyValues.h" #include "CopyValues.h"
#include "CopyValues.cuh" #include "CopyValues.cuh"
...@@ -41,6 +42,9 @@ bool CopyValues(XTensor * s, XTensor * t, XStream * stream) ...@@ -41,6 +42,9 @@ bool CopyValues(XTensor * s, XTensor * t, XStream * stream)
CheckNTErrors((t->data != NULL), "Cannot copy to an empty data array!"); CheckNTErrors((t->data != NULL), "Cannot copy to an empty data array!");
CheckNTErrors((s->unitNum == t->unitNum), "Unmatched data item number!"); CheckNTErrors((s->unitNum == t->unitNum), "Unmatched data item number!");
/* make tensor connections */
XLink::MakeLink(s, NULL, t, MATH_COPYVALUES);
if ((s->dataType == X_FLOAT16 && t->dataType == X_FLOAT) || if ((s->dataType == X_FLOAT16 && t->dataType == X_FLOAT) ||
(s->dataType == X_FLOAT && t->dataType == X_FLOAT16)) { (s->dataType == X_FLOAT && t->dataType == X_FLOAT16)) {
CheckNTErrors(((s->devID < 0 && t->devID < 0) || s->devID == t->devID), CheckNTErrors(((s->devID < 0 && t->devID < 0) || s->devID == t->devID),
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "../XTensor.h" #include "../XTensor.h"
#include "../XDevice.h" #include "../XDevice.h"
#include "../XName.h"
#include "MatrixMul.h" #include "MatrixMul.h"
#include "MatrixMul2D.h" #include "MatrixMul2D.h"
#include "MatrixMULBatchedCPU.h" #include "MatrixMULBatchedCPU.h"
...@@ -58,6 +59,13 @@ void MatrixMul(XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -58,6 +59,13 @@ void MatrixMul(XTensor * a, MATRIX_TRANS_TYPE transposedA,
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2), CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2),
"Input tensors must have a order > 2!"); "Input tensors must have a order > 2!");
/* make tensor connections */
XLink::MakeLink(a, b, c, MATH_MATRIXMUL);
XLink::AddParamToHeadInt(c, transposedA);
XLink::AddParamToHeadInt(c, transposedB);
XLink::AddParamToHead(c, alpha);
XLink::AddParamToHead(c, beta);
int an = transposedA == X_TRANS ? a->dimSize[1] : a->dimSize[0]; int an = transposedA == X_TRANS ? a->dimSize[1] : a->dimSize[0];
int am = transposedA == X_TRANS ? a->dimSize[0] : a->dimSize[1]; int am = transposedA == X_TRANS ? a->dimSize[0] : a->dimSize[1];
int bn = transposedB == X_TRANS ? b->dimSize[1] : b->dimSize[0]; int bn = transposedB == X_TRANS ? b->dimSize[1] : b->dimSize[0];
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
*/ */
#include "../XTensor.h" #include "../XTensor.h"
#include "../XName.h"
#include "MatrixMul2D.h" #include "MatrixMul2D.h"
#include "MatrixMul2D.cuh" #include "MatrixMul2D.cuh"
#include "MatrixMul2DParallel.h" #include "MatrixMul2DParallel.h"
...@@ -51,6 +52,13 @@ void MatrixMul2D(XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -51,6 +52,13 @@ void MatrixMul2D(XTensor * a, MATRIX_TRANS_TYPE transposedA,
CheckNTErrors((a->order == 2 && b->order == 2 && c->order == 2), CheckNTErrors((a->order == 2 && b->order == 2 && c->order == 2),
"Input tensors must have a order = 2!"); "Input tensors must have a order = 2!");
/* make tensor connections */
XLink::MakeLink(a, b, c, MATH_MATRIXMUL2D);
XLink::AddParamToHeadInt(c, transposedA);
XLink::AddParamToHeadInt(c, transposedB);
XLink::AddParamToHead(c, alpha);
XLink::AddParamToHead(c, beta);
int an = a->dimSize[0], am = a->dimSize[1]; int an = a->dimSize[0], am = a->dimSize[1];
int bn = b->dimSize[0], bm = b->dimSize[1]; int bn = b->dimSize[0], bm = b->dimSize[1];
int cn = c->dimSize[0], cm = c->dimSize[1]; int cn = c->dimSize[0], cm = c->dimSize[1];
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "../XTensor.h" #include "../XTensor.h"
#include "../XDevice.h" #include "../XDevice.h"
#include "../XName.h"
#include "MatrixMulBatched.h" #include "MatrixMulBatched.h"
#include "MatrixMULBatchedCPU.h" #include "MatrixMULBatchedCPU.h"
#include "XTensorBLAS.h" #include "XTensorBLAS.h"
...@@ -52,6 +53,13 @@ void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -52,6 +53,13 @@ void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA,
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2), CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2),
"Input tensors must have a order > 2!"); "Input tensors must have a order > 2!");
/* make tensor connections */
XLink::MakeLink(a, b, c, MATH_MATRIXMULBATCHED);
XLink::AddParamToHeadInt(c, transposedA);
XLink::AddParamToHeadInt(c, transposedB);
XLink::AddParamToHead(c, alpha);
XLink::AddParamToHead(c, beta);
int an = transposedA == X_TRANS ? a->dimSize[1] : a->dimSize[0]; int an = transposedA == X_TRANS ? a->dimSize[1] : a->dimSize[0];
int am = transposedA == X_TRANS ? a->dimSize[0] : a->dimSize[1]; int am = transposedA == X_TRANS ? a->dimSize[0] : a->dimSize[1];
int bn = transposedB == X_TRANS ? b->dimSize[1] : b->dimSize[0]; int bn = transposedB == X_TRANS ? b->dimSize[1] : b->dimSize[0];
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "../XTensor.h" #include "../XTensor.h"
#include "../XUtility.h" #include "../XUtility.h"
#include "../XName.h"
#include "Merge.h" #include "Merge.h"
#include "MakeMergeBlockIndex.h" #include "MakeMergeBlockIndex.h"
#include "CopyBlocksOnSite.h" #include "CopyBlocksOnSite.h"
...@@ -63,6 +64,11 @@ void Merge(XTensor * s, XTensor * t, int whereToMerge, int leadingDim) ...@@ -63,6 +64,11 @@ void Merge(XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
} }
} }
/* make tensor connections */
XLink::MakeLink(s, NULL, t, MATH_MERGE);
XLink::AddParamToHeadInt(t, whereToMerge);
XLink::AddParamToHeadInt(t, leadingDim);
int blockSize = 1; int blockSize = 1;
int blockNum = 1; int blockNum = 1;
int gridSize = 1; int gridSize = 1;
......
...@@ -27,12 +27,12 @@ ...@@ -27,12 +27,12 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* /*
merge data by blocks merge data by blocks
>> sourceList - list of source data array >> sourceList - list of source data array
>> blockSizes - list of the block size for each source data array >> blockSizes - list of the block size for each source data array
>> blockNum - number of blocks kept in each data array >> blockNum - number of blocks kept in each data array
>> target - target data array >> target - target data array
>> myMem - memory pool >> myMem - memory pool
*/ */
void MergeBlockLists(XList * sourceList, int * blockSizes, int blockNum, void * target, XMem * myMem) void MergeBlockLists(XList * sourceList, int * blockSizes, int blockNum, void * target, XMem * myMem)
{ {
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
*/ */
#include "../XTensor.h" #include "../XTensor.h"
#include "../XName.h"
#include "Multiply.h" #include "Multiply.h"
#include "Multiply.cuh" #include "Multiply.cuh"
...@@ -41,6 +42,11 @@ void Multiply(XTensor * a, XTensor * b, XTensor * c, int leadingDim, DTYPE alpha ...@@ -41,6 +42,11 @@ void Multiply(XTensor * a, XTensor * b, XTensor * c, int leadingDim, DTYPE alpha
"Unmatched tensors in multiplication!"); "Unmatched tensors in multiplication!");
CheckNTErrors((a->order == b->order && a->order == c->order), "Unmatched tensors!"); CheckNTErrors((a->order == b->order && a->order == c->order), "Unmatched tensors!");
/* make tensor connections */
XLink::MakeLink(a, b, c, MATH_MULTIPLY);
XLink::AddParamToHeadInt(c, leadingDim);
XLink::AddParamToHead(c, alpha);
#ifdef USE_CUDA #ifdef USE_CUDA
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) { if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
CudaMultiply(a, b, c, leadingDim, alpha); CudaMultiply(a, b, c, leadingDim, alpha);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论