From 59a7030a8865a42377517e84744848b1aba34a03 Mon Sep 17 00:00:00 2001 From: xiaotong <xiaotong@mail.neu.edu.cn> Date: Fri, 13 Jul 2018 10:47:33 +0800 Subject: [PATCH] net visit code --- source/network/XNet.cpp | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ source/network/XNet.h | 22 +++++++++++++++++++++- source/tensor/XTensor.cpp | 1 + source/tensor/XTensor.h | 3 +++ 4 files changed, 79 insertions(+), 1 deletion(-) diff --git a/source/network/XNet.cpp b/source/network/XNet.cpp index 01dea5d..7daf67f 100644 --- a/source/network/XNet.cpp +++ b/source/network/XNet.cpp @@ -23,6 +23,23 @@ namespace nts{ +unsigned unsigned int netIDGlobal = 0; +MUTEX_HANDLE netMutex; + +/* generate a network id */ +unsigned int MakeNetID() +{ + if(tensorIDGlobal == 0) + MUTEX_INIT(netMutex); + + MUTEX_LOCK(netMutex); + netIDGlobal += 3; + unsigned int id = netIDGlobal; + MUTEX_UNLOCK(netMutex); + + return id; +} + /* constructor */ XNet::XNet() { @@ -34,6 +51,15 @@ XNet::~XNet() { } +/* clear the network */ +void XNet::Clear() +{ + nodes.Clear(); + gradNodes.Clear(); + outputs.Clear(); + inputs.Clear(); +} + /* backward propagation to obtain gradient wrt. the loss/error function >> root - root node (output) of the network @@ -86,4 +112,32 @@ void XNet::Traverse(XList &roots) nodes.Clear(); } +/* +depth-first search given a node (Tarjan's algorithm for topological ordering) +>> node - the node to visit (mark 0:unvisited, 1:visiting, 2:done) +>> orders - topological order of the nodes +>> code - code of the network +*/ +void XNet::TarjanVisit(XTensor * node, XList &orders, const unsigned int code) +{ + if(node == NULL) + return; + + if(node->visitMark == code + 1){ + ShowNTErrors("There is a circle in the network\n"); + } + else if(node->visitMark <= code || node->visitMark > code + 2){ + node->visitMark = code + 1; + XLink &income = node->income; + for(int i = 0; i < income.tailNum; i++){ + XTensor * child = income.tails[i]; + if(child == NULL) + continue; + TarjanVisit(child, orders, code); + } + node->visitMark = code + 2; + orders.Add(node); + } +} + } \ No newline at end of file diff --git a/source/network/XNet.h b/source/network/XNet.h index 7416626..043c585 100644 --- a/source/network/XNet.h +++ b/source/network/XNet.h @@ -32,15 +32,27 @@ namespace nts{ /* management of tensor net (or graph) */ struct XNet { - /* tensor nodes of the net (in order) */ + /* tensor nodes of the network (in order) */ XList nodes; + /* tensor nodes to keep gradient for output (e.g., SGD)*/ + XList gradNodes; + + /* output nodes of the network */ + XList outputs; + + /* input nodes of the network */ + XList inputs; + /* constructor */ XNet(); /* de-constructor */ ~XNet(); + /* clear the network */ + void Clear(); + /* backward propagation to obtain gradient wrt. the loss/error function */ void Backward(XTensor &root, XTensor &gold = NULLTensor, LOSS_FUNCTION_NAME loss = NOLOSS); @@ -55,8 +67,16 @@ struct XNet /* traverse the net and find the topological order by depth-first search (Tarjan's algorithm) */ void Traverse(XList &roots); + + /* depth-first search given a node (Tarjan's algorithm for topological ordering) */ + void TarjanVisit(XTensor * node, XList &orders, const unsigned int code); }; +/* we make a unique id for every tensor */ +extern unsigned int netIDGlobal; +extern MUTEX_HANDLE netMutex; +extern unsigned int MakeNetID(); + } #endif \ No newline at end of file diff --git a/source/tensor/XTensor.cpp b/source/tensor/XTensor.cpp index 8219932..d589165 100644 --- a/source/tensor/XTensor.cpp +++ b/source/tensor/XTensor.cpp @@ -237,6 +237,7 @@ void XTensor::Init() memset(isAllValued, 0, sizeof(bool) * MAX_TENSOR_DIM_NUM); isInit = false; isTmp = false; + visitMark = 0; } /* delete data arrays */ diff --git a/source/tensor/XTensor.h b/source/tensor/XTensor.h index 2372687..678b0f1 100644 --- a/source/tensor/XTensor.h +++ b/source/tensor/XTensor.h @@ -138,6 +138,9 @@ public: /* indicates whether the tensor is created temporarily */ bool isTmp; + + /* mark for traversing the gragh */ + int visitMark; /* the link used to form networks. Note that when we compute on tensors, we actually create a -- libgit2 0.26.0