Commit c2865de7 by xiaotong

new code for XLink

parent 34c56119
...@@ -16,9 +16,7 @@ ...@@ -16,9 +16,7 @@
*/ */
/* /*
* some public functions are defined here
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-04 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-04
*
*/ */
#include <stdio.h> #include <stdio.h>
...@@ -26,5 +24,115 @@ ...@@ -26,5 +24,115 @@
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
int XLink::paramSize = 64;
/* constuctor */
XLink::XLink()
{
head = NULL;
tails = NULL;
params = NULL;
tailNum = 0;
paramNum = 0;
type[0] = 0;
}
/* deconstructor */
XLink::~XLink()
{
delete[] tails;
delete[] (char*)params;
}
/* reset it */
void XLink::Reset()
{
delete[] tails;
delete[] (char*)params;
head = NULL;
tails = NULL;
params = NULL;
tailNum = 0;
paramNum = 0;
type[0] = 0;
}
/*
set edge type name
>> typeName - type name in string
*/
void XLink::SetType(const char * typeName)
{
strcpy(type, typeName);
}
/*
set head
>> h - pointer to the head tensor
*/
void XLink::SetHead(XTensor * h)
{
head = h;
}
/*
add a tail
>> t - pointer to the tail tensor
*/
void XLink::AddTail(XTensor * t)
{
XTensor ** ts = tails;
tails = new XTensor*[tailNum + 1];
memcpy(tails, ts, sizeof(XTensor*) * tailNum);
tails[tailNum++] = t;
delete[] ts;
}
/*
add two tails in one time
>> t1 - pointer to the tail tensor
>> t2 - pointer to another tail tensor
*/
void XLink::AddTwoTails(XTensor * t1, XTensor * t2)
{
XTensor ** ts = tails;
tails = new XTensor*[tailNum + 2];
memcpy(tails, ts, sizeof(XTensor*) * tailNum);
tails[tailNum++] = t1;
tails[tailNum++] = t2;
delete[] ts;
}
/*
add an integer parameter
>> param - parameter in integer
*/
void XLink::AddParamInt(int param)
{
void * ps = params;
params = new char[paramNum + 1];
memcpy(params, ps, paramNum * paramSize);
int * p = (int*)((char*)params + paramNum * paramSize);
*p = param;
paramNum++;
delete[] (char*)ps;
}
/*
add a float parameter
>> param - parameter in float
*/
void XLink::AddParamFloat(float param)
{
void * ps = params;
params = new char[paramNum + 1];
memcpy(params, ps, paramNum * paramSize);
float * p = (float*)((char*)params + paramNum * paramSize);
*p = param;
paramNum++;
delete[] (char*)ps;
}
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
...@@ -16,9 +16,7 @@ ...@@ -16,9 +16,7 @@
*/ */
/* /*
* some public functions are defined here
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-04 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-04
*
*/ */
#include <stdio.h> #include <stdio.h>
...@@ -33,6 +31,8 @@ namespace nts{ // namespace nts(NiuTrans.Tensor) ...@@ -33,6 +31,8 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* cross reference */ /* cross reference */
struct XTensor; struct XTensor;
#define MAX_OP_NAME_LENGTH 16
/* /*
This defines the link among tensors in networks. XLink can be This defines the link among tensors in networks. XLink can be
cast as a hyperedge in a graph. when we compute on tensors, we actually create a cast as a hyperedge in a graph. when we compute on tensors, we actually create a
...@@ -53,12 +53,51 @@ struct XLink ...@@ -53,12 +53,51 @@ struct XLink
/* head of the hyperedge */ /* head of the hyperedge */
XTensor * head; XTensor * head;
/* tails of the hyperedge ∂*/ /* tails of the hyperedge */
XTensor * tails[]; XTensor ** tails;
/* number of tails */
int tailNum;
/* parameters used. e.g., c = a * b * \alpha
scalar \alpha is the parameter */
void * params;
/* number of parameters */
int paramNum;
/* size of each parameter */
static int paramSize;
/* name of the hyperedge type. e.g., sum, mul ... */
char type[MAX_OP_NAME_LENGTH];
/* constuctor */
XLink();
/* deconstructor */
~XLink();
/* reset it */
void Reset();
/* set edge type name */
void SetType(const char * typeName);
/* set head */
void SetHead(XTensor * h);
/* add a tail */
void AddTail(XTensor * t);
/* add two tails in one time */
void AddTwoTails(XTensor * t1, XTensor * t2);
XLink(){}; /* add a integer parameter */
void AddParamInt(int param);
~XLink(){}; /* add a integer parameter */
void AddParamFloat(float param);
}; };
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论