Commit 07a5ae75 by xiaotong

new code for XLink

parent c2865de7
...@@ -63,6 +63,9 @@ set edge type name ...@@ -63,6 +63,9 @@ set edge type name
*/ */
void XLink::SetType(const char * typeName) void XLink::SetType(const char * typeName)
{ {
type[0] = 0;
if(typeName == NULL)
return;
strcpy(type, typeName); strcpy(type, typeName);
} }
...@@ -104,35 +107,86 @@ void XLink::AddTwoTails(XTensor * t1, XTensor * t2) ...@@ -104,35 +107,86 @@ void XLink::AddTwoTails(XTensor * t1, XTensor * t2)
} }
/* /*
add an integer parameter add a parameter
>> param - parameter in integer >> param - parameter in default type
*/ */
void XLink::AddParamInt(int param) void XLink::AddParam(DTYPE param)
{ {
void * ps = params; void * ps = params;
params = new char[paramNum + 1]; params = new char[paramNum + 1];
memcpy(params, ps, paramNum * paramSize); memcpy(params, ps, paramNum * paramSize);
int * p = (int*)((char*)params + paramNum * paramSize); DTYPE * p = (DTYPE*)((char*)params + paramNum * paramSize);
*p = param; *p = param;
paramNum++; paramNum++;
delete[] (char*)ps; delete[] (char*)ps;
} }
/* /*
add a float parameter add a parameter
>> param - parameter in float >> param - pointer to the parameter
>> size - size of the parameter
*/ */
void XLink::AddParamFloat(float param) void XLink::AddParam(void * param, int size)
{ {
void * ps = params; void * ps = params;
params = new char[paramNum + 1]; params = new char[paramNum + 1];
memcpy(params, ps, paramNum * paramSize); memcpy(params, ps, paramNum * paramSize);
float * p = (float*)((char*)params + paramNum * paramSize); char * p = (char*)params + paramNum * paramSize;
*p = param; memcpy(p, param, size);
paramNum++; paramNum++;
delete[] (char*)ps; delete[] (char*)ps;
} }
/*
create a hyperedge with two input tensors and a output tensor
>> t1 - a tail tensor
>> t2 - another tail tensor
>> h - head tensor
>> typeName - name of edge type
*/
void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeName)
{
if(h != NULL)
return;
/* forward */
XLink &income = h->income;
income.Reset();
income.SetHead(h);
if(t1 != NULL && t2 != NULL)
income.AddTwoTails(t1, t2);
else if(t1 != NULL)
income.AddTail(t1);
else{
ShowNTErrors("TODO!");
}
income.SetType(typeName);
/* backward for t1 */
if(t1 != NULL){
XLink &outgo = t1->outgo;
CheckNTErrors(outgo.head != h, "Wrong head of the hyperedge!");
outgo.AddTail(h);
}
/* backward for t2 */
if(t2 != NULL){
XLink &outgo = t2->outgo;
CheckNTErrors(outgo.head != h, "Wrong head of the hyperedge!");
outgo.AddTail(h);
}
}
/*
add parameters
>> h - head
>> param - parameter we want introduce
*/
void XLink::AddParamToHead(XTensor * h, DTYPE param)
{
if(h != NULL)
return;
h->income.AddParam(param);
}
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#ifndef __XLINK_H__ #ifndef __XLINK_H__
#define __XLINK_H__ #define __XLINK_H__
#include "XGlobal.h"
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
/* cross reference */ /* cross reference */
...@@ -94,10 +96,18 @@ struct XLink ...@@ -94,10 +96,18 @@ struct XLink
void AddTwoTails(XTensor * t1, XTensor * t2); void AddTwoTails(XTensor * t1, XTensor * t2);
/* add a integer parameter */ /* add a integer parameter */
void AddParamInt(int param); void AddParam(DTYPE param);
/* add a integer parameter */ /* add a integer parameter */
void AddParamFloat(float param); void AddParam(void * param, int size);
/* create a hyper edge with two input tensors and a output tensor */
static
void MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeName);
/* add parameters */
static
void AddParamToHead(XTensor * h, DTYPE param);
}; };
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
*
* We define various names here
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05
* It was really HOT these days. I can't imagine what a hot day in Shenyang!
*/
#ifndef __XNAME_H__
#define __XNAME_H__
namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_SUM "M_SUM"
#define MATH_MATMUL "M_MATMUL"
} // namespace nts(NiuTrans.Tensor)
#endif // __XNAME_H__
\ No newline at end of file
...@@ -139,10 +139,10 @@ struct XTensor ...@@ -139,10 +139,10 @@ struct XTensor
represents a network with three nodes (a, b and c) and a hyperedge that links a and b (tails) to c (head). represents a network with three nodes (a, b and c) and a hyperedge that links a and b (tails) to c (head).
Here "income" keeps which nodes (tensors) are used to form the current node (tensor). Here "income" keeps which nodes (tensors) are used to form the current node (tensor).
*/ */
XLink * income; XLink income;
/* It keeps which nodes (tensors) we go to from the current node (tensor). */ /* It keeps which nodes (tensors) we go to from the current node (tensor). */
XLink * outgo; XLink outgo;
/******************** /********************
XTensor untilities XTensor untilities
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
*/ */
#include "../XTensor.h" #include "../XTensor.h"
#include "../XName.h"
#include "Sum.h" #include "Sum.h"
#include "Sum.cuh" #include "Sum.cuh"
...@@ -37,15 +38,16 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta) ...@@ -37,15 +38,16 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
if (c == NULL) if (c == NULL)
c = a; c = a;
CheckNTErrors((a && b && c), CheckNTErrors(a && b && c, "Empty tensors in addition!");
"Empty tensors in addition!"); CheckNTErrors(a->unitNum == b->unitNum && a->unitNum == c->unitNum,
CheckNTErrors((a->unitNum == b->unitNum && a->unitNum == c->unitNum),
"Unmatched tensors in addition!"); "Unmatched tensors in addition!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
"Unmatched tensors in addition!"); "Unmatched tensors in addition!");
/* make tensor connections */
XLink::MakeLink(a, b, c, MATH_SUM);
XLink::AddParamToHead(c, beta);
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) { if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA #ifdef USE_CUDA
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论