Commit b3c2cf56 by xiaotong

Adding "name" into XTensor. Now we can search for a node in the network by its "name".

parent 9ac1028a
...@@ -458,4 +458,15 @@ void XNet::ShowNetwork(FILE * file, XTensor * node) ...@@ -458,4 +458,15 @@ void XNet::ShowNetwork(FILE * file, XTensor * node)
} }
} }
/*
search for a node in a top-down manner by its name
>> top - the top most node
<< return - the node we found
*/
XTensor * XNet::SearchNode(XTensor * top, const char * name)
{
return XLink::SearchNode(top, name);
}
} }
\ No newline at end of file
...@@ -111,6 +111,10 @@ struct XNet ...@@ -111,6 +111,10 @@ struct XNet
/* show network topology */ /* show network topology */
void ShowNetwork(FILE * file, XTensor * node); void ShowNetwork(FILE * file, XTensor * node);
/* search a node in a top-down manner by its name */
static
XTensor * SearchNode(XTensor * top, const char * name);
}; };
/* we make a unique id for every tensor */ /* we make a unique id for every tensor */
......
...@@ -528,6 +528,8 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne) ...@@ -528,6 +528,8 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne)
CheckNTErrors(hit, "No proper node found in parent.income edge!"); CheckNTErrors(hit, "No proper node found in parent.income edge!");
} }
} }
strcpy(newOne->name, oldOne->name);
} }
/* /*
...@@ -655,6 +657,29 @@ void XLink::ShowNode(FILE * file, XTensor * node) ...@@ -655,6 +657,29 @@ void XLink::ShowNode(FILE * file, XTensor * node)
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
/*
search for a node in a top-down manner by its name
>> top - the top most node
<< return - the node we found
*/
XTensor * XLink::SearchNode(XTensor * top, const char * name)
{
if(!strcmp(top->name, name))
return top;
XLink &incoming = top->income;
for(int i = 0; i < incoming.tailNum; i++){
XTensor * child = incoming.tails[i];
XTensor * hit = SearchNode(child, name);
if(hit != NULL)
return hit;
}
return NULL;
}
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
...@@ -185,6 +185,10 @@ struct XLink ...@@ -185,6 +185,10 @@ struct XLink
/* show a node */ /* show a node */
static static
void ShowNode(FILE * file, XTensor * node); void ShowNode(FILE * file, XTensor * node);
/* search a node in a top-down manner by its name */
static
XTensor * SearchNode(XTensor * top, const char * name);
}; };
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -251,9 +251,16 @@ XTensor::~XTensor() ...@@ -251,9 +251,16 @@ XTensor::~XTensor()
delete grad; delete grad;
} }
/* set the name of the tensor */
void XTensor::SetName(const char * myName)
{
strcpy(name, myName);
}
/* initialize member variables */ /* initialize member variables */
void XTensor::Init() void XTensor::Init()
{ {
name[0] = '\0';
id = -1; id = -1;
mem = NULL; mem = NULL;
signature = 0; signature = 0;
...@@ -306,6 +313,7 @@ Note that we do not copy data array here ...@@ -306,6 +313,7 @@ Note that we do not copy data array here
*/ */
void XTensor::ShallowCopy(const XTensor &tensor) void XTensor::ShallowCopy(const XTensor &tensor)
{ {
strcpy(name, tensor.name);
order = tensor.order; order = tensor.order;
memcpy(dimSize, tensor.dimSize, sizeof(int) * MAX_TENSOR_DIM_NUM); memcpy(dimSize, tensor.dimSize, sizeof(int) * MAX_TENSOR_DIM_NUM);
memcpy(dimSizeRDI, tensor.dimSizeRDI, sizeof(int) * MAX_TENSOR_DIM_NUM); memcpy(dimSizeRDI, tensor.dimSizeRDI, sizeof(int) * MAX_TENSOR_DIM_NUM);
......
...@@ -52,6 +52,7 @@ struct XLink; ...@@ -52,6 +52,7 @@ struct XLink;
#define MIN_TENSOR_MERGE_NUM 0 #define MIN_TENSOR_MERGE_NUM 0
#define MIN_TENSOR_MERGE_LIST_NUM 1024 #define MIN_TENSOR_MERGE_LIST_NUM 1024
#define MIN_TENSOR_CAT_NUM 8 #define MIN_TENSOR_CAT_NUM 8
#define MAX_TENSOR_NAME_SIZE 32
/* computation flags */ /* computation flags */
#define UNSAFE_BUT_FAST_MEM #define UNSAFE_BUT_FAST_MEM
...@@ -61,6 +62,9 @@ struct XLink; ...@@ -61,6 +62,9 @@ struct XLink;
struct XTensor struct XTensor
{ {
public: public:
/* name */
char name[MAX_TENSOR_NAME_SIZE];
/* id */ /* id */
int id; int id;
...@@ -197,6 +201,9 @@ public: ...@@ -197,6 +201,9 @@ public:
/* de-constructor */ /* de-constructor */
~XTensor(); ~XTensor();
/* set the name of the tensor */
void SetName(const char * myName);
/* initialize member variables */ /* initialize member variables */
void Init(); void Init();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论