Commit 4acc236a by xiaotong

new code for XLink

parent 85ee1664
...@@ -56,7 +56,59 @@ void XLink::Reset() ...@@ -56,7 +56,59 @@ void XLink::Reset()
params = NULL; params = NULL;
tailNum = 0; tailNum = 0;
paramNum = 0; paramNum = 0;
type[0] = 0; type[0] = 0;
}
/* clear it */
void XLink::Clear()
{
head = NULL;
tailNum = 0;
paramNum = 0;
type[0] = 0;
}
/* reset tails */
void XLink::ClearTail()
{
tailNum = 0;
}
/*
clear the incoming node list of tensor node
>> node - the node to be cleared
*/
void XLink::ClearIncoming(XTensor * node)
{
if(node == NULL)
return;
XLink &income = node->income;
for(int i = 0; i < income.tailNum; i++){
/* for a incoming node */
XTensor * child = income.tails[i];
XLink &childOutgo = child->outgo;
CheckNTErrors(childOutgo.tailNum > 0, "The node must have outgoing edges!");
/* we check for each child node and remove the link to current node */
for(int j = 0; j < childOutgo.tailNum; j++){
if(childOutgo.tails[j] == node){
memcpy(childOutgo.tails + j,
childOutgo.tails + j + 1,
(childOutgo.tailNum - 1 - j) * sizeof(XTensor*));
childOutgo.tailNum--;
}
}
if(childOutgo.tailNum == 0)
delete child;
}
income.ClearTail();
income.tailNum = 0;
} }
/* /*
...@@ -229,6 +281,59 @@ void XLink::AddParamToHeadInt(XTensor * h, int param) ...@@ -229,6 +281,59 @@ void XLink::AddParamToHeadInt(XTensor * h, int param)
return; return;
h->income.AddParam(&param, sizeof(int)); h->income.AddParam(&param, sizeof(int));
} }
/*
replace a node with another, i.e., we redirect the links to the new node
>> oldOne - the node to be replaced
>> newOne - the new node
*/
void XLink::Replace(XTensor * oldOne, XTensor * newOne)
{
if(oldOne == NULL || newOne == NULL)
return;
XLink::ClearIncoming(newOne);
XLink &newIncome = newOne->income;
XLink &newOutgo = newOne->outgo;
delete[] newIncome.tails;
/* incoming nodes for the new node */
newIncome.tailNum = oldOne->income.tailNum;
newIncome.tails = new XTensor*[newIncome.tailNum];
memcpy(newIncome.tails, oldOne->income.tails, sizeof(XTensor*) * newIncome.tailNum);
/* update the link to each child node */
for(int i = 0; i < newIncome.tailNum; i++){
XTensor * child = newIncome.tails[i];
XLink &childOutgo = child->outgo;
for(int j = 0; j < childOutgo.tailNum; j++){
if(childOutgo.tails[j] == oldOne){
childOutgo.tails[j] = newOne;
}
}
}
/* outgoing nodes for the new node */
newOutgo.tailNum = oldOne->income.tailNum;
newOutgo.tails = new XTensor*[newOutgo.tailNum];
memcpy(newOutgo.tails, oldOne->income.tails, sizeof(XTensor*) * newOutgo.tailNum);
/* update the link to each parent node */
for(int i = 0; i < newOutgo.tailNum; i++){
XTensor * parent = newOutgo.tails[i];
XLink &parentIncome = parent->income;
for(int j = 0; j < parentIncome.tailNum; j++){
if(parentIncome.tails[j] == oldOne){
parentIncome.tails[j] = newOne;
}
}
}
XLink &oldOutgo = oldOne->outgo;
ClearIncoming(oldOne);
oldOne->outgo.tailNum = 0;
}
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
...@@ -86,6 +86,16 @@ struct XLink ...@@ -86,6 +86,16 @@ struct XLink
/* reset it */ /* reset it */
void Reset(); void Reset();
/* clear it */
void Clear();
/* clear tails */
void ClearTail();
/* clear the incoming node list of tensor node */
static
void ClearIncoming(XTensor * node);
/* set edge type id and name */ /* set edge type id and name */
void SetType(int id); void SetType(int id);
...@@ -119,6 +129,10 @@ struct XLink ...@@ -119,6 +129,10 @@ struct XLink
/* add an integer parameter */ /* add an integer parameter */
static static
void AddParamToHeadInt(XTensor * h, int param); void AddParamToHeadInt(XTensor * h, int param);
/* replace a node with another, i.e., we redirect the links to the new node */
static
void Replace(XTensor * oldOne, XTensor * newOne);
}; };
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -81,6 +81,7 @@ XTensor::XTensor() ...@@ -81,6 +81,7 @@ XTensor::XTensor()
isDefaultDType = true; isDefaultDType = true;
isInGlobalMem = false; isInGlobalMem = false;
isInit = false; isInit = false;
isTmp = false;
} }
/* constructor */ /* constructor */
...@@ -93,6 +94,7 @@ XTensor::XTensor(XTensor * reference) ...@@ -93,6 +94,7 @@ XTensor::XTensor(XTensor * reference)
denseRatio = 1.0F; denseRatio = 1.0F;
isDefaultDType = true; isDefaultDType = true;
isInit = false; isInit = false;
isTmp = false;
InitTensor(this, reference); InitTensor(this, reference);
} }
...@@ -127,6 +129,7 @@ XTensor::XTensor(const int myOrder, int myDevID, XMem * myMem) ...@@ -127,6 +129,7 @@ XTensor::XTensor(const int myOrder, int myDevID, XMem * myMem)
isDefaultDType = true; isDefaultDType = true;
isInGlobalMem = false; isInGlobalMem = false;
isInit = false; isInit = false;
isTmp = false;
} }
/* /*
...@@ -157,6 +160,7 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP ...@@ -157,6 +160,7 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP
isDefaultDType = true; isDefaultDType = true;
isInGlobalMem = false; isInGlobalMem = false;
isInit = false; isInit = false;
isTmp = false;
Resize(myOrder, myDimSize, myDataType, myDenseRatio); Resize(myOrder, myDimSize, myDataType, myDenseRatio);
} }
...@@ -168,6 +172,7 @@ XTensor::~XTensor() ...@@ -168,6 +172,7 @@ XTensor::~XTensor()
data = NULL; data = NULL;
dataHost = NULL; dataHost = NULL;
mem = NULL; mem = NULL;
XLink::ClearIncoming(this);
} }
/* delete data arrays */ /* delete data arrays */
...@@ -186,9 +191,32 @@ void XTensor::DestroyData() ...@@ -186,9 +191,32 @@ void XTensor::DestroyData()
dataHost = NULL; dataHost = NULL;
} }
/*
shallow copy of tensor
Note that we do not copy data array here
>> tensor - the source tensor
*/
void XTensor::ShallowCopy(const XTensor &tensor)
{
order = tensor.order;
memcpy(dimSize, tensor.dimSize, sizeof(int) * MAX_TENSOR_DIM_NUM);
memcpy(dimSizeRDI, tensor.dimSizeRDI, sizeof(int) * MAX_TENSOR_DIM_NUM);
dataType = tensor.dataType;
unitSize = tensor.unitSize;
unitNum = tensor.unitNum;
isSparse = tensor.isSparse;
unitNumNonZero = tensor.unitNumNonZero;
denseRatio = tensor.denseRatio;
isShared = tensor.isShared;
isDefaultDType = tensor.isDefaultDType;
isInGlobalMem = tensor.isInGlobalMem;
memcpy(isAllValued, tensor.isAllValued, sizeof(bool) * MAX_TENSOR_DIM_NUM);
}
/* overloading of the equal-sign */ /* overloading of the equal-sign */
XTensor& XTensor::operator = (const XTensor& tensor) XTensor& XTensor::operator = (XTensor& tensor)
{ {
/* hard copy of data array */
int size = unitNum * unitSize; int size = unitNum * unitSize;
if( isInit && !isSparse && !tensor.isSparse && if( isInit && !isSparse && !tensor.isSparse &&
size == tensor.unitNum * tensor.unitSize && size == tensor.unitNum * tensor.unitSize &&
...@@ -217,20 +245,14 @@ XTensor& XTensor::operator = (const XTensor& tensor) ...@@ -217,20 +245,14 @@ XTensor& XTensor::operator = (const XTensor& tensor)
XMemCopy(data, devID, tensor.data, tensor.devID, size); XMemCopy(data, devID, tensor.data, tensor.devID, size);
} }
order = tensor.order; /* copy member variables */
memcpy(dimSize, tensor.dimSize, sizeof(int) * MAX_TENSOR_DIM_NUM); ShallowCopy(tensor);
memcpy(dimSizeRDI, tensor.dimSizeRDI, sizeof(int) * MAX_TENSOR_DIM_NUM);
dataType = tensor.dataType;
unitSize = tensor.unitSize;
unitNum = tensor.unitNum;
isSparse = tensor.isSparse;
unitNumNonZero = tensor.unitNumNonZero;
denseRatio = tensor.denseRatio;
isShared = tensor.isShared;
isDefaultDType = tensor.isDefaultDType;
isInGlobalMem = tensor.isInGlobalMem;
memcpy(isAllValued, tensor.isAllValued, sizeof(bool) * MAX_TENSOR_DIM_NUM);
isInit = true; isInit = true;
isTmp = false;
/* create tensor links for the new tensor */
XLink::Replace(&tensor, this);
return *this; return *this;
} }
......
...@@ -130,6 +130,9 @@ struct XTensor ...@@ -130,6 +130,9 @@ struct XTensor
/* indicates whether the tensor is initialized or not */ /* indicates whether the tensor is initialized or not */
bool isInit; bool isInit;
/* indicates whether the tensor is created temporarily */
bool isTmp;
/* /*
the link used to form networks. Note that when we compute on tensors, we actually create a the link used to form networks. Note that when we compute on tensors, we actually create a
...@@ -167,8 +170,11 @@ struct XTensor ...@@ -167,8 +170,11 @@ struct XTensor
/* delete data arrays */ /* delete data arrays */
void DestroyData(); void DestroyData();
/* shallow copy of tensor */
void ShallowCopy(const XTensor &tensor);
/* overloading of the equal-sign */ /* overloading of the equal-sign */
XTensor& operator = (const XTensor& tensor); XTensor& operator = (XTensor &tensor);
/* judge whether the two matrices are in the same type and size */ /* judge whether the two matrices are in the same type and size */
static static
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论