From 59a7030a8865a42377517e84744848b1aba34a03 Mon Sep 17 00:00:00 2001
From: xiaotong <xiaotong@mail.neu.edu.cn>
Date: Fri, 13 Jul 2018 10:47:33 +0800
Subject: [PATCH] net visit code

---
 source/network/XNet.cpp   | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 source/network/XNet.h     | 22 +++++++++++++++++++++-
 source/tensor/XTensor.cpp |  1 +
 source/tensor/XTensor.h   |  3 +++
 4 files changed, 79 insertions(+), 1 deletion(-)

diff --git a/source/network/XNet.cpp b/source/network/XNet.cpp
index 01dea5d..7daf67f 100644
--- a/source/network/XNet.cpp
+++ b/source/network/XNet.cpp
@@ -23,6 +23,23 @@
 
 namespace nts{
 
+unsigned unsigned int netIDGlobal = 0;
+MUTEX_HANDLE netMutex;
+
+/* generate a network id */
+unsigned int MakeNetID()
+{
+    if(tensorIDGlobal == 0)
+        MUTEX_INIT(netMutex);
+
+    MUTEX_LOCK(netMutex);
+    netIDGlobal += 3;
+    unsigned int id = netIDGlobal;
+    MUTEX_UNLOCK(netMutex);
+
+    return id;
+}
+
 /* constructor */
 XNet::XNet()
 {
@@ -34,6 +51,15 @@ XNet::~XNet()
 {
 }
 
+/* clear the network */
+void XNet::Clear()
+{
+    nodes.Clear();
+    gradNodes.Clear();
+    outputs.Clear();
+    inputs.Clear();
+}
+
 /* 
 backward propagation to obtain gradient wrt. the loss/error function 
 >> root - root node (output) of the network
@@ -86,4 +112,32 @@ void XNet::Traverse(XList &roots)
     nodes.Clear();
 }
 
+/* 
+depth-first search given a node (Tarjan's algorithm for topological ordering)
+>> node - the node to visit (mark 0:unvisited, 1:visiting, 2:done)
+>> orders - topological order of the nodes
+>> code - code of the network
+*/
+void XNet::TarjanVisit(XTensor * node, XList &orders, const unsigned int code)
+{
+    if(node == NULL)
+        return;
+
+    if(node->visitMark == code + 1){
+        ShowNTErrors("There is a circle in the network\n");
+    }
+    else if(node->visitMark <= code || node->visitMark > code + 2){
+        node->visitMark = code + 1;
+        XLink &income = node->income;
+        for(int i = 0; i < income.tailNum; i++){
+            XTensor * child = income.tails[i];
+            if(child == NULL)
+                continue;
+            TarjanVisit(child, orders, code);
+        }
+        node->visitMark = code + 2;
+        orders.Add(node);
+    }
+}
+
 }
\ No newline at end of file
diff --git a/source/network/XNet.h b/source/network/XNet.h
index 7416626..043c585 100644
--- a/source/network/XNet.h
+++ b/source/network/XNet.h
@@ -32,15 +32,27 @@ namespace nts{
 /* management of tensor net (or graph) */
 struct XNet
 {
-    /* tensor nodes of the net (in order) */
+    /* tensor nodes of the network (in order) */
     XList nodes;
 
+    /* tensor nodes to keep gradient for output (e.g., SGD)*/
+    XList gradNodes;
+
+    /* output nodes of the network */
+    XList outputs;
+
+    /* input nodes of the network */
+    XList inputs;
+
     /* constructor */
     XNet();
 
     /* de-constructor */
     ~XNet();
 
+    /* clear the network */
+    void Clear();
+
     /* backward propagation to obtain gradient wrt. the loss/error function */
     void Backward(XTensor &root, XTensor &gold = NULLTensor, LOSS_FUNCTION_NAME loss = NOLOSS);
 
@@ -55,8 +67,16 @@ struct XNet
     /* traverse the net and find the topological order by 
        depth-first search (Tarjan's algorithm) */
     void Traverse(XList &roots);
+
+    /* depth-first search given a node (Tarjan's algorithm for topological ordering) */
+    void TarjanVisit(XTensor * node, XList &orders, const unsigned int code);
 };
 
+/* we make a unique id for every tensor */
+extern unsigned int netIDGlobal;
+extern MUTEX_HANDLE netMutex;
+extern unsigned int MakeNetID();
+
 }
 
 #endif
\ No newline at end of file
diff --git a/source/tensor/XTensor.cpp b/source/tensor/XTensor.cpp
index 8219932..d589165 100644
--- a/source/tensor/XTensor.cpp
+++ b/source/tensor/XTensor.cpp
@@ -237,6 +237,7 @@ void XTensor::Init()
     memset(isAllValued, 0, sizeof(bool) * MAX_TENSOR_DIM_NUM);
     isInit = false;
     isTmp =  false;
+    visitMark = 0;
 }
 
 /* delete data arrays */
diff --git a/source/tensor/XTensor.h b/source/tensor/XTensor.h
index 2372687..678b0f1 100644
--- a/source/tensor/XTensor.h
+++ b/source/tensor/XTensor.h
@@ -138,6 +138,9 @@ public:
 
     /* indicates whether the tensor is created temporarily */
     bool isTmp;
+
+    /* mark for traversing the gragh */
+    int visitMark;
     
     /*
     the link used to form networks. Note that when we compute on tensors, we actually create a
--
libgit2 0.26.0