Commit dcd95238 by xiaotong

update Sum function for easy use

parent c3b9f35a
...@@ -52,10 +52,15 @@ int main( int argc, const char ** argv ) ...@@ -52,10 +52,15 @@ int main( int argc, const char ** argv )
a.Set2D(1.0F, 0, 0); a.Set2D(1.0F, 0, 0);
a.Set2D(1.0F, 1, 1); a.Set2D(1.0F, 1, 1);
b = Sum(a, a); b = Sum(a, Sum(a, a));
XTensor c = b;
a.Dump(stderr, "a: ");
b.Dump(stderr, "b: "); b.Dump(stderr, "b: ");
} }
return 0;
if(argc > 1 && !strcmp(argv[1], "-test")) if(argc > 1 && !strcmp(argv[1], "-test"))
Test(); Test();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm")) else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
......
...@@ -197,37 +197,16 @@ create a hyperedge with two input tensors and a output tensor ...@@ -197,37 +197,16 @@ create a hyperedge with two input tensors and a output tensor
>> h - head tensor >> h - head tensor
>> id - id of the edge type >> id - id of the edge type
*/ */
void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, int id) void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id)
{ {
if(h != NULL) if(h != NULL)
return; return;
/* forward */ XList list(2);
XLink &income = h->income; list.Add(t1);
income.Reset(); list.Add(t2);
income.SetHead(h);
if(t1 != NULL && t2 != NULL)
income.AddTwoTails(t1, t2);
else if(t1 != NULL)
income.AddTail(t1);
else{
ShowNTErrors("TODO!");
}
income.SetType(id);
/* backward for t1 */ MakeLink(&list, h, id);
if(t1 != NULL){
XLink &outgo = t1->outgo;
CheckNTErrors(outgo.head != t1, "Wrong head of the hyperedge!");
outgo.AddTail(h);
}
/* backward for t2 */
if(t2 != NULL){
XLink &outgo = t2->outgo;
CheckNTErrors(outgo.head != t2, "Wrong head of the hyperedge!");
outgo.AddTail(h);
}
} }
/* /*
...@@ -287,7 +266,7 @@ replace a node with another, i.e., we redirect the links to the new node ...@@ -287,7 +266,7 @@ replace a node with another, i.e., we redirect the links to the new node
>> oldOne - the node to be replaced >> oldOne - the node to be replaced
>> newOne - the new node >> newOne - the new node
*/ */
void XLink::Replace(XTensor * oldOne, XTensor * newOne) void XLink::Replace(const XTensor * oldOne, XTensor * newOne)
{ {
if(oldOne == NULL || newOne == NULL) if(oldOne == NULL || newOne == NULL)
return; return;
...@@ -329,10 +308,6 @@ void XLink::Replace(XTensor * oldOne, XTensor * newOne) ...@@ -329,10 +308,6 @@ void XLink::Replace(XTensor * oldOne, XTensor * newOne)
} }
} }
} }
XLink &oldOutgo = oldOne->outgo;
ClearIncoming(oldOne);
oldOne->outgo.tailNum = 0;
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -116,7 +116,7 @@ struct XLink ...@@ -116,7 +116,7 @@ struct XLink
/* create a hyper edge with two input tensors and a output tensor */ /* create a hyper edge with two input tensors and a output tensor */
static static
void MakeLink(XTensor * t1, XTensor * t2, XTensor * h, int id); void MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id);
/* create a hyper edge with a list of input tensors and a output tensor */ /* create a hyper edge with a list of input tensors and a output tensor */
static static
...@@ -132,7 +132,7 @@ struct XLink ...@@ -132,7 +132,7 @@ struct XLink
/* replace a node with another, i.e., we redirect the links to the new node */ /* replace a node with another, i.e., we redirect the links to the new node */
static static
void Replace(XTensor * oldOne, XTensor * newOne); void Replace(const XTensor * oldOne, XTensor * newOne);
}; };
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -111,7 +111,7 @@ void XList::Create(int myMaxNum, XMem * myMem) ...@@ -111,7 +111,7 @@ void XList::Create(int myMaxNum, XMem * myMem)
add an item into the list add an item into the list
>> item - pointer to the item >> item - pointer to the item
*/ */
void XList::Add(void * item) void XList::Add(const void * item)
{ {
if( count == maxNum ){ if( count == maxNum ){
void ** newItems; void ** newItems;
...@@ -126,7 +126,8 @@ void XList::Add(void * item) ...@@ -126,7 +126,8 @@ void XList::Add(void * item)
maxNum = maxNum * 2 + 1; maxNum = maxNum * 2 + 1;
} }
items[count++] = item; MTYPE p = (MTYPE)item;
items[count++] = (MTYPE*)p;
} }
......
...@@ -69,7 +69,7 @@ public: ...@@ -69,7 +69,7 @@ public:
/* utilities */ /* utilities */
void Create(int myMaxNum, XMem * myMem); void Create(int myMaxNum, XMem * myMem);
void Add(void * item); void Add(const void * item);
void Add(void ** inputItems, int inputItemCount); void Add(void ** inputItems, int inputItemCount);
void AddList(XList * l); void AddList(XList * l);
void AddInt(int i); void AddInt(int i);
......
...@@ -81,6 +81,7 @@ constructor ...@@ -81,6 +81,7 @@ constructor
XTensor::XTensor() XTensor::XTensor()
{ {
memset(this, 0, sizeof(XTensor)); memset(this, 0, sizeof(XTensor));
SetDataPointer();
id = MakeTensorID(); id = MakeTensorID();
order = -1; order = -1;
...@@ -104,9 +105,10 @@ XTensor::XTensor() ...@@ -104,9 +105,10 @@ XTensor::XTensor()
} }
/* constructor */ /* constructor */
XTensor::XTensor(XTensor * reference) XTensor::XTensor(const XTensor * reference)
{ {
memset(this, 0, sizeof(XTensor)); memset(this, 0, sizeof(XTensor));
SetDataPointer();
id = MakeTensorID(); id = MakeTensorID();
dataType = DEFAULT_DTYPE; dataType = DEFAULT_DTYPE;
...@@ -129,6 +131,7 @@ XTensor::XTensor(const int myOrder, int myDevID, XMem * myMem) ...@@ -129,6 +131,7 @@ XTensor::XTensor(const int myOrder, int myDevID, XMem * myMem)
{ {
CheckNTErrors((myOrder > 0), "Illegal tensor order1"); CheckNTErrors((myOrder > 0), "Illegal tensor order1");
SetDataPointer();
id = MakeTensorID(); id = MakeTensorID();
order = myOrder; order = myOrder;
memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM); memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
...@@ -166,6 +169,7 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP ...@@ -166,6 +169,7 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP
{ {
CheckNTErrors((myOrder > 0), "Illegal tensor order1"); CheckNTErrors((myOrder > 0), "Illegal tensor order1");
SetDataPointer();
id = MakeTensorID(); id = MakeTensorID();
order = myOrder; order = myOrder;
memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM); memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
...@@ -188,8 +192,9 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP ...@@ -188,8 +192,9 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP
} }
/* copy constructor */ /* copy constructor */
XTensor::XTensor(XTensor &reference) XTensor::XTensor(const XTensor &reference)
{ {
SetDataPointer();
id = MakeTensorID(); id = MakeTensorID();
ShallowCopy(reference); ShallowCopy(reference);
isInit = false; isInit = false;
...@@ -200,7 +205,13 @@ XTensor::XTensor(XTensor &reference) ...@@ -200,7 +205,13 @@ XTensor::XTensor(XTensor &reference)
devID = reference.devID; devID = reference.devID;
mem = reference.mem; mem = reference.mem;
data = reference.data; data = reference.data;
reference.data = NULL;
/* what we really want to do is "reference.data = NULL;"
As "reference" is constant, we cannot reset reference.data
here. So we save the ADDRESS of reference.data in
reference.dataP, and do this work by updating "*reference.dataP".
This is VERY trick and might not be the best solution :) */
*reference.dataP = NULL;
} }
else{ else{
DestroyData(); DestroyData();
...@@ -263,7 +274,7 @@ void XTensor::ShallowCopy(const XTensor &tensor) ...@@ -263,7 +274,7 @@ void XTensor::ShallowCopy(const XTensor &tensor)
} }
/* overloading of the equal-sign */ /* overloading of the equal-sign */
XTensor& XTensor::operator= (XTensor& tensor) XTensor& XTensor::operator= (const XTensor& tensor)
{ {
/* hard copy of data array */ /* hard copy of data array */
int size = unitNum * unitSize; int size = unitNum * unitSize;
...@@ -284,7 +295,6 @@ XTensor& XTensor::operator= (XTensor& tensor) ...@@ -284,7 +295,6 @@ XTensor& XTensor::operator= (XTensor& tensor)
} }
Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio); Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
CopyValues(&tensor, this); CopyValues(&tensor, this);
} }
...@@ -621,6 +631,12 @@ bool XTensor::CheckData(const void * d, int num, int beg) ...@@ -621,6 +631,12 @@ bool XTensor::CheckData(const void * d, int num, int beg)
return true; return true;
} }
/* set the pointer to "data" */
void XTensor::SetDataPointer()
{
dataP = &data;
}
bool XTensor::CheckData(const void * d, int num, float tolerance, int beg) bool XTensor::CheckData(const void * d, int num, float tolerance, int beg)
{ {
if (data == NULL || d == NULL) if (data == NULL || d == NULL)
...@@ -969,7 +985,7 @@ increase the value of a cell in a 2d tensor ...@@ -969,7 +985,7 @@ increase the value of a cell in a 2d tensor
} }
/* get the number of non-zero elements (in a sparse tensor) */ /* get the number of non-zero elements (in a sparse tensor) */
int XTensor::GetNonzeroSize() const int XTensor::GetNonzeroSize()
{ {
if(!isSparse){ if(!isSparse){
XPRINT(1, stderr, "WARNING! Counting non-zero elements in a dense tensor might be slow!\n"); XPRINT(1, stderr, "WARNING! Counting non-zero elements in a dense tensor might be slow!\n");
...@@ -1736,7 +1752,7 @@ initialize a tensor with a reference tensor ...@@ -1736,7 +1752,7 @@ initialize a tensor with a reference tensor
>> tensor - the tensor we intend to initialize >> tensor - the tensor we intend to initialize
>> reference - the reference tensor >> reference - the reference tensor
*/ */
void InitTensor(XTensor * tensor, XTensor * reference) void InitTensor(XTensor * tensor, const XTensor * reference)
{ {
if(reference->order < 0) if(reference->order < 0)
return; return;
......
...@@ -56,11 +56,11 @@ struct XLink; ...@@ -56,11 +56,11 @@ struct XLink;
#define FAST_MATRIX #define FAST_MATRIX
/* /*
We implemente the tensor class here though we have defined the class of XMatrix. It We implemente the tensor class here.
is the parent class of XMatrix.
*/ */
struct XTensor class XTensor
{ {
public:
/* id */ /* id */
int id; int id;
...@@ -74,6 +74,10 @@ struct XTensor ...@@ -74,6 +74,10 @@ struct XTensor
when the matrix is operated on GPUs */ when the matrix is operated on GPUs */
void * dataHost; void * dataHost;
/* a pointer to data (i.e., a pointer to the address of "data".
This is for reset "data" when XTensor is used as a const variable. */
void ** dataP;
/* /*
device id device id
<0: CPU memory <0: CPU memory
...@@ -158,7 +162,7 @@ struct XTensor ...@@ -158,7 +162,7 @@ struct XTensor
XTensor(); XTensor();
/* constructor */ /* constructor */
XTensor(XTensor * reference); XTensor(const XTensor * reference);
/* constructor */ /* constructor */
XTensor(const int myOrder, int myDevID, XMem * myMem); XTensor(const int myOrder, int myDevID, XMem * myMem);
...@@ -168,7 +172,7 @@ struct XTensor ...@@ -168,7 +172,7 @@ struct XTensor
const float myDenseRatio, XMem * myMem); const float myDenseRatio, XMem * myMem);
/* copy constructor */ /* copy constructor */
XTensor(XTensor &reference); XTensor(const XTensor &reference);
/* de-constructor */ /* de-constructor */
~XTensor(); ~XTensor();
...@@ -180,7 +184,7 @@ struct XTensor ...@@ -180,7 +184,7 @@ struct XTensor
void ShallowCopy(const XTensor &tensor); void ShallowCopy(const XTensor &tensor);
/* overloading of the equal-sign */ /* overloading of the equal-sign */
XTensor& operator= (XTensor &tensor); XTensor& operator= (const XTensor &tensor);
/* judge whether the two matrices are in the same type and size */ /* judge whether the two matrices are in the same type and size */
static static
...@@ -226,6 +230,9 @@ struct XTensor ...@@ -226,6 +230,9 @@ struct XTensor
/* check whether the data array is the same as the answer */ /* check whether the data array is the same as the answer */
bool CheckData(const void * answer, int num, float tolerance, int beg = 0); bool CheckData(const void * answer, int num, float tolerance, int beg = 0);
/* set the pointer to "data" */
void SetDataPointer();
/* set the cell to the ascending order along a given dimension */ /* set the cell to the ascending order along a given dimension */
void SetAscendingOrder(int dim); void SetAscendingOrder(int dim);
...@@ -275,7 +282,7 @@ struct XTensor ...@@ -275,7 +282,7 @@ struct XTensor
bool Add2D(DTYPE value, int ni, int mi); bool Add2D(DTYPE value, int ni, int mi);
/* get the number of non-zero elements (in a sparse tensor) */ /* get the number of non-zero elements (in a sparse tensor) */
int GetNonzeroSize(); const int GetNonzeroSize();
/* set the tensor as "temporary" */ /* set the tensor as "temporary" */
void SetTMP(bool myIsTmp = true); void SetTMP(bool myIsTmp = true);
...@@ -350,7 +357,7 @@ void InitTensor5D(XTensor * tensor, const int d0, const int d1, const int d2, co ...@@ -350,7 +357,7 @@ void InitTensor5D(XTensor * tensor, const int d0, const int d1, const int d2, co
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1, XMem * myMem = NULL); const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1, XMem * myMem = NULL);
/* initialize a tensor with a reference tensor */ /* initialize a tensor with a reference tensor */
void InitTensor(XTensor * tensor, XTensor * reference); void InitTensor(XTensor * tensor, const XTensor * reference);
/* generate a XTensor */ /* generate a XTensor */
XTensor * NewTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType = X_FLOAT, XTensor * NewTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType = X_FLOAT,
......
...@@ -35,11 +35,8 @@ return a pointer ...@@ -35,11 +35,8 @@ return a pointer
>> c - where we put a+b*\beta. we save it in a if c is NULL >> c - where we put a+b*\beta. we save it in a if c is NULL
>> beta - the scaling factor >> beta - the scaling factor
*/ */
void _Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta) void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{ {
if (c == NULL)
c = a;
CheckNTErrors(a && b && c, "Empty tensors in addition!"); CheckNTErrors(a && b && c, "Empty tensors in addition!");
CheckNTErrors(a->unitNum == b->unitNum && a->unitNum == c->unitNum, CheckNTErrors(a->unitNum == b->unitNum && a->unitNum == c->unitNum,
"Unmatched tensors in addition!"); "Unmatched tensors in addition!");
...@@ -121,7 +118,7 @@ do it on site ...@@ -121,7 +118,7 @@ do it on site
>> b - another tensor >> b - another tensor
>> beta - the scaling factor >> beta - the scaling factor
*/ */
void _SumMe(XTensor * a, XTensor * b, DTYPE beta) void _SumMe(XTensor * a, const XTensor * b, DTYPE beta)
{ {
_Sum(a, b, a, beta); _Sum(a, b, a, beta);
} }
...@@ -133,7 +130,7 @@ return a XTensor structure ...@@ -133,7 +130,7 @@ return a XTensor structure
>> b - another tensor >> b - another tensor
>> beta - the scaling factor >> beta - the scaling factor
*/ */
XTensor Sum(XTensor &a, XTensor &b, DTYPE beta) XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
{ {
XTensor c(&a); XTensor c(&a);
c.SetTMP(); c.SetTMP();
...@@ -141,6 +138,8 @@ XTensor Sum(XTensor &a, XTensor &b, DTYPE beta) ...@@ -141,6 +138,8 @@ XTensor Sum(XTensor &a, XTensor &b, DTYPE beta)
/* computation */ /* computation */
_Sum(&a, &b, &c, beta); _Sum(&a, &b, &c, beta);
c.Dump(stderr, "c: ");
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUM); XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
......
...@@ -27,13 +27,13 @@ ...@@ -27,13 +27,13 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* tensor summation c = a + b * \beta */ /* tensor summation c = a + b * \beta */
void _Sum(XTensor * a, XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0); void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
/* tensor summation a = a + b * \beta (return a pointer) */ /* tensor summation a = a + b * \beta (return a pointer) */
void _SumMe(XTensor * a, XTensor * b, DTYPE beta = (DTYPE)1.0); void _SumMe(XTensor * a, const XTensor * b, DTYPE beta = (DTYPE)1.0);
/* tensor summation c = a + b * \beta (return a structure) */ /* tensor summation c = a + b * \beta (return a structure) */
XTensor Sum(XTensor &a, XTensor &b, DTYPE beta = (DTYPE)1.0); XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -32,7 +32,7 @@ copy s to t ...@@ -32,7 +32,7 @@ copy s to t
>> stream - the stream for creating the job pipeline >> stream - the stream for creating the job pipeline
<< return - succeeded or not << return - succeeded or not
*/ */
bool CopyValues(XTensor * s, XTensor * t, XStream * stream) bool CopyValues(const XTensor * s, XTensor * t, XStream * stream)
{ {
if (s == NULL || t == NULL) if (s == NULL || t == NULL)
return false; return false;
...@@ -60,10 +60,10 @@ bool CopyValues(XTensor * s, XTensor * t, XStream * stream) ...@@ -60,10 +60,10 @@ bool CopyValues(XTensor * s, XTensor * t, XStream * stream)
memcpy((char*)t->data, (char*)s->data, s->unitSize * s->unitNum); memcpy((char*)t->data, (char*)s->data, s->unitSize * s->unitNum);
} }
else if (s->isSparse && t->isSparse) { else if (s->isSparse && t->isSparse) {
int d = s->GetNonzeroSize(); int d = s->unitNumNonZero;
t->Resize(s); t->Resize(s);
t->unitNumNonZero = d; t->unitNumNonZero = d;
memcpy((char*)t->data, (char*)s->data, sizeof(int) + d *(sizeof(int) + sizeof(DTYPE))); memcpy((char*)t->data, (char*)s->data, sizeof(int) + d *(sizeof(int) + t->unitSize));
} }
else { else {
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
......
...@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* copy s to t */ /* copy s to t */
extern "C" extern "C"
bool CopyValues(XTensor * s, XTensor * t, XStream * stream = NULL); bool CopyValues(const XTensor * s, XTensor * t, XStream * stream = NULL);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -55,7 +55,7 @@ bool TestSum1() ...@@ -55,7 +55,7 @@ bool TestSum1()
b->SetData(bData, unitNum); b->SetData(bData, unitNum);
/* call sum function */ /* call sum function */
_Sum(a, b); _Sum(a, b, a);
/* check results */ /* check results */
cpuTest = a->CheckData(answer, unitNum); cpuTest = a->CheckData(answer, unitNum);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论