Commit d49fb9b1 by xiaotong

declare XNet class

parent f8a37184
...@@ -40,6 +40,10 @@ int main( int argc, const char ** argv ) ...@@ -40,6 +40,10 @@ int main( int argc, const char ** argv )
fprintf(stderr, "Run this program with \"-test\" for unit test!\n"); fprintf(stderr, "Run this program with \"-test\" for unit test!\n");
} }
XNet net;
XTensor a;
net.Backward(a);
//_CrtDumpMemoryLeaks(); //_CrtDumpMemoryLeaks();
return 0; return 0;
......
...@@ -23,4 +23,67 @@ ...@@ -23,4 +23,67 @@
namespace nts{ namespace nts{
/* constructor */
XNet::XNet()
{
nodes.Clear();
}
/* de-constructor */
XNet::~XNet()
{
}
/*
backward propagation to obtain gradient wrt. the loss/error function
>> root - root node (output) of the network
>> gold - gold standard for the output
>> loss - name of loss function
*/
void XNet::Backward(XTensor &root, XTensor &gold, LOSS_FUNCTION_NAME loss)
{
XList roots(1);
roots.Add(&root);
XList golds(1);
golds.Add(&gold);
Backward(roots, golds, loss);
}
/*
backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes
>> root - a list of root nodes (output) of the network
>> gold - a list of gold standard for the output
>> loss - name of loss function
*/
void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
{
Traverse(roots);
}
/*
traverse the net and find the topological order by
depth-first search (Tarjan's algorithm)
>> root - root node (or output of the net)
*/
void XNet::Traverse(XTensor &root)
{
XList roots(1);
roots.Add(&root);
Traverse(roots);
}
/*
traverse the net and find the topological order by
depth-first search (Tarjan's algorithm)
>> roots - a list of roots (or output nodes)
*/
void XNet::Traverse(XList &roots)
{
nodes.Clear();
}
} }
\ No newline at end of file
...@@ -30,9 +30,16 @@ ...@@ -30,9 +30,16 @@
namespace nts{ namespace nts{
/* management of tensor net (or graph) */ /* management of tensor net (or graph) */
class XNet struct XNet
{ {
public: /* tensor nodes of the net (in order) */
XList nodes;
/* constructor */
XNet();
/* de-constructor */
~XNet();
/* backward propagation to obtain gradient wrt. the loss/error function */ /* backward propagation to obtain gradient wrt. the loss/error function */
void Backward(XTensor &root, XTensor &gold = NULLTensor, LOSS_FUNCTION_NAME loss = NOLOSS); void Backward(XTensor &root, XTensor &gold = NULLTensor, LOSS_FUNCTION_NAME loss = NOLOSS);
...@@ -40,6 +47,14 @@ public: ...@@ -40,6 +47,14 @@ public:
/* backward propagation to obtain gradient wrt. the loss/error function /* backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes */ with a number of root nodes */
void Backward(XList &roots, XList &golds = NULLList, LOSS_FUNCTION_NAME loss = NOLOSS); void Backward(XList &roots, XList &golds = NULLList, LOSS_FUNCTION_NAME loss = NOLOSS);
/* traverse the net and find the topological order by
depth-first search (Tarjan's algorithm) */
void Traverse(XTensor &root);
/* traverse the net and find the topological order by
depth-first search (Tarjan's algorithm) */
void Traverse(XList &roots);
}; };
} }
......
...@@ -45,7 +45,7 @@ int main( int argc, const char ** argv ) ...@@ -45,7 +45,7 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(123); //_CrtSetBreakAlloc(123);
/* a tiny test */ /* a tiny test */
if(false) if(true)
SmallTest(); SmallTest();
//_CrtDumpMemoryLeaks(); //_CrtDumpMemoryLeaks();
......
...@@ -37,6 +37,7 @@ XLink::XLink() ...@@ -37,6 +37,7 @@ XLink::XLink()
paramNum = 0; paramNum = 0;
type[0] = 0; type[0] = 0;
typeID = 0; typeID = 0;
caculator = NULL;
} }
/* deconstructor */ /* deconstructor */
...@@ -59,6 +60,8 @@ void XLink::Reset() ...@@ -59,6 +60,8 @@ void XLink::Reset()
tailNum = 0; tailNum = 0;
paramNum = 0; paramNum = 0;
type[0] = 0; type[0] = 0;
typeID = 0;
caculator = NULL;
} }
/* clear it */ /* clear it */
...@@ -68,6 +71,8 @@ void XLink::Clear() ...@@ -68,6 +71,8 @@ void XLink::Clear()
tailNum = 0; tailNum = 0;
paramNum = 0; paramNum = 0;
type[0] = 0; type[0] = 0;
typeID = 0;
caculator = NULL;
} }
/* reset tails */ /* reset tails */
......
...@@ -77,6 +77,9 @@ struct XLink ...@@ -77,6 +77,9 @@ struct XLink
/* type id */ /* type id */
int typeID; int typeID;
/* caculator (pointer to the class for computation) */
void * caculator;
/* constuctor */ /* constuctor */
XLink(); XLink();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论