Commit 59a7030a by xiaotong

net visit code

parent d49fb9b1
...@@ -23,6 +23,23 @@ ...@@ -23,6 +23,23 @@
namespace nts{ namespace nts{
unsigned unsigned int netIDGlobal = 0;
MUTEX_HANDLE netMutex;
/* generate a network id */
unsigned int MakeNetID()
{
if(tensorIDGlobal == 0)
MUTEX_INIT(netMutex);
MUTEX_LOCK(netMutex);
netIDGlobal += 3;
unsigned int id = netIDGlobal;
MUTEX_UNLOCK(netMutex);
return id;
}
/* constructor */ /* constructor */
XNet::XNet() XNet::XNet()
{ {
...@@ -34,6 +51,15 @@ XNet::~XNet() ...@@ -34,6 +51,15 @@ XNet::~XNet()
{ {
} }
/* clear the network */
void XNet::Clear()
{
nodes.Clear();
gradNodes.Clear();
outputs.Clear();
inputs.Clear();
}
/* /*
backward propagation to obtain gradient wrt. the loss/error function backward propagation to obtain gradient wrt. the loss/error function
>> root - root node (output) of the network >> root - root node (output) of the network
...@@ -86,4 +112,32 @@ void XNet::Traverse(XList &roots) ...@@ -86,4 +112,32 @@ void XNet::Traverse(XList &roots)
nodes.Clear(); nodes.Clear();
} }
/*
depth-first search given a node (Tarjan's algorithm for topological ordering)
>> node - the node to visit (mark 0:unvisited, 1:visiting, 2:done)
>> orders - topological order of the nodes
>> code - code of the network
*/
void XNet::TarjanVisit(XTensor * node, XList &orders, const unsigned int code)
{
if(node == NULL)
return;
if(node->visitMark == code + 1){
ShowNTErrors("There is a circle in the network\n");
}
else if(node->visitMark <= code || node->visitMark > code + 2){
node->visitMark = code + 1;
XLink &income = node->income;
for(int i = 0; i < income.tailNum; i++){
XTensor * child = income.tails[i];
if(child == NULL)
continue;
TarjanVisit(child, orders, code);
}
node->visitMark = code + 2;
orders.Add(node);
}
}
} }
\ No newline at end of file
...@@ -32,15 +32,27 @@ namespace nts{ ...@@ -32,15 +32,27 @@ namespace nts{
/* management of tensor net (or graph) */ /* management of tensor net (or graph) */
struct XNet struct XNet
{ {
/* tensor nodes of the net (in order) */ /* tensor nodes of the network (in order) */
XList nodes; XList nodes;
/* tensor nodes to keep gradient for output (e.g., SGD)*/
XList gradNodes;
/* output nodes of the network */
XList outputs;
/* input nodes of the network */
XList inputs;
/* constructor */ /* constructor */
XNet(); XNet();
/* de-constructor */ /* de-constructor */
~XNet(); ~XNet();
/* clear the network */
void Clear();
/* backward propagation to obtain gradient wrt. the loss/error function */ /* backward propagation to obtain gradient wrt. the loss/error function */
void Backward(XTensor &root, XTensor &gold = NULLTensor, LOSS_FUNCTION_NAME loss = NOLOSS); void Backward(XTensor &root, XTensor &gold = NULLTensor, LOSS_FUNCTION_NAME loss = NOLOSS);
...@@ -55,8 +67,16 @@ struct XNet ...@@ -55,8 +67,16 @@ struct XNet
/* traverse the net and find the topological order by /* traverse the net and find the topological order by
depth-first search (Tarjan's algorithm) */ depth-first search (Tarjan's algorithm) */
void Traverse(XList &roots); void Traverse(XList &roots);
/* depth-first search given a node (Tarjan's algorithm for topological ordering) */
void TarjanVisit(XTensor * node, XList &orders, const unsigned int code);
}; };
/* we make a unique id for every tensor */
extern unsigned int netIDGlobal;
extern MUTEX_HANDLE netMutex;
extern unsigned int MakeNetID();
} }
#endif #endif
\ No newline at end of file
...@@ -237,6 +237,7 @@ void XTensor::Init() ...@@ -237,6 +237,7 @@ void XTensor::Init()
memset(isAllValued, 0, sizeof(bool) * MAX_TENSOR_DIM_NUM); memset(isAllValued, 0, sizeof(bool) * MAX_TENSOR_DIM_NUM);
isInit = false; isInit = false;
isTmp = false; isTmp = false;
visitMark = 0;
} }
/* delete data arrays */ /* delete data arrays */
......
...@@ -139,6 +139,9 @@ public: ...@@ -139,6 +139,9 @@ public:
/* indicates whether the tensor is created temporarily */ /* indicates whether the tensor is created temporarily */
bool isTmp; bool isTmp;
/* mark for traversing the gragh */
int visitMark;
/* /*
the link used to form networks. Note that when we compute on tensors, we actually create a the link used to form networks. Note that when we compute on tensors, we actually create a
network where nodes are tensors and edges the connections among them. Each connection is network where nodes are tensors and edges the connections among them. Each connection is
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论