Commit 59a7030a by xiaotong

net visit code

parent d49fb9b1
......@@ -23,6 +23,23 @@
namespace nts{
unsigned unsigned int netIDGlobal = 0;
MUTEX_HANDLE netMutex;
/* generate a network id */
unsigned int MakeNetID()
{
if(tensorIDGlobal == 0)
MUTEX_INIT(netMutex);
MUTEX_LOCK(netMutex);
netIDGlobal += 3;
unsigned int id = netIDGlobal;
MUTEX_UNLOCK(netMutex);
return id;
}
/* constructor */
XNet::XNet()
{
......@@ -34,6 +51,15 @@ XNet::~XNet()
{
}
/* clear the network */
void XNet::Clear()
{
nodes.Clear();
gradNodes.Clear();
outputs.Clear();
inputs.Clear();
}
/*
backward propagation to obtain gradient wrt. the loss/error function
>> root - root node (output) of the network
......@@ -86,4 +112,32 @@ void XNet::Traverse(XList &roots)
nodes.Clear();
}
/*
depth-first search given a node (Tarjan's algorithm for topological ordering)
>> node - the node to visit (mark 0:unvisited, 1:visiting, 2:done)
>> orders - topological order of the nodes
>> code - code of the network
*/
void XNet::TarjanVisit(XTensor * node, XList &orders, const unsigned int code)
{
if(node == NULL)
return;
if(node->visitMark == code + 1){
ShowNTErrors("There is a circle in the network\n");
}
else if(node->visitMark <= code || node->visitMark > code + 2){
node->visitMark = code + 1;
XLink &income = node->income;
for(int i = 0; i < income.tailNum; i++){
XTensor * child = income.tails[i];
if(child == NULL)
continue;
TarjanVisit(child, orders, code);
}
node->visitMark = code + 2;
orders.Add(node);
}
}
}
\ No newline at end of file
......@@ -32,15 +32,27 @@ namespace nts{
/* management of tensor net (or graph) */
struct XNet
{
/* tensor nodes of the net (in order) */
/* tensor nodes of the network (in order) */
XList nodes;
/* tensor nodes to keep gradient for output (e.g., SGD)*/
XList gradNodes;
/* output nodes of the network */
XList outputs;
/* input nodes of the network */
XList inputs;
/* constructor */
XNet();
/* de-constructor */
~XNet();
/* clear the network */
void Clear();
/* backward propagation to obtain gradient wrt. the loss/error function */
void Backward(XTensor &root, XTensor &gold = NULLTensor, LOSS_FUNCTION_NAME loss = NOLOSS);
......@@ -55,8 +67,16 @@ struct XNet
/* traverse the net and find the topological order by
depth-first search (Tarjan's algorithm) */
void Traverse(XList &roots);
/* depth-first search given a node (Tarjan's algorithm for topological ordering) */
void TarjanVisit(XTensor * node, XList &orders, const unsigned int code);
};
/* we make a unique id for every tensor */
extern unsigned int netIDGlobal;
extern MUTEX_HANDLE netMutex;
extern unsigned int MakeNetID();
}
#endif
\ No newline at end of file
......@@ -237,6 +237,7 @@ void XTensor::Init()
memset(isAllValued, 0, sizeof(bool) * MAX_TENSOR_DIM_NUM);
isInit = false;
isTmp = false;
visitMark = 0;
}
/* delete data arrays */
......
......@@ -138,6 +138,9 @@ public:
/* indicates whether the tensor is created temporarily */
bool isTmp;
/* mark for traversing the gragh */
int visitMark;
/*
the link used to form networks. Note that when we compute on tensors, we actually create a
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论