Commit 4acc236a by xiaotong

new code for XLink

parent 85ee1664
......@@ -56,7 +56,59 @@ void XLink::Reset()
params = NULL;
tailNum = 0;
paramNum = 0;
type[0] = 0;
type[0] = 0;
}
/* clear it */
void XLink::Clear()
{
head = NULL;
tailNum = 0;
paramNum = 0;
type[0] = 0;
}
/* reset tails */
void XLink::ClearTail()
{
tailNum = 0;
}
/*
clear the incoming node list of tensor node
>> node - the node to be cleared
*/
void XLink::ClearIncoming(XTensor * node)
{
if(node == NULL)
return;
XLink &income = node->income;
for(int i = 0; i < income.tailNum; i++){
/* for a incoming node */
XTensor * child = income.tails[i];
XLink &childOutgo = child->outgo;
CheckNTErrors(childOutgo.tailNum > 0, "The node must have outgoing edges!");
/* we check for each child node and remove the link to current node */
for(int j = 0; j < childOutgo.tailNum; j++){
if(childOutgo.tails[j] == node){
memcpy(childOutgo.tails + j,
childOutgo.tails + j + 1,
(childOutgo.tailNum - 1 - j) * sizeof(XTensor*));
childOutgo.tailNum--;
}
}
if(childOutgo.tailNum == 0)
delete child;
}
income.ClearTail();
income.tailNum = 0;
}
/*
......@@ -229,6 +281,59 @@ void XLink::AddParamToHeadInt(XTensor * h, int param)
return;
h->income.AddParam(&param, sizeof(int));
}
/*
replace a node with another, i.e., we redirect the links to the new node
>> oldOne - the node to be replaced
>> newOne - the new node
*/
void XLink::Replace(XTensor * oldOne, XTensor * newOne)
{
if(oldOne == NULL || newOne == NULL)
return;
XLink::ClearIncoming(newOne);
XLink &newIncome = newOne->income;
XLink &newOutgo = newOne->outgo;
delete[] newIncome.tails;
/* incoming nodes for the new node */
newIncome.tailNum = oldOne->income.tailNum;
newIncome.tails = new XTensor*[newIncome.tailNum];
memcpy(newIncome.tails, oldOne->income.tails, sizeof(XTensor*) * newIncome.tailNum);
/* update the link to each child node */
for(int i = 0; i < newIncome.tailNum; i++){
XTensor * child = newIncome.tails[i];
XLink &childOutgo = child->outgo;
for(int j = 0; j < childOutgo.tailNum; j++){
if(childOutgo.tails[j] == oldOne){
childOutgo.tails[j] = newOne;
}
}
}
/* outgoing nodes for the new node */
newOutgo.tailNum = oldOne->income.tailNum;
newOutgo.tails = new XTensor*[newOutgo.tailNum];
memcpy(newOutgo.tails, oldOne->income.tails, sizeof(XTensor*) * newOutgo.tailNum);
/* update the link to each parent node */
for(int i = 0; i < newOutgo.tailNum; i++){
XTensor * parent = newOutgo.tails[i];
XLink &parentIncome = parent->income;
for(int j = 0; j < parentIncome.tailNum; j++){
if(parentIncome.tails[j] == oldOne){
parentIncome.tails[j] = newOne;
}
}
}
XLink &oldOutgo = oldOne->outgo;
ClearIncoming(oldOne);
oldOne->outgo.tailNum = 0;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -86,6 +86,16 @@ struct XLink
/* reset it */
void Reset();
/* clear it */
void Clear();
/* clear tails */
void ClearTail();
/* clear the incoming node list of tensor node */
static
void ClearIncoming(XTensor * node);
/* set edge type id and name */
void SetType(int id);
......@@ -119,6 +129,10 @@ struct XLink
/* add an integer parameter */
static
void AddParamToHeadInt(XTensor * h, int param);
/* replace a node with another, i.e., we redirect the links to the new node */
static
void Replace(XTensor * oldOne, XTensor * newOne);
};
} // namespace nts(NiuTrans.Tensor)
......
......@@ -81,6 +81,7 @@ XTensor::XTensor()
isDefaultDType = true;
isInGlobalMem = false;
isInit = false;
isTmp = false;
}
/* constructor */
......@@ -93,6 +94,7 @@ XTensor::XTensor(XTensor * reference)
denseRatio = 1.0F;
isDefaultDType = true;
isInit = false;
isTmp = false;
InitTensor(this, reference);
}
......@@ -127,6 +129,7 @@ XTensor::XTensor(const int myOrder, int myDevID, XMem * myMem)
isDefaultDType = true;
isInGlobalMem = false;
isInit = false;
isTmp = false;
}
/*
......@@ -157,6 +160,7 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP
isDefaultDType = true;
isInGlobalMem = false;
isInit = false;
isTmp = false;
Resize(myOrder, myDimSize, myDataType, myDenseRatio);
}
......@@ -168,6 +172,7 @@ XTensor::~XTensor()
data = NULL;
dataHost = NULL;
mem = NULL;
XLink::ClearIncoming(this);
}
/* delete data arrays */
......@@ -186,9 +191,32 @@ void XTensor::DestroyData()
dataHost = NULL;
}
/*
shallow copy of tensor
Note that we do not copy data array here
>> tensor - the source tensor
*/
void XTensor::ShallowCopy(const XTensor &tensor)
{
order = tensor.order;
memcpy(dimSize, tensor.dimSize, sizeof(int) * MAX_TENSOR_DIM_NUM);
memcpy(dimSizeRDI, tensor.dimSizeRDI, sizeof(int) * MAX_TENSOR_DIM_NUM);
dataType = tensor.dataType;
unitSize = tensor.unitSize;
unitNum = tensor.unitNum;
isSparse = tensor.isSparse;
unitNumNonZero = tensor.unitNumNonZero;
denseRatio = tensor.denseRatio;
isShared = tensor.isShared;
isDefaultDType = tensor.isDefaultDType;
isInGlobalMem = tensor.isInGlobalMem;
memcpy(isAllValued, tensor.isAllValued, sizeof(bool) * MAX_TENSOR_DIM_NUM);
}
/* overloading of the equal-sign */
XTensor& XTensor::operator = (const XTensor& tensor)
XTensor& XTensor::operator = (XTensor& tensor)
{
/* hard copy of data array */
int size = unitNum * unitSize;
if( isInit && !isSparse && !tensor.isSparse &&
size == tensor.unitNum * tensor.unitSize &&
......@@ -217,20 +245,14 @@ XTensor& XTensor::operator = (const XTensor& tensor)
XMemCopy(data, devID, tensor.data, tensor.devID, size);
}
order = tensor.order;
memcpy(dimSize, tensor.dimSize, sizeof(int) * MAX_TENSOR_DIM_NUM);
memcpy(dimSizeRDI, tensor.dimSizeRDI, sizeof(int) * MAX_TENSOR_DIM_NUM);
dataType = tensor.dataType;
unitSize = tensor.unitSize;
unitNum = tensor.unitNum;
isSparse = tensor.isSparse;
unitNumNonZero = tensor.unitNumNonZero;
denseRatio = tensor.denseRatio;
isShared = tensor.isShared;
isDefaultDType = tensor.isDefaultDType;
isInGlobalMem = tensor.isInGlobalMem;
memcpy(isAllValued, tensor.isAllValued, sizeof(bool) * MAX_TENSOR_DIM_NUM);
/* copy member variables */
ShallowCopy(tensor);
isInit = true;
isTmp = false;
/* create tensor links for the new tensor */
XLink::Replace(&tensor, this);
return *this;
}
......
......@@ -130,6 +130,9 @@ struct XTensor
/* indicates whether the tensor is initialized or not */
bool isInit;
/* indicates whether the tensor is created temporarily */
bool isTmp;
/*
the link used to form networks. Note that when we compute on tensors, we actually create a
......@@ -167,8 +170,11 @@ struct XTensor
/* delete data arrays */
void DestroyData();
/* shallow copy of tensor */
void ShallowCopy(const XTensor &tensor);
/* overloading of the equal-sign */
XTensor& operator = (const XTensor& tensor);
XTensor& operator = (XTensor &tensor);
/* judge whether the two matrices are in the same type and size */
static
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论