Commit 8cae67f2 by xiaotong

1) fix the bug of unexpected breakpoints in tensor graph 2) add devID to XTensor…

1) fix the bug of unexpected breakpoints in tensor graph 2) add devID to XTensor constructor 3) avoid unneccessary data allocation in XLink
parent f8ddb15b
...@@ -45,8 +45,9 @@ int main( int argc, const char ** argv ) ...@@ -45,8 +45,9 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(78); //_CrtSetBreakAlloc(78);
/* a tiny test */ /* a tiny test */
if(1) if(true)
SmallTest(); SmallTest();
return 0;
if(argc > 1 && !strcmp(argv[1], "-test")) if(argc > 1 && !strcmp(argv[1], "-test"))
Test(); Test();
...@@ -68,25 +69,31 @@ void SmallTest() ...@@ -68,25 +69,31 @@ void SmallTest()
{ {
XTensor a; XTensor a;
XTensor b; XTensor b;
XTensor c;
XTensor d;
InitTensor2D(&a, 2, 2); InitTensor2D(&a, 2, 2);
InitTensor2D(&b, 2, 2);
a.SetZeroAll(); a.SetZeroAll();
b.SetZeroAll();
a.Set2D(1.0F, 0, 0); a.Set2D(1.0F, 0, 0);
a.Set2D(2.0F, 1, 1); a.Set2D(2.0F, 1, 1);
b = Sum(a, Multiply(a, a)); //b = Sum(a, Multiply(a, a));
XTensor c = a * b + a;
int nnn = 1; /* cannot write this !!!!!!!!!!!!! */
//XTensor c = a * b + a;
//XTensor d = a + b + c.Lin(0.5F);
XTensor d = a + b + c.Lin(0.5F); c = a * b + a;
d = a + b + c.Lin(0.5F);
//d = a + b * b;
XLink::CheckNetwork(&d); XLink::CheckNetwork(&d);
XLink::ShowNetwork(stderr, &b); XLink::ShowNetwork(stderr, &d);
a.Dump(stderr, "a: "); a.Dump(stderr, "a:");
b.Dump(stderr, "b: "); b.Dump(stderr, "b:");
c.Dump(stderr, "c: "); c.Dump(stderr, "c:");
d.Dump(stderr, "d: "); d.Dump(stderr, "d:");
} }
...@@ -75,6 +75,39 @@ void XLink::ClearTail() ...@@ -75,6 +75,39 @@ void XLink::ClearTail()
} }
/* /*
clear the outgoing node list of tensor node
>> node - the node to be cleared
*/
void XLink::ClearOutgoing(XTensor * node)
{
if(node == NULL)
return;
XLink &outgo = node->outgo;
for(int i = 0; i < outgo.tailNum; i++){
/* for each parent node */
XTensor * parent = outgo.tails[i];
XLink &parentIncome = parent->income;
CheckNTErrors(parentIncome.tailNum > 0, "The node must have incoming edges!");
/* we check for each parent node and remove the link to current node */
for(int j = 0; j < parentIncome.tailNum; j++){
if(parentIncome.tails[j] == node){
memcpy(parentIncome.tails + j, parentIncome.tails + j + 1,
sizeof(XTensor*) * (parentIncome.tailNum - 1 - j));
parentIncome.tailNum--;
break;
}
}
}
outgo.ClearTail();
}
/*
clear the incoming node list of tensor node clear the incoming node list of tensor node
>> node - the node to be cleared >> node - the node to be cleared
*/ */
...@@ -87,7 +120,7 @@ void XLink::ClearIncoming(XTensor * node) ...@@ -87,7 +120,7 @@ void XLink::ClearIncoming(XTensor * node)
for(int i = 0; i < income.tailNum; i++){ for(int i = 0; i < income.tailNum; i++){
/* for a incoming node */ /* for each incoming node */
XTensor * child = income.tails[i]; XTensor * child = income.tails[i];
XLink &childOutgo = child->outgo; XLink &childOutgo = child->outgo;
...@@ -96,9 +129,8 @@ void XLink::ClearIncoming(XTensor * node) ...@@ -96,9 +129,8 @@ void XLink::ClearIncoming(XTensor * node)
/* we check for each child node and remove the link to current node */ /* we check for each child node and remove the link to current node */
for(int j = 0; j < childOutgo.tailNum; j++){ for(int j = 0; j < childOutgo.tailNum; j++){
if(childOutgo.tails[j] == node){ if(childOutgo.tails[j] == node){
memcpy(childOutgo.tails + j, memcpy(childOutgo.tails + j, childOutgo.tails + j + 1,
childOutgo.tails + j + 1, sizeof(XTensor*) * (childOutgo.tailNum - 1 - j));
(childOutgo.tailNum - 1 - j) * sizeof(XTensor*));
childOutgo.tailNum--; childOutgo.tailNum--;
break; break;
} }
...@@ -109,7 +141,6 @@ void XLink::ClearIncoming(XTensor * node) ...@@ -109,7 +141,6 @@ void XLink::ClearIncoming(XTensor * node)
} }
income.ClearTail(); income.ClearTail();
income.tailNum = 0;
} }
/* /*
...@@ -239,6 +270,7 @@ void XLink::MakeLink(XList * list, XTensor * h, int id) ...@@ -239,6 +270,7 @@ void XLink::MakeLink(XList * list, XTensor * h, int id)
XLink &outgo = t->outgo; XLink &outgo = t->outgo;
CheckNTErrors(outgo.head == NULL || outgo.head == t, CheckNTErrors(outgo.head == NULL || outgo.head == t,
"Wrong head of the hyperedge!"); "Wrong head of the hyperedge!");
outgo.SetHead(t);
outgo.AddTail(h); outgo.AddTail(h);
} }
} }
...@@ -277,16 +309,23 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne) ...@@ -277,16 +309,23 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne)
if(oldOne == NULL || newOne == NULL) if(oldOne == NULL || newOne == NULL)
return; return;
XLink::ClearIncoming(newOne);
XLink &newIncome = newOne->income; XLink &newIncome = newOne->income;
XLink &newOutgo = newOne->outgo; XLink &newOutgo = newOne->outgo;
int incomeTailNum = newIncome.tailNum;
int outgoTailNum = newOutgo.tailNum;
XLink::ClearOutgoing(newOne);
XLink::ClearIncoming(newOne);
if(incomeTailNum < oldOne->income.tailNum){
delete[] newIncome.tails; delete[] newIncome.tails;
newIncome.tails = new XTensor*[newIncome.tailNum];
}
/* incoming nodes for the new node */ /* incoming nodes for the new node */
newIncome.SetType(oldOne->income.typeID);
newIncome.head = newOne; newIncome.head = newOne;
newIncome.tailNum = oldOne->income.tailNum; newIncome.tailNum = oldOne->income.tailNum;
newIncome.tails = new XTensor*[newIncome.tailNum];
memcpy(newIncome.tails, oldOne->income.tails, sizeof(XTensor*) * newIncome.tailNum); memcpy(newIncome.tails, oldOne->income.tails, sizeof(XTensor*) * newIncome.tailNum);
/* update the link to each child node */ /* update the link to each child node */
...@@ -307,10 +346,14 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne) ...@@ -307,10 +346,14 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne)
} }
} }
if(outgoTailNum < oldOne->outgo.tailNum){
delete[] newOutgo.tails;
newOutgo.tails = new XTensor*[newOutgo.tailNum];
}
/* outgoing nodes for the new node */ /* outgoing nodes for the new node */
newOutgo.head = newOne; newOutgo.head = newOne;
newOutgo.tailNum = oldOne->outgo.tailNum; newOutgo.tailNum = oldOne->outgo.tailNum;
newOutgo.tails = new XTensor*[newOutgo.tailNum];
memcpy(newOutgo.tails, oldOne->outgo.tails, sizeof(XTensor*) * newOutgo.tailNum); memcpy(newOutgo.tails, oldOne->outgo.tails, sizeof(XTensor*) * newOutgo.tailNum);
/* update the link to each parent node */ /* update the link to each parent node */
...@@ -385,7 +428,6 @@ void XLink::CheckNetwork(XTensor * root) ...@@ -385,7 +428,6 @@ void XLink::CheckNetwork(XTensor * root)
} }
CheckNTErrors(hit, "Wrong outgoing edge!"); CheckNTErrors(hit, "Wrong outgoing edge!");
} }
} }
XLink &outgo = root->outgo; XLink &outgo = root->outgo;
...@@ -397,15 +439,15 @@ void XLink::CheckNetwork(XTensor * root) ...@@ -397,15 +439,15 @@ void XLink::CheckNetwork(XTensor * root)
XTensor * parent = outgo.tails[i]; XTensor * parent = outgo.tails[i];
if(parent == NULL) if(parent == NULL)
continue; continue;
XLink & parentOutgo = parent->outgo; XLink & parentIncome = parent->income;
bool hit = false; bool hit = false;
for(int j = 0; j < parentOutgo.tailNum; j++){ for(int j = 0; j < parentIncome.tailNum; j++){
if(parentOutgo.tails[j] == root){ if(parentIncome.tails[j] == root){
hit = true; hit = true;
break; break;
} }
} }
CheckNTErrors(hit, "Wrong outgoing edge!"); CheckNTErrors(hit, "Wrong incoming edge!");
} }
} }
...@@ -429,7 +471,7 @@ void XLink::ShowNetwork(FILE * file, XTensor * root) ...@@ -429,7 +471,7 @@ void XLink::ShowNetwork(FILE * file, XTensor * root)
fprintf(file, "income[%d]: null ", income.tailNum); fprintf(file, "income[%d]: null ", income.tailNum);
} }
else{ else{
fprintf(file, "income[%d]: ", income.tailNum); fprintf(file, "income[%d, %s]: ", income.tailNum, GetOPName(income.typeID));
for(int i = 0; i < income.tailNum; i++){ for(int i = 0; i < income.tailNum; i++){
XTensor * child = income.tails[i]; XTensor * child = income.tails[i];
if(child == NULL) if(child == NULL)
...@@ -438,13 +480,14 @@ void XLink::ShowNetwork(FILE * file, XTensor * root) ...@@ -438,13 +480,14 @@ void XLink::ShowNetwork(FILE * file, XTensor * root)
fprintf(file, "%d ", child->id); fprintf(file, "%d ", child->id);
} }
} }
fprintf(stderr, ", ");
XLink &outgo = root->outgo; XLink &outgo = root->outgo;
if(outgo.head == NULL){ if(outgo.head == NULL || outgo.tailNum == 0){
fprintf(file, "outgo[%d]: null ", outgo.tailNum); fprintf(file, "outgo[%d]: null ", outgo.tailNum);
} }
else{ else{
fprintf(file, "outgo[%d]: ", income.tailNum); fprintf(file, "outgo[%d]: ", outgo.tailNum);
for(int i = 0; i < outgo.tailNum; i++){ for(int i = 0; i < outgo.tailNum; i++){
XTensor * parent = outgo.tails[i]; XTensor * parent = outgo.tails[i];
if(parent == NULL) if(parent == NULL)
......
...@@ -96,6 +96,10 @@ struct XLink ...@@ -96,6 +96,10 @@ struct XLink
static static
void ClearIncoming(XTensor * node); void ClearIncoming(XTensor * node);
/* clear the outgoing node list of tensor node */
static
void ClearOutgoing(XTensor * node);
/* set edge type id and name */ /* set edge type id and name */
void SetType(int id); void SetType(int id);
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_ARITHMETIC 10000 #define MATH_ARITHMETIC 0x00001000
#define MATH_SUM MATH_ARITHMETIC + 1 #define MATH_SUM MATH_ARITHMETIC + 1
#define MATH_MULTIPLY MATH_SUM + 1 #define MATH_MULTIPLY MATH_SUM + 1
#define MATH_SCALEANDSHIFT MATH_MULTIPLY + 1 #define MATH_SCALEANDSHIFT MATH_MULTIPLY + 1
......
...@@ -127,7 +127,7 @@ XTensor::XTensor(const XTensor * reference) ...@@ -127,7 +127,7 @@ XTensor::XTensor(const XTensor * reference)
/* /*
constructor constructor
>> myOrder - order of the tensor >> myOrder - order of the tensor
>> myDevID - prefered device id >> myDevID - device id
>> myMem - memory pool used to allocating the data array >> myMem - memory pool used to allocating the data array
*/ */
XTensor::XTensor(const int myOrder, int myDevID, XMem * myMem) XTensor::XTensor(const int myOrder, int myDevID, XMem * myMem)
...@@ -165,10 +165,11 @@ constructor ...@@ -165,10 +165,11 @@ constructor
>> myDimSize - the size of each dimension >> myDimSize - the size of each dimension
>> myDataType - unit size (e.g., int, float, and double) >> myDataType - unit size (e.g., int, float, and double)
>> myDenseRatio - how often an element has non-zero value >> myDenseRatio - how often an element has non-zero value
>> myDevID - device id
>> myMem - memory pool used to allocating the data array >> myMem - memory pool used to allocating the data array
*/ */
XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType, XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType,
const float myDenseRatio, XMem * myMem) const float myDenseRatio, int myDevID, XMem * myMem)
{ {
CheckNTErrors((myOrder > 0), "Illegal tensor order1"); CheckNTErrors((myOrder > 0), "Illegal tensor order1");
...@@ -180,6 +181,7 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP ...@@ -180,6 +181,7 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP
memset(isAllValued, 0, sizeof(bool) * MAX_TENSOR_DIM_NUM); memset(isAllValued, 0, sizeof(bool) * MAX_TENSOR_DIM_NUM);
mem = myMem; mem = myMem;
devID = myMem != NULL ? myMem->devID : myDevID;
data = NULL; data = NULL;
dataHost = NULL; dataHost = NULL;
dataType = DEFAULT_DTYPE; dataType = DEFAULT_DTYPE;
...@@ -229,18 +231,33 @@ XTensor::XTensor(const XTensor &reference) ...@@ -229,18 +231,33 @@ XTensor::XTensor(const XTensor &reference)
XLink::CopyIncoming(&reference, this); XLink::CopyIncoming(&reference, this);
} }
isInit = false; isInit = true;
isTmp = false; isTmp = reference.isTmp;
} }
/* de-constructor */ /* de-constructor */
XTensor::~XTensor() XTensor::~XTensor()
{ {
DestroyData(); /* We make a hard copy of the tensor to keep
the connectivity of the graph. To kill memory
leak, we release the data of the new tensor
when its parent is deleted (see ClearIncoming). */
if(isTmp && outgo.tailNum > 0){
int dims[MAX_TENSOR_DIM_NUM];
memcpy(dims, dimSize, order * sizeof(int));
dims[0] = -dims[0];
XTensor * newTensor = new XTensor(order, dims, dataType, denseRatio, devID, mem);
newTensor->data = data;
data = NULL; data = NULL;
dataHost = NULL;
mem = NULL; XLink::Replace(this, newTensor);
}
XLink::ClearOutgoing(this);
XLink::ClearIncoming(this); XLink::ClearIncoming(this);
DestroyData();
} }
/* delete data arrays */ /* delete data arrays */
...@@ -284,7 +301,7 @@ void XTensor::ShallowCopy(const XTensor &tensor) ...@@ -284,7 +301,7 @@ void XTensor::ShallowCopy(const XTensor &tensor)
/* overloading of the equal-sign */ /* overloading of the equal-sign */
XTensor& XTensor::operator= (const XTensor& tensor) XTensor& XTensor::operator= (const XTensor& tensor)
{ {
/* hard copy of data array */ /* hard copy of the data array */
int size = unitNum * unitSize; int size = unitNum * unitSize;
if( isInit && !isSparse && !tensor.isSparse && if( isInit && !isSparse && !tensor.isSparse &&
size == tensor.unitNum * tensor.unitSize && size == tensor.unitNum * tensor.unitSize &&
...@@ -1343,16 +1360,25 @@ void XTensor::Dump(FILE * file, const char * label, const int n, const int verbo ...@@ -1343,16 +1360,25 @@ void XTensor::Dump(FILE * file, const char * label, const int n, const int verbo
if (label != NULL) if (label != NULL)
fprintf(file, "%s ", label); fprintf(file, "%s ", label);
if(isInit){
fprintf(file, "order=%d dimsize=", order); fprintf(file, "order=%d dimsize=", order);
for (int i = 0; i < order; i++) { for (int i = 0; i < order; i++) {
fprintf(file, "%d", dimSize[i]); fprintf(file, "%d", dimSize[i]);
if (i < order - 1) if (i < order - 1)
fprintf(file, ","); fprintf(file, ",");
} }
}
else{
fprintf(file, "order=-1 dimsize=-1");
}
fprintf(file, " dtype=%s dense=%f\n", GetDataTypeName(dataType), denseRatio); fprintf(file, " dtype=%s dense=%f\n", GetDataTypeName(dataType), denseRatio);
if(!isInit){
fprintf(file, "NULL");
}
if (!isSparse) { if (!isSparse) {
if (dataType == DEFAULT_DTYPE) { if (dataType == DEFAULT_DTYPE) {
if (unitNum > 0) { if (unitNum > 0) {
...@@ -1811,7 +1837,7 @@ XTensor * NewTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_ ...@@ -1811,7 +1837,7 @@ XTensor * NewTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_
const float myDenseRatio, const int myDevID, XMem * myMem) const float myDenseRatio, const int myDevID, XMem * myMem)
{ {
if(myMem != NULL) if(myMem != NULL)
return new XTensor(myOrder, myDimSize, myDataType, myDenseRatio, myMem); return new XTensor(myOrder, myDimSize, myDataType, myDenseRatio, myDevID, myMem);
else{ else{
XTensor * tensor = new XTensor(); XTensor * tensor = new XTensor();
InitTensor(tensor, myOrder, myDimSize, myDataType, myDenseRatio, myDevID, myMem); InitTensor(tensor, myOrder, myDimSize, myDataType, myDenseRatio, myDevID, myMem);
...@@ -1982,7 +2008,9 @@ XTensor * NewTensor(XTensor * a, bool isFilledData) ...@@ -1982,7 +2008,9 @@ XTensor * NewTensor(XTensor * a, bool isFilledData)
if(!isFilledData) if(!isFilledData)
dims[0] = -dims[0]; dims[0] = -dims[0];
XTensor * newTensor = new XTensor(a->order, dims, a->dataType, a->denseRatio, a->mem); XTensor * newTensor = new XTensor(a->order, dims,
a->dataType, a->denseRatio,
a->devID, a->mem);
delete[] dims; delete[] dims;
......
...@@ -167,7 +167,7 @@ public: ...@@ -167,7 +167,7 @@ public:
/* constructor */ /* constructor */
XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType, XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType,
const float myDenseRatio, XMem * myMem); const float myDenseRatio, int myDevID, XMem * myMem);
/* copy constructor */ /* copy constructor */
XTensor(const XTensor &reference); XTensor(const XTensor &reference);
......
...@@ -89,9 +89,9 @@ void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -89,9 +89,9 @@ void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA,
void * ap = (char*)a->data + aRealBlockSize * p; void * ap = (char*)a->data + aRealBlockSize * p;
void * bp = (char*)b->data + bRealBlockSize * p; void * bp = (char*)b->data + bRealBlockSize * p;
void * cp = (char*)c->data + cRealBlockSize * p; void * cp = (char*)c->data + cRealBlockSize * p;
XTensor * ai = new XTensor(2, aDimSize, a->dataType, a->denseRatio, a->mem); XTensor * ai = new XTensor(2, aDimSize, a->dataType, a->denseRatio, a->devID, a->mem);
XTensor * bi = new XTensor(2, bDimSize, b->dataType, b->denseRatio, b->mem); XTensor * bi = new XTensor(2, bDimSize, b->dataType, b->denseRatio, b->devID, b->mem);
XTensor * ci = new XTensor(2, cDimSize, c->dataType, c->denseRatio, c->mem); XTensor * ci = new XTensor(2, cDimSize, c->dataType, c->denseRatio, c->devID, c->mem);
ai->data = ap; ai->data = ap;
bi->data = bp; bi->data = bp;
ci->data = cp; ci->data = cp;
......
...@@ -220,7 +220,9 @@ void Merge(XList * smalls, XTensor * big, int whereToMerge) ...@@ -220,7 +220,9 @@ void Merge(XList * smalls, XTensor * big, int whereToMerge)
dimSizeTMP[smallsItem0->order] = -mergeNum; dimSizeTMP[smallsItem0->order] = -mergeNum;
XMem * mem = smallsItem0->mem; XMem * mem = smallsItem0->mem;
XTensor * tensorTMP = new XTensor(smallsItem0->order + 1, dimSizeTMP, smallsItem0->dataType, smallsItem0->denseRatio, mem); XTensor * tensorTMP = new XTensor(smallsItem0->order + 1, dimSizeTMP,
smallsItem0->dataType, smallsItem0->denseRatio,
smallsItem0->devID, mem);
int size = mergeNum * itemSize; int size = mergeNum * itemSize;
void * dataTMP = NULL; void * dataTMP = NULL;
......
...@@ -197,7 +197,7 @@ void Split(XTensor * big, XList * smalls, int whereToSplit, int splitNum) ...@@ -197,7 +197,7 @@ void Split(XTensor * big, XList * smalls, int whereToSplit, int splitNum)
dimSizeTMP[big->order] = -splitNum; dimSizeTMP[big->order] = -splitNum;
XMem * mem = big->mem; XMem * mem = big->mem;
XTensor* tensorTMP = new XTensor(big->order + 1, dimSizeTMP, big->dataType, big->denseRatio, mem); XTensor* tensorTMP = new XTensor(big->order + 1, dimSizeTMP, big->dataType, big->denseRatio, big->devID, mem);
int size = big->unitNum * big->unitSize; int size = big->unitNum * big->unitSize;
void * dataTMP = NULL; void * dataTMP = NULL;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论