Commit b3c2cf56 by xiaotong

Adding "name" into XTensor. Now we can search for a node in the network by its "name".

parent 9ac1028a
......@@ -458,4 +458,15 @@ void XNet::ShowNetwork(FILE * file, XTensor * node)
}
}
/*
search for a node in a top-down manner by its name
>> top - the top most node
<< return - the node we found
*/
XTensor * XNet::SearchNode(XTensor * top, const char * name)
{
return XLink::SearchNode(top, name);
}
}
\ No newline at end of file
......@@ -111,6 +111,10 @@ struct XNet
/* show network topology */
void ShowNetwork(FILE * file, XTensor * node);
/* search a node in a top-down manner by its name */
static
XTensor * SearchNode(XTensor * top, const char * name);
};
/* we make a unique id for every tensor */
......
......@@ -528,6 +528,8 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne)
CheckNTErrors(hit, "No proper node found in parent.income edge!");
}
}
strcpy(newOne->name, oldOne->name);
}
/*
......@@ -655,6 +657,29 @@ void XLink::ShowNode(FILE * file, XTensor * node)
fprintf(stderr, "\n");
}
/*
search for a node in a top-down manner by its name
>> top - the top most node
<< return - the node we found
*/
XTensor * XLink::SearchNode(XTensor * top, const char * name)
{
if(!strcmp(top->name, name))
return top;
XLink &incoming = top->income;
for(int i = 0; i < incoming.tailNum; i++){
XTensor * child = incoming.tails[i];
XTensor * hit = SearchNode(child, name);
if(hit != NULL)
return hit;
}
return NULL;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -185,6 +185,10 @@ struct XLink
/* show a node */
static
void ShowNode(FILE * file, XTensor * node);
/* search a node in a top-down manner by its name */
static
XTensor * SearchNode(XTensor * top, const char * name);
};
} // namespace nts(NiuTrans.Tensor)
......
......@@ -251,9 +251,16 @@ XTensor::~XTensor()
delete grad;
}
/* set the name of the tensor */
void XTensor::SetName(const char * myName)
{
strcpy(name, myName);
}
/* initialize member variables */
void XTensor::Init()
{
name[0] = '\0';
id = -1;
mem = NULL;
signature = 0;
......@@ -306,6 +313,7 @@ Note that we do not copy data array here
*/
void XTensor::ShallowCopy(const XTensor &tensor)
{
strcpy(name, tensor.name);
order = tensor.order;
memcpy(dimSize, tensor.dimSize, sizeof(int) * MAX_TENSOR_DIM_NUM);
memcpy(dimSizeRDI, tensor.dimSizeRDI, sizeof(int) * MAX_TENSOR_DIM_NUM);
......
......@@ -52,6 +52,7 @@ struct XLink;
#define MIN_TENSOR_MERGE_NUM 0
#define MIN_TENSOR_MERGE_LIST_NUM 1024
#define MIN_TENSOR_CAT_NUM 8
#define MAX_TENSOR_NAME_SIZE 32
/* computation flags */
#define UNSAFE_BUT_FAST_MEM
......@@ -61,6 +62,9 @@ struct XLink;
struct XTensor
{
public:
/* name */
char name[MAX_TENSOR_NAME_SIZE];
/* id */
int id;
......@@ -197,6 +201,9 @@ public:
/* de-constructor */
~XTensor();
/* set the name of the tensor */
void SetName(const char * myName);
/* initialize member variables */
void Init();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论