Commit 68e146a6 by xiaotong

fix the bug of XLink::params

parent 4b7f7a18
...@@ -145,6 +145,8 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss) ...@@ -145,6 +145,8 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
XTensor * node = (XTensor*)nodes.Get(i); XTensor * node = (XTensor*)nodes.Get(i);
if(node->visitMark == NODE_FINISHED) if(node->visitMark == NODE_FINISHED)
continue; continue;
BackwardNode(node);
} }
} }
......
...@@ -114,6 +114,8 @@ void XLink::ClearOutgoing(XTensor * node) ...@@ -114,6 +114,8 @@ void XLink::ClearOutgoing(XTensor * node)
outgo.ClearTail(); outgo.ClearTail();
outgo.typeID = 0; outgo.typeID = 0;
outgo.type[0] = 0; outgo.type[0] = 0;
delete[] outgo.params;
outgo.params = NULL;
} }
/* /*
...@@ -152,6 +154,8 @@ void XLink::ClearIncoming(XTensor * node) ...@@ -152,6 +154,8 @@ void XLink::ClearIncoming(XTensor * node)
income.ClearTail(); income.ClearTail();
income.typeID = 0; income.typeID = 0;
income.type[0] = 0; income.type[0] = 0;
delete[] income.params;
income.params = NULL;
} }
/* /*
...@@ -210,7 +214,7 @@ add a parameter ...@@ -210,7 +214,7 @@ add a parameter
void XLink::AddParam(DTYPE param) void XLink::AddParam(DTYPE param)
{ {
void * ps = params; void * ps = params;
params = new char[paramNum + 1]; params = new char[(paramNum + 1) * paramSize];
memcpy(params, ps, paramNum * paramSize); memcpy(params, ps, paramNum * paramSize);
DTYPE * p = (DTYPE*)((char*)params + paramNum * paramSize); DTYPE * p = (DTYPE*)((char*)params + paramNum * paramSize);
*p = param; *p = param;
...@@ -226,7 +230,7 @@ add a parameter ...@@ -226,7 +230,7 @@ add a parameter
void XLink::AddParam(void * param, int size) void XLink::AddParam(void * param, int size)
{ {
void * ps = params; void * ps = params;
params = new char[paramNum + 1]; params = new char[(paramNum + 1) * paramSize];
memcpy(params, ps, paramNum * paramSize); memcpy(params, ps, paramNum * paramSize);
char * p = (char*)params + paramNum * paramSize; char * p = (char*)params + paramNum * paramSize;
memcpy(p, param, size); memcpy(p, param, size);
...@@ -240,6 +244,7 @@ get a paramter in default type ...@@ -240,6 +244,7 @@ get a paramter in default type
*/ */
DTYPE XLink::GetParam(int i) DTYPE XLink::GetParam(int i)
{ {
CheckNTErrors(params != NULL, "parameter array cannot be empty!");
char * p = (char*)params + i * paramSize; char * p = (char*)params + i * paramSize;
return *(DTYPE*)p; return *(DTYPE*)p;
} }
...@@ -250,6 +255,7 @@ get a paramter in integer ...@@ -250,6 +255,7 @@ get a paramter in integer
*/ */
int XLink::GetParamInt(int i) int XLink::GetParamInt(int i)
{ {
CheckNTErrors(params != NULL, "parameter array cannot be empty!");
char * p = (char*)params + i * paramSize; char * p = (char*)params + i * paramSize;
return *(int*)p; return *(int*)p;
} }
...@@ -314,8 +320,7 @@ add parameters ...@@ -314,8 +320,7 @@ add parameters
*/ */
void XLink::AddParamToHead(XTensor * h, DTYPE param) void XLink::AddParamToHead(XTensor * h, DTYPE param)
{ {
if(h != NULL) CheckNTErrors(h != NULL, "head tensor cannot be empty!");
return;
h->income.AddParam(param); h->income.AddParam(param);
} }
...@@ -326,8 +331,7 @@ add an integer parameter ...@@ -326,8 +331,7 @@ add an integer parameter
*/ */
void XLink::AddParamToHeadInt(XTensor * h, int param) void XLink::AddParamToHeadInt(XTensor * h, int param)
{ {
if(h != NULL) CheckNTErrors(h != NULL, "head tensor cannot be empty!");
return;
h->income.AddParam(&param, sizeof(int)); h->income.AddParam(&param, sizeof(int));
} }
...@@ -338,8 +342,7 @@ add a MATRIX_TRANS_TYPE parameter ...@@ -338,8 +342,7 @@ add a MATRIX_TRANS_TYPE parameter
*/ */
void XLink::AddParamToHeadTrans(XTensor * h, MATRIX_TRANS_TYPE param) void XLink::AddParamToHeadTrans(XTensor * h, MATRIX_TRANS_TYPE param)
{ {
if(h != NULL) CheckNTErrors(h != NULL, "head tensor cannot be empty!");
return;
h->income.AddParam(&param, sizeof(MATRIX_TRANS_TYPE)); h->income.AddParam(&param, sizeof(MATRIX_TRANS_TYPE));
} }
...@@ -396,6 +399,11 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne) ...@@ -396,6 +399,11 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne)
newIncome.tailNum = oldOne->income.tailNum; newIncome.tailNum = oldOne->income.tailNum;
memcpy(newIncome.tails, oldOne->income.tails, sizeof(XTensor*) * newIncome.tailNum); memcpy(newIncome.tails, oldOne->income.tails, sizeof(XTensor*) * newIncome.tailNum);
int paraArraySize = oldOne->income.paramNum * oldOne->income.paramSize;
newIncome.params = new char[paraArraySize];
memcpy(newIncome.params, oldOne->income.params, paraArraySize);
newIncome.paramNum = oldOne->income.paramNum;
/* update the link to each child node */ /* update the link to each child node */
for(int i = 0; i < newIncome.tailNum; i++){ for(int i = 0; i < newIncome.tailNum; i++){
XTensor * child = newIncome.tails[i]; XTensor * child = newIncome.tails[i];
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论