Commit 412e53a8 by xiaotong

bug fixes for multi-threading

parent 87bb27ee
...@@ -107,13 +107,24 @@ void XLeader::SetServerModel(XConfig * config, XModel * model) ...@@ -107,13 +107,24 @@ void XLeader::SetServerModel(XConfig * config, XModel * model)
{ {
XList members; XList members;
for (int i = 0; i < jworkers.count; i++) { for (int i = 0; i < jworkers.count; i++) {
XModel * member = (XModel*)jworkers.GetItem(i); XModel * member = ((XWorkerJob*)jworkers[i])->GetModel();
members.Add(member); members.Add(member);
} }
SetServerModel(config, model, &members); SetServerModel(config, model, &members);
} }
/* initialize the models for running them */
void XLeader::InitForRun()
{
serverModel.InitForRun();
for (int i = 0; i < jworkers.count; i++) {
XModel * model = ((XWorkerJob*)jworkers[i])->GetModel();
model->InitForRun();
}
}
/* get loss */ /* get loss */
float XLeader::GetLoss() float XLeader::GetLoss()
{ {
...@@ -260,6 +271,8 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -260,6 +271,8 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
int activeJobCount = 0; int activeJobCount = 0;
int* active = new int[jworkers.count]; int* active = new int[jworkers.count];
InitForRun();
for (int i = 0; i < jworkers.count; i++) for (int i = 0; i < jworkers.count; i++)
active[i] = 0; active[i] = 0;
......
...@@ -109,6 +109,9 @@ public: ...@@ -109,6 +109,9 @@ public:
/* set the server model */ /* set the server model */
void SetServerModel(XConfig * config, XModel * model); void SetServerModel(XConfig * config, XModel * model);
/* initialize the models for running them */
void InitForRun();
/* get loss */ /* get loss */
float GetLoss(); float GetLoss();
......
...@@ -127,6 +127,16 @@ bool XModel::CheckParam() ...@@ -127,6 +127,16 @@ bool XModel::CheckParam()
return true; return true;
} }
/* initial model for running the it */
void XModel::InitForRun()
{
for (int i = 0; i < params.count; i++) {
XTensor * param = (XTensor*)params[i];
param->isGradFinished = false;
flags[i] = PARAM_STATE_NOT_READY;
}
}
/* refresh the model */ /* refresh the model */
void XModel::RefreshMe() void XModel::RefreshMe()
{ {
......
...@@ -95,6 +95,9 @@ public: ...@@ -95,6 +95,9 @@ public:
/* check if the parameters are well-defined for training */ /* check if the parameters are well-defined for training */
bool CheckParam(); bool CheckParam();
/* initial model for running the it */
void InitForRun();
/* refresh the model */ /* refresh the model */
void RefreshMe(); void RefreshMe();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论