Commit f8ddb15b by xiaotong

new XLink methods

parent bc12dbd2
...@@ -78,8 +78,13 @@ void SmallTest() ...@@ -78,8 +78,13 @@ void SmallTest()
XTensor c = a * b + a; XTensor c = a * b + a;
int nnn = 1;
XTensor d = a + b + c.Lin(0.5F); XTensor d = a + b + c.Lin(0.5F);
XLink::CheckNetwork(&d);
XLink::ShowNetwork(stderr, &b);
a.Dump(stderr, "a: "); a.Dump(stderr, "a: ");
b.Dump(stderr, "b: "); b.Dump(stderr, "b: ");
c.Dump(stderr, "c: "); c.Dump(stderr, "c: ");
......
...@@ -100,10 +100,11 @@ void XLink::ClearIncoming(XTensor * node) ...@@ -100,10 +100,11 @@ void XLink::ClearIncoming(XTensor * node)
childOutgo.tails + j + 1, childOutgo.tails + j + 1,
(childOutgo.tailNum - 1 - j) * sizeof(XTensor*)); (childOutgo.tailNum - 1 - j) * sizeof(XTensor*));
childOutgo.tailNum--; childOutgo.tailNum--;
break;
} }
} }
if(childOutgo.tailNum == 0) if(child->isTmp && childOutgo.tailNum == 0)
delete child; delete child;
} }
...@@ -120,7 +121,7 @@ void XLink::SetType(int id) ...@@ -120,7 +121,7 @@ void XLink::SetType(int id)
type[0] = 0; type[0] = 0;
strcpy(type, GetOPName(id)); strcpy(type, GetOPName(id));
typeID = id; typeID = id;
CheckNTErrors(!strcmp(type, "NULL"), "illegal edge type name!"); CheckNTErrors(strcmp(type, "NULL"), "illegal edge type name!");
} }
/* /*
...@@ -199,7 +200,7 @@ create a hyperedge with two input tensors and a output tensor ...@@ -199,7 +200,7 @@ create a hyperedge with two input tensors and a output tensor
*/ */
void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id) void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id)
{ {
if(h != NULL) if(h == NULL)
return; return;
XList list(2); XList list(2);
...@@ -225,14 +226,19 @@ void XLink::MakeLink(XList * list, XTensor * h, int id) ...@@ -225,14 +226,19 @@ void XLink::MakeLink(XList * list, XTensor * h, int id)
for(int i = 0; i < list->count; i++){ for(int i = 0; i < list->count; i++){
XTensor * t = (XTensor*)list->GetItem(i); XTensor * t = (XTensor*)list->GetItem(i);
if(t == NULL)
continue;
income.AddTail(t); income.AddTail(t);
} }
/* backward */ /* backward */
for(int i = 0; i < list->count; i++){ for(int i = 0; i < list->count; i++){
XTensor * t = (XTensor*)list->GetItem(i); XTensor * t = (XTensor*)list->GetItem(i);
if(t == NULL)
continue;
XLink &outgo = t->outgo; XLink &outgo = t->outgo;
CheckNTErrors(outgo.head != t, "Wrong head of the hyperedge!"); CheckNTErrors(outgo.head == NULL || outgo.head == t,
"Wrong head of the hyperedge!");
outgo.AddTail(h); outgo.AddTail(h);
} }
} }
...@@ -278,6 +284,7 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne) ...@@ -278,6 +284,7 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne)
delete[] newIncome.tails; delete[] newIncome.tails;
/* incoming nodes for the new node */ /* incoming nodes for the new node */
newIncome.head = newOne;
newIncome.tailNum = oldOne->income.tailNum; newIncome.tailNum = oldOne->income.tailNum;
newIncome.tails = new XTensor*[newIncome.tailNum]; newIncome.tails = new XTensor*[newIncome.tailNum];
memcpy(newIncome.tails, oldOne->income.tails, sizeof(XTensor*) * newIncome.tailNum); memcpy(newIncome.tails, oldOne->income.tails, sizeof(XTensor*) * newIncome.tailNum);
...@@ -286,28 +293,173 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne) ...@@ -286,28 +293,173 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne)
for(int i = 0; i < newIncome.tailNum; i++){ for(int i = 0; i < newIncome.tailNum; i++){
XTensor * child = newIncome.tails[i]; XTensor * child = newIncome.tails[i];
XLink &childOutgo = child->outgo; XLink &childOutgo = child->outgo;
bool hit = false;
for(int j = 0; j < childOutgo.tailNum; j++){ for(int j = 0; j < childOutgo.tailNum; j++){
if(childOutgo.tails[j] == oldOne){ if(childOutgo.tails[j] == oldOne){
childOutgo.tails[j] = newOne; childOutgo.tails[j] = newOne;
hit = true;
break;
}
} }
if(childOutgo.tailNum > 0){
CheckNTErrors(hit, "No proper node found in child.outgo edge!");
} }
} }
/* outgoing nodes for the new node */ /* outgoing nodes for the new node */
newOutgo.tailNum = oldOne->income.tailNum; newOutgo.head = newOne;
newOutgo.tailNum = oldOne->outgo.tailNum;
newOutgo.tails = new XTensor*[newOutgo.tailNum]; newOutgo.tails = new XTensor*[newOutgo.tailNum];
memcpy(newOutgo.tails, oldOne->income.tails, sizeof(XTensor*) * newOutgo.tailNum); memcpy(newOutgo.tails, oldOne->outgo.tails, sizeof(XTensor*) * newOutgo.tailNum);
/* update the link to each parent node */ /* update the link to each parent node */
for(int i = 0; i < newOutgo.tailNum; i++){ for(int i = 0; i < newOutgo.tailNum; i++){
XTensor * parent = newOutgo.tails[i]; XTensor * parent = newOutgo.tails[i];
XLink &parentIncome = parent->income; XLink &parentIncome = parent->income;
bool hit = false;
for(int j = 0; j < parentIncome.tailNum; j++){ for(int j = 0; j < parentIncome.tailNum; j++){
if(parentIncome.tails[j] == oldOne){ if(parentIncome.tails[j] == oldOne){
parentIncome.tails[j] = newOne; parentIncome.tails[j] = newOne;
hit = true;
}
}
if(parentIncome.tailNum > 0){
CheckNTErrors(hit, "No proper node found in parent.income edge!");
}
}
}
/*
copy incoming edges of a given node
>> reference - the node we copy from
>> target - where we copy to
*/
void XLink::CopyIncoming(const XTensor * reference, XTensor * target)
{
CheckNTErrors(reference && target, "Empty input tensors!");
ClearIncoming(target);
int tailNum = reference->income.tailNum;
XList tails(tailNum);
for(int i = 0; i < tailNum; i++){
XTensor * tail = (XTensor*)reference->income.tails[i];
tails.Add(tail);
}
MakeLink(&tails, target, reference->id);
int paraNum = reference->income.paramNum;
target->income.paramNum = paraNum;
delete[] (char*)target->income.params;
int size = paraNum * reference->income.paramSize;
target->income.params = new char[size];
memcpy(target->income.params, reference->income.params, size);
}
/*
check the correctness of the network encoded in a root node (tensor)
>> root - pointer to the root node
*/
void XLink::CheckNetwork(XTensor * root)
{
XLink &income = root->income;
if(income.head == NULL){
CheckNTErrors(income.tailNum == 0, "Wrong number of the incoming edge tails!");
}
else{
for(int i = 0; i < income.tailNum; i++){
XTensor * child = income.tails[i];
if(child == NULL)
continue;
XLink & childOutgo = child->outgo;
bool hit = false;
for(int j = 0; j < childOutgo.tailNum; j++){
if(childOutgo.tails[j] == root){
hit = true;
break;
}
}
CheckNTErrors(hit, "Wrong outgoing edge!");
} }
}
XLink &outgo = root->outgo;
if(outgo.head == NULL){
CheckNTErrors(outgo.tailNum == 0, "Wrong number of the incoming edge tails!");
}
else{
for(int i = 0; i < outgo.tailNum; i++){
XTensor * parent = outgo.tails[i];
if(parent == NULL)
continue;
XLink & parentOutgo = parent->outgo;
bool hit = false;
for(int j = 0; j < parentOutgo.tailNum; j++){
if(parentOutgo.tails[j] == root){
hit = true;
break;
}
}
CheckNTErrors(hit, "Wrong outgoing edge!");
}
}
for(int i = 0; i < income.tailNum; i++){
XTensor * child = income.tails[i];
CheckNetwork(child);
}
}
/*
show the network encoded in a root node (tensor)
>> file - file to dump information
>> root - pointer to the root node
*/
void XLink::ShowNetwork(FILE * file, XTensor * root)
{
fprintf(file, "node %d - ", root->id);
XLink &income = root->income;
if(income.head == NULL){
fprintf(file, "income[%d]: null ", income.tailNum);
}
else{
fprintf(file, "income[%d]: ", income.tailNum);
for(int i = 0; i < income.tailNum; i++){
XTensor * child = income.tails[i];
if(child == NULL)
fprintf(file, "na ");
else
fprintf(file, "%d ", child->id);
} }
} }
XLink &outgo = root->outgo;
if(outgo.head == NULL){
fprintf(file, "outgo[%d]: null ", outgo.tailNum);
}
else{
fprintf(file, "outgo[%d]: ", income.tailNum);
for(int i = 0; i < outgo.tailNum; i++){
XTensor * parent = outgo.tails[i];
if(parent == NULL)
fprintf(file, "na ");
else
fprintf(file, "%d ", parent->id);
}
}
fprintf(stderr, "\n");
for(int i = 0; i < income.tailNum; i++){
XTensor * child = income.tails[i];
ShowNetwork(file, child);
}
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -133,6 +133,18 @@ struct XLink ...@@ -133,6 +133,18 @@ struct XLink
/* replace a node with another, i.e., we redirect the links to the new node */ /* replace a node with another, i.e., we redirect the links to the new node */
static static
void Replace(const XTensor * oldOne, XTensor * newOne); void Replace(const XTensor * oldOne, XTensor * newOne);
/* copy links of a given node */
static
void CopyIncoming(const XTensor * reference, XTensor * target);
/* check the correctness of the network encoded in a root node (tensor) */
static
void CheckNetwork(XTensor * root);
/* show the network encoded in a root node (tensor) */
static
void ShowNetwork(FILE * file, XTensor * root);
}; };
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -222,6 +222,13 @@ XTensor::XTensor(const XTensor &reference) ...@@ -222,6 +222,13 @@ XTensor::XTensor(const XTensor &reference)
CopyValues(&reference, this); CopyValues(&reference, this);
} }
if(reference.isTmp)
XLink::Replace(&reference, this);
else{
CheckNTErrors(outgo.tailNum == 0, "The node has outgoing edge to other nodes!");
XLink::CopyIncoming(&reference, this);
}
isInit = false; isInit = false;
isTmp = false; isTmp = false;
} }
...@@ -305,6 +312,8 @@ XTensor& XTensor::operator= (const XTensor& tensor) ...@@ -305,6 +312,8 @@ XTensor& XTensor::operator= (const XTensor& tensor)
isInit = true; isInit = true;
isTmp = false; isTmp = false;
CheckNTErrors(outgo.tailNum == 0, "The node has outgoing edge to other nodes!");
/* create tensor links for the new tensor */ /* create tensor links for the new tensor */
XLink::Replace(&tensor, this); XLink::Replace(&tensor, this);
......
...@@ -104,7 +104,7 @@ XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift) ...@@ -104,7 +104,7 @@ XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift)
_ScaleAndShift(&a, &b, scale, shift); _ScaleAndShift(&a, &b, scale, shift);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SUM); XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT);
XLink::AddParamToHead(&b, scale); XLink::AddParamToHead(&b, scale);
XLink::AddParamToHead(&b, shift); XLink::AddParamToHead(&b, shift);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论