Commit d49fb9b1 by xiaotong

declare XNet class

parent f8a37184
......@@ -40,6 +40,10 @@ int main( int argc, const char ** argv )
fprintf(stderr, "Run this program with \"-test\" for unit test!\n");
}
XNet net;
XTensor a;
net.Backward(a);
//_CrtDumpMemoryLeaks();
return 0;
......
......@@ -23,4 +23,67 @@
namespace nts{
/* constructor */
XNet::XNet()
{
nodes.Clear();
}
/* de-constructor */
XNet::~XNet()
{
}
/*
backward propagation to obtain gradient wrt. the loss/error function
>> root - root node (output) of the network
>> gold - gold standard for the output
>> loss - name of loss function
*/
void XNet::Backward(XTensor &root, XTensor &gold, LOSS_FUNCTION_NAME loss)
{
XList roots(1);
roots.Add(&root);
XList golds(1);
golds.Add(&gold);
Backward(roots, golds, loss);
}
/*
backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes
>> root - a list of root nodes (output) of the network
>> gold - a list of gold standard for the output
>> loss - name of loss function
*/
void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
{
Traverse(roots);
}
/*
traverse the net and find the topological order by
depth-first search (Tarjan's algorithm)
>> root - root node (or output of the net)
*/
void XNet::Traverse(XTensor &root)
{
XList roots(1);
roots.Add(&root);
Traverse(roots);
}
/*
traverse the net and find the topological order by
depth-first search (Tarjan's algorithm)
>> roots - a list of roots (or output nodes)
*/
void XNet::Traverse(XList &roots)
{
nodes.Clear();
}
}
\ No newline at end of file
......@@ -30,9 +30,16 @@
namespace nts{
/* management of tensor net (or graph) */
class XNet
struct XNet
{
public:
/* tensor nodes of the net (in order) */
XList nodes;
/* constructor */
XNet();
/* de-constructor */
~XNet();
/* backward propagation to obtain gradient wrt. the loss/error function */
void Backward(XTensor &root, XTensor &gold = NULLTensor, LOSS_FUNCTION_NAME loss = NOLOSS);
......@@ -40,6 +47,14 @@ public:
/* backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes */
void Backward(XList &roots, XList &golds = NULLList, LOSS_FUNCTION_NAME loss = NOLOSS);
/* traverse the net and find the topological order by
depth-first search (Tarjan's algorithm) */
void Traverse(XTensor &root);
/* traverse the net and find the topological order by
depth-first search (Tarjan's algorithm) */
void Traverse(XList &roots);
};
}
......
......@@ -45,7 +45,7 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(123);
/* a tiny test */
if(false)
if(true)
SmallTest();
//_CrtDumpMemoryLeaks();
......
......@@ -37,6 +37,7 @@ XLink::XLink()
paramNum = 0;
type[0] = 0;
typeID = 0;
caculator = NULL;
}
/* deconstructor */
......@@ -59,6 +60,8 @@ void XLink::Reset()
tailNum = 0;
paramNum = 0;
type[0] = 0;
typeID = 0;
caculator = NULL;
}
/* clear it */
......@@ -68,6 +71,8 @@ void XLink::Clear()
tailNum = 0;
paramNum = 0;
type[0] = 0;
typeID = 0;
caculator = NULL;
}
/* reset tails */
......
......@@ -77,6 +77,9 @@ struct XLink
/* type id */
int typeID;
/* caculator (pointer to the class for computation) */
void * caculator;
/* constuctor */
XLink();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论