Commit 8cae67f2 by xiaotong

1) fix the bug of unexpected breakpoints in tensor graph 2) add devID to XTensor…

1) fix the bug of unexpected breakpoints in tensor graph 2) add devID to XTensor constructor 3) avoid unneccessary data allocation in XLink
parent f8ddb15b
......@@ -45,8 +45,9 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(78);
/* a tiny test */
if(1)
if(true)
SmallTest();
return 0;
if(argc > 1 && !strcmp(argv[1], "-test"))
Test();
......@@ -68,25 +69,31 @@ void SmallTest()
{
XTensor a;
XTensor b;
XTensor c;
XTensor d;
InitTensor2D(&a, 2, 2);
InitTensor2D(&b, 2, 2);
a.SetZeroAll();
b.SetZeroAll();
a.Set2D(1.0F, 0, 0);
a.Set2D(2.0F, 1, 1);
b = Sum(a, Multiply(a, a));
XTensor c = a * b + a;
int nnn = 1;
//b = Sum(a, Multiply(a, a));
XTensor d = a + b + c.Lin(0.5F);
/* cannot write this !!!!!!!!!!!!! */
//XTensor c = a * b + a;
//XTensor d = a + b + c.Lin(0.5F);
c = a * b + a;
d = a + b + c.Lin(0.5F);
//d = a + b * b;
XLink::CheckNetwork(&d);
XLink::ShowNetwork(stderr, &b);
XLink::ShowNetwork(stderr, &d);
a.Dump(stderr, "a: ");
b.Dump(stderr, "b: ");
c.Dump(stderr, "c: ");
d.Dump(stderr, "d: ");
a.Dump(stderr, "a:");
b.Dump(stderr, "b:");
c.Dump(stderr, "c:");
d.Dump(stderr, "d:");
}
......@@ -75,6 +75,39 @@ void XLink::ClearTail()
}
/*
clear the outgoing node list of tensor node
>> node - the node to be cleared
*/
void XLink::ClearOutgoing(XTensor * node)
{
if(node == NULL)
return;
XLink &outgo = node->outgo;
for(int i = 0; i < outgo.tailNum; i++){
/* for each parent node */
XTensor * parent = outgo.tails[i];
XLink &parentIncome = parent->income;
CheckNTErrors(parentIncome.tailNum > 0, "The node must have incoming edges!");
/* we check for each parent node and remove the link to current node */
for(int j = 0; j < parentIncome.tailNum; j++){
if(parentIncome.tails[j] == node){
memcpy(parentIncome.tails + j, parentIncome.tails + j + 1,
sizeof(XTensor*) * (parentIncome.tailNum - 1 - j));
parentIncome.tailNum--;
break;
}
}
}
outgo.ClearTail();
}
/*
clear the incoming node list of tensor node
>> node - the node to be cleared
*/
......@@ -87,7 +120,7 @@ void XLink::ClearIncoming(XTensor * node)
for(int i = 0; i < income.tailNum; i++){
/* for a incoming node */
/* for each incoming node */
XTensor * child = income.tails[i];
XLink &childOutgo = child->outgo;
......@@ -96,9 +129,8 @@ void XLink::ClearIncoming(XTensor * node)
/* we check for each child node and remove the link to current node */
for(int j = 0; j < childOutgo.tailNum; j++){
if(childOutgo.tails[j] == node){
memcpy(childOutgo.tails + j,
childOutgo.tails + j + 1,
(childOutgo.tailNum - 1 - j) * sizeof(XTensor*));
memcpy(childOutgo.tails + j, childOutgo.tails + j + 1,
sizeof(XTensor*) * (childOutgo.tailNum - 1 - j));
childOutgo.tailNum--;
break;
}
......@@ -109,7 +141,6 @@ void XLink::ClearIncoming(XTensor * node)
}
income.ClearTail();
income.tailNum = 0;
}
/*
......@@ -239,6 +270,7 @@ void XLink::MakeLink(XList * list, XTensor * h, int id)
XLink &outgo = t->outgo;
CheckNTErrors(outgo.head == NULL || outgo.head == t,
"Wrong head of the hyperedge!");
outgo.SetHead(t);
outgo.AddTail(h);
}
}
......@@ -276,17 +308,24 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne)
{
if(oldOne == NULL || newOne == NULL)
return;
XLink &newIncome = newOne->income;
XLink &newOutgo = newOne->outgo;
int incomeTailNum = newIncome.tailNum;
int outgoTailNum = newOutgo.tailNum;
XLink::ClearOutgoing(newOne);
XLink::ClearIncoming(newOne);
XLink &newIncome = newOne->income;
XLink &newOutgo = newOne->outgo;
delete[] newIncome.tails;
if(incomeTailNum < oldOne->income.tailNum){
delete[] newIncome.tails;
newIncome.tails = new XTensor*[newIncome.tailNum];
}
/* incoming nodes for the new node */
newIncome.SetType(oldOne->income.typeID);
newIncome.head = newOne;
newIncome.tailNum = oldOne->income.tailNum;
newIncome.tails = new XTensor*[newIncome.tailNum];
memcpy(newIncome.tails, oldOne->income.tails, sizeof(XTensor*) * newIncome.tailNum);
/* update the link to each child node */
......@@ -306,11 +345,15 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne)
CheckNTErrors(hit, "No proper node found in child.outgo edge!");
}
}
if(outgoTailNum < oldOne->outgo.tailNum){
delete[] newOutgo.tails;
newOutgo.tails = new XTensor*[newOutgo.tailNum];
}
/* outgoing nodes for the new node */
newOutgo.head = newOne;
newOutgo.tailNum = oldOne->outgo.tailNum;
newOutgo.tails = new XTensor*[newOutgo.tailNum];
memcpy(newOutgo.tails, oldOne->outgo.tails, sizeof(XTensor*) * newOutgo.tailNum);
/* update the link to each parent node */
......@@ -385,7 +428,6 @@ void XLink::CheckNetwork(XTensor * root)
}
CheckNTErrors(hit, "Wrong outgoing edge!");
}
}
XLink &outgo = root->outgo;
......@@ -397,15 +439,15 @@ void XLink::CheckNetwork(XTensor * root)
XTensor * parent = outgo.tails[i];
if(parent == NULL)
continue;
XLink & parentOutgo = parent->outgo;
XLink & parentIncome = parent->income;
bool hit = false;
for(int j = 0; j < parentOutgo.tailNum; j++){
if(parentOutgo.tails[j] == root){
for(int j = 0; j < parentIncome.tailNum; j++){
if(parentIncome.tails[j] == root){
hit = true;
break;
}
}
CheckNTErrors(hit, "Wrong outgoing edge!");
CheckNTErrors(hit, "Wrong incoming edge!");
}
}
......@@ -429,7 +471,7 @@ void XLink::ShowNetwork(FILE * file, XTensor * root)
fprintf(file, "income[%d]: null ", income.tailNum);
}
else{
fprintf(file, "income[%d]: ", income.tailNum);
fprintf(file, "income[%d, %s]: ", income.tailNum, GetOPName(income.typeID));
for(int i = 0; i < income.tailNum; i++){
XTensor * child = income.tails[i];
if(child == NULL)
......@@ -438,13 +480,14 @@ void XLink::ShowNetwork(FILE * file, XTensor * root)
fprintf(file, "%d ", child->id);
}
}
fprintf(stderr, ", ");
XLink &outgo = root->outgo;
if(outgo.head == NULL){
if(outgo.head == NULL || outgo.tailNum == 0){
fprintf(file, "outgo[%d]: null ", outgo.tailNum);
}
else{
fprintf(file, "outgo[%d]: ", income.tailNum);
fprintf(file, "outgo[%d]: ", outgo.tailNum);
for(int i = 0; i < outgo.tailNum; i++){
XTensor * parent = outgo.tails[i];
if(parent == NULL)
......
......@@ -95,6 +95,10 @@ struct XLink
/* clear the incoming node list of tensor node */
static
void ClearIncoming(XTensor * node);
/* clear the outgoing node list of tensor node */
static
void ClearOutgoing(XTensor * node);
/* set edge type id and name */
void SetType(int id);
......
......@@ -28,7 +28,7 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_ARITHMETIC 10000
#define MATH_ARITHMETIC 0x00001000
#define MATH_SUM MATH_ARITHMETIC + 1
#define MATH_MULTIPLY MATH_SUM + 1
#define MATH_SCALEANDSHIFT MATH_MULTIPLY + 1
......
......@@ -127,7 +127,7 @@ XTensor::XTensor(const XTensor * reference)
/*
constructor
>> myOrder - order of the tensor
>> myDevID - prefered device id
>> myDevID - device id
>> myMem - memory pool used to allocating the data array
*/
XTensor::XTensor(const int myOrder, int myDevID, XMem * myMem)
......@@ -165,10 +165,11 @@ constructor
>> myDimSize - the size of each dimension
>> myDataType - unit size (e.g., int, float, and double)
>> myDenseRatio - how often an element has non-zero value
>> myDevID - device id
>> myMem - memory pool used to allocating the data array
*/
XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType,
const float myDenseRatio, XMem * myMem)
const float myDenseRatio, int myDevID, XMem * myMem)
{
CheckNTErrors((myOrder > 0), "Illegal tensor order1");
......@@ -180,6 +181,7 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP
memset(isAllValued, 0, sizeof(bool) * MAX_TENSOR_DIM_NUM);
mem = myMem;
devID = myMem != NULL ? myMem->devID : myDevID;
data = NULL;
dataHost = NULL;
dataType = DEFAULT_DTYPE;
......@@ -229,18 +231,33 @@ XTensor::XTensor(const XTensor &reference)
XLink::CopyIncoming(&reference, this);
}
isInit = false;
isTmp = false;
isInit = true;
isTmp = reference.isTmp;
}
/* de-constructor */
XTensor::~XTensor()
{
DestroyData();
data = NULL;
dataHost = NULL;
mem = NULL;
/* We make a hard copy of the tensor to keep
the connectivity of the graph. To kill memory
leak, we release the data of the new tensor
when its parent is deleted (see ClearIncoming). */
if(isTmp && outgo.tailNum > 0){
int dims[MAX_TENSOR_DIM_NUM];
memcpy(dims, dimSize, order * sizeof(int));
dims[0] = -dims[0];
XTensor * newTensor = new XTensor(order, dims, dataType, denseRatio, devID, mem);
newTensor->data = data;
data = NULL;
XLink::Replace(this, newTensor);
}
XLink::ClearOutgoing(this);
XLink::ClearIncoming(this);
DestroyData();
}
/* delete data arrays */
......@@ -284,7 +301,7 @@ void XTensor::ShallowCopy(const XTensor &tensor)
/* overloading of the equal-sign */
XTensor& XTensor::operator= (const XTensor& tensor)
{
/* hard copy of data array */
/* hard copy of the data array */
int size = unitNum * unitSize;
if( isInit && !isSparse && !tensor.isSparse &&
size == tensor.unitNum * tensor.unitSize &&
......@@ -1343,16 +1360,25 @@ void XTensor::Dump(FILE * file, const char * label, const int n, const int verbo
if (label != NULL)
fprintf(file, "%s ", label);
fprintf(file, "order=%d dimsize=", order);
for (int i = 0; i < order; i++) {
fprintf(file, "%d", dimSize[i]);
if (i < order - 1)
fprintf(file, ",");
if(isInit){
fprintf(file, "order=%d dimsize=", order);
for (int i = 0; i < order; i++) {
fprintf(file, "%d", dimSize[i]);
if (i < order - 1)
fprintf(file, ",");
}
}
else{
fprintf(file, "order=-1 dimsize=-1");
}
fprintf(file, " dtype=%s dense=%f\n", GetDataTypeName(dataType), denseRatio);
if(!isInit){
fprintf(file, "NULL");
}
if (!isSparse) {
if (dataType == DEFAULT_DTYPE) {
if (unitNum > 0) {
......@@ -1811,7 +1837,7 @@ XTensor * NewTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_
const float myDenseRatio, const int myDevID, XMem * myMem)
{
if(myMem != NULL)
return new XTensor(myOrder, myDimSize, myDataType, myDenseRatio, myMem);
return new XTensor(myOrder, myDimSize, myDataType, myDenseRatio, myDevID, myMem);
else{
XTensor * tensor = new XTensor();
InitTensor(tensor, myOrder, myDimSize, myDataType, myDenseRatio, myDevID, myMem);
......@@ -1982,7 +2008,9 @@ XTensor * NewTensor(XTensor * a, bool isFilledData)
if(!isFilledData)
dims[0] = -dims[0];
XTensor * newTensor = new XTensor(a->order, dims, a->dataType, a->denseRatio, a->mem);
XTensor * newTensor = new XTensor(a->order, dims,
a->dataType, a->denseRatio,
a->devID, a->mem);
delete[] dims;
......
......@@ -167,7 +167,7 @@ public:
/* constructor */
XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType,
const float myDenseRatio, XMem * myMem);
const float myDenseRatio, int myDevID, XMem * myMem);
/* copy constructor */
XTensor(const XTensor &reference);
......
......@@ -89,9 +89,9 @@ void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA,
void * ap = (char*)a->data + aRealBlockSize * p;
void * bp = (char*)b->data + bRealBlockSize * p;
void * cp = (char*)c->data + cRealBlockSize * p;
XTensor * ai = new XTensor(2, aDimSize, a->dataType, a->denseRatio, a->mem);
XTensor * bi = new XTensor(2, bDimSize, b->dataType, b->denseRatio, b->mem);
XTensor * ci = new XTensor(2, cDimSize, c->dataType, c->denseRatio, c->mem);
XTensor * ai = new XTensor(2, aDimSize, a->dataType, a->denseRatio, a->devID, a->mem);
XTensor * bi = new XTensor(2, bDimSize, b->dataType, b->denseRatio, b->devID, b->mem);
XTensor * ci = new XTensor(2, cDimSize, c->dataType, c->denseRatio, c->devID, c->mem);
ai->data = ap;
bi->data = bp;
ci->data = cp;
......
......@@ -220,7 +220,9 @@ void Merge(XList * smalls, XTensor * big, int whereToMerge)
dimSizeTMP[smallsItem0->order] = -mergeNum;
XMem * mem = smallsItem0->mem;
XTensor * tensorTMP = new XTensor(smallsItem0->order + 1, dimSizeTMP, smallsItem0->dataType, smallsItem0->denseRatio, mem);
XTensor * tensorTMP = new XTensor(smallsItem0->order + 1, dimSizeTMP,
smallsItem0->dataType, smallsItem0->denseRatio,
smallsItem0->devID, mem);
int size = mergeNum * itemSize;
void * dataTMP = NULL;
......
......@@ -197,7 +197,7 @@ void Split(XTensor * big, XList * smalls, int whereToSplit, int splitNum)
dimSizeTMP[big->order] = -splitNum;
XMem * mem = big->mem;
XTensor* tensorTMP = new XTensor(big->order + 1, dimSizeTMP, big->dataType, big->denseRatio, mem);
XTensor* tensorTMP = new XTensor(big->order + 1, dimSizeTMP, big->dataType, big->denseRatio, big->devID, mem);
int size = big->unitNum * big->unitSize;
void * dataTMP = NULL;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论