Commit 7a7dc4c6 by xiaotong

new tensor connections

parent 90c12836
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <stdio.h> #include <stdio.h>
#include "XLink.h" #include "XLink.h"
#include "XName.h"
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
...@@ -35,6 +36,7 @@ XLink::XLink() ...@@ -35,6 +36,7 @@ XLink::XLink()
tailNum = 0; tailNum = 0;
paramNum = 0; paramNum = 0;
type[0] = 0; type[0] = 0;
typeID = 0;
} }
/* deconstructor */ /* deconstructor */
...@@ -59,14 +61,14 @@ void XLink::Reset() ...@@ -59,14 +61,14 @@ void XLink::Reset()
/* /*
set edge type name set edge type name
>> typeName - type name in string >> id - id of the type
*/ */
void XLink::SetType(const char * typeName) void XLink::SetType(int id)
{ {
type[0] = 0; type[0] = 0;
if(typeName == NULL) strcpy(type, GetOPName(id));
return; typeID = id;
strcpy(type, typeName); CheckNTErrors(!strcmp(type, "NULL"), "illegal edge type name!");
} }
/* /*
...@@ -141,9 +143,9 @@ create a hyperedge with two input tensors and a output tensor ...@@ -141,9 +143,9 @@ create a hyperedge with two input tensors and a output tensor
>> t1 - a tail tensor >> t1 - a tail tensor
>> t2 - another tail tensor >> t2 - another tail tensor
>> h - head tensor >> h - head tensor
>> typeName - name of edge type >> id - id of the edge type
*/ */
void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeName) void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, int id)
{ {
if(h != NULL) if(h != NULL)
return; return;
...@@ -159,7 +161,7 @@ void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeN ...@@ -159,7 +161,7 @@ void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeN
else{ else{
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
} }
income.SetType(typeName); income.SetType(id);
/* backward for t1 */ /* backward for t1 */
if(t1 != NULL){ if(t1 != NULL){
...@@ -180,15 +182,15 @@ void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeN ...@@ -180,15 +182,15 @@ void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeN
create a hyper edge with a list of tensors and a output tensor create a hyper edge with a list of tensors and a output tensor
>> list - a list of input tensors >> list - a list of input tensors
>> h - head tensor >> h - head tensor
>> typeName - name of edge type >> id - id of the edge type
*/ */
void XLink::MakeLink(XList * list, XTensor * h, const char * typeName) void XLink::MakeLink(XList * list, XTensor * h, int id)
{ {
/* forward */ /* forward */
XLink &income = h->income; XLink &income = h->income;
income.Reset(); income.Reset();
income.SetHead(h); income.SetHead(h);
income.SetType(typeName); income.SetType(id);
for(int i = 0; i < list->count; i++){ for(int i = 0; i < list->count; i++){
XTensor * t = (XTensor*)list->GetItem(i); XTensor * t = (XTensor*)list->GetItem(i);
......
...@@ -74,6 +74,9 @@ struct XLink ...@@ -74,6 +74,9 @@ struct XLink
/* name of the hyperedge type. e.g., sum, mul ... */ /* name of the hyperedge type. e.g., sum, mul ... */
char type[MAX_OP_NAME_LENGTH]; char type[MAX_OP_NAME_LENGTH];
/* type id */
int typeID;
/* constuctor */ /* constuctor */
XLink(); XLink();
...@@ -83,8 +86,8 @@ struct XLink ...@@ -83,8 +86,8 @@ struct XLink
/* reset it */ /* reset it */
void Reset(); void Reset();
/* set edge type name */ /* set edge type id and name */
void SetType(const char * typeName); void SetType(int id);
/* set head */ /* set head */
void SetHead(XTensor * h); void SetHead(XTensor * h);
...@@ -103,11 +106,11 @@ struct XLink ...@@ -103,11 +106,11 @@ struct XLink
/* create a hyper edge with two input tensors and a output tensor */ /* create a hyper edge with two input tensors and a output tensor */
static static
void MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeName); void MakeLink(XTensor * t1, XTensor * t2, XTensor * h, int id);
/* create a hyper edge with a list of tensors and a output tensor */ /* create a hyper edge with a list of input tensors and a output tensor */
static static
void MakeLink(XList * list, XTensor * h, const char * typeName); void MakeLink(XList * list, XTensor * h, int id);
/* add a parameter */ /* add a parameter */
static static
...@@ -120,4 +123,4 @@ struct XLink ...@@ -120,4 +123,4 @@ struct XLink
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // __XLINK_H__ #endif // __XLINK_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05
*/
#ifndef __XNAME_H__
#define __XNAME_H__
namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_ARITHMETIC 0x00001000
#define MATH_SUM MATH_ARITHMETIC + 1
#define MATH_MULTIPLY MATH_SUM + 1
/* get operator name */
const char * GetOPName(int type)
{
if((type & MATH_ARITHMETIC) != 0){
if(type == MATH_SUM)
return "M_SUM";
else if(type == MATH_MULTIPLY)
return "M_MULTIPLY";
}
return "NULL";
}
} // namespace nts(NiuTrans.Tensor)
#endif // __XNAME_H__
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
* We define various names here * We define various names here
* *
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05
* It was really HOT these days. I can't imagine what a hot day here in Shenyang! * It was really HOT these days. I can't imagine it is SO hot here in Shenyang!
*/ */
#ifndef __XNAME_H__ #ifndef __XNAME_H__
...@@ -28,22 +28,13 @@ ...@@ -28,22 +28,13 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_MATMUL "M_MATMUL" #define MATH_ARITHMETIC 10000
#define MATH_CONCATENATESOLY "M_CONCATENATESOLY" #define MATH_SUM MATH_ARITHMETIC + 1
#define MATH_COPYVALUES "M_COPYVALUES" #define MATH_MULTIPLY MATH_SUM + 1
#define MATH_MATRIXMUL "M_MATRIXMUL"
#define MATH_MATRIXMUL2D "M_MATRIXMUL2D" /* get operator name */
#define MATH_MATRIXMULBATCHED "M_MATRIXMULBATCHED" const char * GetOPName(int type);
#define MATH_MERGE "M_MERGE"
#define MATH_MULTIPLY "M_MULTIPLY"
#define MATH_REDUCEMAX "M_REDUCEMAX"
#define MATH_REDUCESUM "M_REDUCESUM"
#define MATH_SELECTRANGE "M_SELECTRANGE"
#define MATH_SORT "M_SORT"
#define MATH_SUM "M_SUM"
#define MATH_TOPK "M_TOPK"
#define MATH_UNSQUEEZE "M_UNSQUEEZE"
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // __XNAME_H__ #endif // __XNAME_H__
\ No newline at end of file
...@@ -40,9 +40,9 @@ where trans() return the transposed matrix if the flag is fired ...@@ -40,9 +40,9 @@ where trans() return the transposed matrix if the flag is fired
>> parallelRunner - parallel processing module >> parallelRunner - parallel processing module
*/ */
void MatrixMul2DParallel(XTensor * a, MATRIX_TRANS_TYPE transposedA, void MatrixMul2DParallel(XTensor * a, MATRIX_TRANS_TYPE transposedA,
XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta, XTensor * c, DTYPE alpha, DTYPE beta,
XPRunner * parallelRunner) XPRunner * parallelRunner)
{ {
CheckNTErrors((a && b && c), "Empty input tensors!"); CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->order == 2 && b->order == 2 && c->order == 2), CheckNTErrors((a->order == 2 && b->order == 2 && c->order == 2),
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XUtility.h"
#include "Sum.h" #include "Sum.h"
#include "Sum.cuh" #include "Sum.cuh"
...@@ -28,12 +29,13 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -28,12 +29,13 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* /*
tensor summation c = a + b * \beta tensor summation c = a + b * \beta
return a pointer
>> a - a tensor >> a - a tensor
>> b - another tensor >> b - another tensor
>> c - where we put a+b*\beta. we save it in a if c is NULL >> c - where we put a+b*\beta. we save it in a if c is NULL
>> beta - the scaling factor >> beta - the scaling factor
*/ */
void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta) void _Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
{ {
if (c == NULL) if (c == NULL)
c = a; c = a;
...@@ -59,17 +61,16 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta) ...@@ -59,17 +61,16 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
ShowNTErrors("Cannot run this method on multiple devices simultaneously!"); ShowNTErrors("Cannot run this method on multiple devices simultaneously!");
} }
else else
CudaSum(a, b, c, beta); _CudaSum(a, b, c, beta);
} }
else else
CudaSum(a, b, c, beta); _CudaSum(a, b, c, beta);
#endif #endif
} }
else { else {
if (!a->isSparse && !b->isSparse) { if (!a->isSparse && !b->isSparse) {
CheckNTErrors(!c->isSparse, CheckNTErrors(!c->isSparse, "Illegal use of sparse matrix in addition!");
"Illegal use of sparse matrix in addition!");
if (a->dataType == DEFAULT_DTYPE && if (a->dataType == DEFAULT_DTYPE &&
b->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE &&
...@@ -112,5 +113,38 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta) ...@@ -112,5 +113,38 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
} }
} }
} }
/*
tensor summation a = a + b * \beta
do it on site
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
*/
void _SumMe(XTensor * a, XTensor * b, DTYPE beta)
{
_Sum(a, b, a, beta);
}
/*
tensor summation a = a + b * \beta
return a XTensor structure
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
*/
XTensor Sum(XTensor &a, XTensor &b, DTYPE beta)
{
XTensor c(&a);
/* computation */
_Sum(&a, &b, &c, beta);
/* tensor connections */
//XLink::MakeLink(&a, &b, &c, MATH_SUM);
//XLink::AddParamToHead(&c, beta);
return c;
}
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
...@@ -51,7 +51,7 @@ tensor summation c = a + b * \beta (cuda version) ...@@ -51,7 +51,7 @@ tensor summation c = a + b * \beta (cuda version)
>> c - where we put a+b*\beta. we save it in a if c is NULL >> c - where we put a+b*\beta. we save it in a if c is NULL
>> beta - the scaling factor >> beta - the scaling factor
*/ */
void CudaSum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta) void _CudaSum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
{ {
if (c == NULL) if (c == NULL)
c = a; c = a;
...@@ -124,7 +124,7 @@ tensor summation c = a + b * \beta (cuda version) with an input handle ...@@ -124,7 +124,7 @@ tensor summation c = a + b * \beta (cuda version) with an input handle
>> size - size of the array >> size - size of the array
>> beta - the coefficient >> beta - the coefficient
*/ */
void CudaSumWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta) void _CudaSumWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta)
{ {
if (size == 0) if (size == 0)
return; return;
...@@ -160,4 +160,4 @@ void CudaSumWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, ...@@ -160,4 +160,4 @@ void CudaSumWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b,
#endif // USE_CUDA #endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -34,14 +34,14 @@ void KernelADD(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1. ...@@ -34,14 +34,14 @@ void KernelADD(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1.
/* tensor summation c = a + b * \beta (cuda version) */ /* tensor summation c = a + b * \beta (cuda version) */
extern "C" extern "C"
void CudaSum(XTensor * a, XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0); void _CudaSum(XTensor * a, XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
/* tensor summation c = a + b * \beta (cuda version) with an input handle */ /* tensor summation c = a + b * \beta (cuda version) with an input handle */
extern "C" extern "C"
void CudaSumWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1.0); void _CudaSumWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1.0);
#endif // USE_CUDA #endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // __SUM_CUH__ #endif // __SUM_CUH__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University. * Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved. * All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24 * $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/ */
#ifndef __SUM_H__ #ifndef __SUM_H__
#define __SUM_H__ #define __SUM_H__
...@@ -27,9 +27,14 @@ ...@@ -27,9 +27,14 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* tensor summation c = a + b * \beta */ /* tensor summation c = a + b * \beta */
extern "C" void _Sum(XTensor * a, XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
void Sum(XTensor * a, XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
/* tensor summation a = a + b * \beta (return a pointer) */
void _SumMe(XTensor * a, XTensor * b, DTYPE beta = (DTYPE)1.0);
/* tensor summation c = a + b * \beta (return a structure) */
XTensor Sum(XTensor &a, XTensor &b, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // __SUM_H__ #endif // __SUM_H__
\ No newline at end of file
...@@ -42,9 +42,6 @@ bool CopyValues(XTensor * s, XTensor * t, XStream * stream) ...@@ -42,9 +42,6 @@ bool CopyValues(XTensor * s, XTensor * t, XStream * stream)
CheckNTErrors((t->data != NULL), "Cannot copy to an empty data array!"); CheckNTErrors((t->data != NULL), "Cannot copy to an empty data array!");
CheckNTErrors((s->unitNum == t->unitNum), "Unmatched data item number!"); CheckNTErrors((s->unitNum == t->unitNum), "Unmatched data item number!");
/* make tensor connections */
XLink::MakeLink(s, NULL, t, MATH_COPYVALUES);
if ((s->dataType == X_FLOAT16 && t->dataType == X_FLOAT) || if ((s->dataType == X_FLOAT16 && t->dataType == X_FLOAT) ||
(s->dataType == X_FLOAT && t->dataType == X_FLOAT16)) { (s->dataType == X_FLOAT && t->dataType == X_FLOAT16)) {
CheckNTErrors(((s->devID < 0 && t->devID < 0) || s->devID == t->devID), CheckNTErrors(((s->devID < 0 && t->devID < 0) || s->devID == t->devID),
......
...@@ -37,10 +37,6 @@ void ConcatenateSolely(XList * smalls, XTensor * big, int dim) ...@@ -37,10 +37,6 @@ void ConcatenateSolely(XList * smalls, XTensor * big, int dim)
{ {
CheckNTErrors((big->order > dim && dim >= 0), "Illegal dimension to concatenate!"); CheckNTErrors((big->order > dim && dim >= 0), "Illegal dimension to concatenate!");
/* make tensor connections */
XLink::MakeLink(smalls, big, MATH_CONCATENATESOLY);
XLink::AddParamToHeadInt(big, dim);
int catDimSize = 0; int catDimSize = 0;
int dimRDI = big->order - dim - 1; int dimRDI = big->order - dim - 1;
...@@ -102,4 +98,4 @@ void ConcatenateSolely(XList * smalls, XTensor * big, int dim) ...@@ -102,4 +98,4 @@ void ConcatenateSolely(XList * smalls, XTensor * big, int dim)
delete sourceArrays; delete sourceArrays;
} }
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -36,7 +36,7 @@ segment a 2d tensor (i.e., matrix) into blocks and run jobs in parallel ...@@ -36,7 +36,7 @@ segment a 2d tensor (i.e., matrix) into blocks and run jobs in parallel
>> ... - arguments of the jobs >> ... - arguments of the jobs
*/ */
void RunParallel2D(XPRunner * parallelRunner, void * job, void RunParallel2D(XPRunner * parallelRunner, void * job,
int opNum, int rowNum, int colNum, int argNum, ...) int opNum, int rowNum, int colNum, int argNum, ...)
{ {
if (rowNum == 0 || colNum == 0) if (rowNum == 0 || colNum == 0)
return; return;
......
...@@ -55,7 +55,7 @@ bool TestSum1() ...@@ -55,7 +55,7 @@ bool TestSum1()
b->SetData(bData, unitNum); b->SetData(bData, unitNum);
/* call sum function */ /* call sum function */
Sum(a, b); _Sum(a, b);
/* check results */ /* check results */
cpuTest = a->CheckData(answer, unitNum); cpuTest = a->CheckData(answer, unitNum);
...@@ -131,7 +131,7 @@ bool TestSum2() ...@@ -131,7 +131,7 @@ bool TestSum2()
c->SetZeroAll(); c->SetZeroAll();
/* call Sum function */ /* call Sum function */
Sum(a, b, c, beta); _Sum(a, b, c, beta);
/* check results */ /* check results */
cpuTest = c->CheckData(answer, unitNum); cpuTest = c->CheckData(answer, unitNum);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论