Commit 8a6d5d3b by xiaotong

bug fixes and updates of XWorkerBroadcast

parent 1afdcdba
...@@ -80,10 +80,10 @@ void TestTrain() ...@@ -80,10 +80,10 @@ void TestTrain()
config.Add("lrate", 0.1F); config.Add("lrate", 0.1F);
config.Add("nstep", 100000); config.Add("nstep", 100000);
config.Add("nepoch", 5); config.Add("nepoch", 5);
config.Add("jobdev0", -1); config.Add("jobdev0", 0);
config.Add("jobdev1", -1); //config.Add("jobdev1", -1);
config.Add("jobdev2", -1); //config.Add("jobdev2", -1);
config.Add("jobdev3", -1); //config.Add("jobdev3", -1);
//config.Add("jobdev4", -1); //config.Add("jobdev4", -1);
int serverDevID = config.GetInt("jobdev0", -1); int serverDevID = config.GetInt("jobdev0", -1);
......
...@@ -415,7 +415,7 @@ void XLeader::MakeParamMap() ...@@ -415,7 +415,7 @@ void XLeader::MakeParamMap()
} }
for (int j = 0, c = 0; j < jworkers.count; j++) { for (int j = 0, c = 0; j < jworkers.count; j++) {
XWorker * worker = (XWorker*)jworkers[i]; XWorker * worker = (XWorker*)jworkers[j];
if (worker->GetWorkerType() == XWORKER_TYPE_JOB) { if (worker->GetWorkerType() == XWORKER_TYPE_JOB) {
XModel * model = ((XWorkerJob*)jworkers[j])->GetModel(); XModel * model = ((XWorkerJob*)jworkers[j])->GetModel();
paramMap[i][c].tensor = model->params[i].tensor; paramMap[i][c].tensor = model->params[i].tensor;
...@@ -522,7 +522,7 @@ int XLeader::RunModel(XConfig * config, DataDistributeBase * dataDistributor, in ...@@ -522,7 +522,7 @@ int XLeader::RunModel(XConfig * config, DataDistributeBase * dataDistributor, in
} }
/* /*
update the model update the model in a standard server-worker manner
>> config - the configuration >> config - the configuration
>> optimizer - the optimizer >> optimizer - the optimizer
>> active - flag for each job worker (1 = active, 0 = not active) >> active - flag for each job worker (1 = active, 0 = not active)
...@@ -555,7 +555,7 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -555,7 +555,7 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
CheckNTErrors(jobQueues.count == serverModel.paramNum, "Incompatiable model!"); CheckNTErrors(jobQueues.count == serverModel.paramNum, "Incompatiable model!");
/* jobs in queue 2 (say jobQueue): collect the (gradient) data and other stuff. /* jobs in queue 2 (say jobQueue): collect the (gradient) data.
This is a reduce process. Then we add a job to to update the model. followed This is a reduce process. Then we add a job to to update the model. followed
by a job to broadcast the lastest parameters to workers. NOTE that we by a job to broadcast the lastest parameters to workers. NOTE that we
would update a worker to the latest model parameters, even if it is not would update a worker to the latest model parameters, even if it is not
...@@ -583,6 +583,8 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -583,6 +583,8 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
} }
} }
XList * paramList = new XList[serverModel.paramNum];
CheckNTErrors(modelCount == modelNum, "Wrong model number!"); CheckNTErrors(modelCount == modelNum, "Wrong model number!");
/* This is a simple implementation of the do-and-wait process */ /* This is a simple implementation of the do-and-wait process */
...@@ -620,6 +622,11 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -620,6 +622,11 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
collecter->AddJobCollectDataP2P(jobQueue, paramWorker.grad, paramServer.grad); collecter->AddJobCollectDataP2P(jobQueue, paramWorker.grad, paramServer.grad);
collecter->AddJobEnqueueFinished(jobQueue); collecter->AddJobEnqueueFinished(jobQueue);
/* We keep the worker parameter in a list. It would be used when we broadcast
the updated paramter to the workers, that is, this is a list of worker
parameters. */
paramList[j].Add(&paramWorker);
/* reset the flag */ /* reset the flag */
paramWorker.flag = PARAM_STATE_COLLECTED; paramWorker.flag = PARAM_STATE_COLLECTED;
finished++; finished++;
...@@ -637,7 +644,7 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -637,7 +644,7 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
updater->AddJobEnqueueFinished(jobQueue); updater->AddJobEnqueueFinished(jobQueue);
/* broadcast the new parameter to other models */ /* broadcast the new parameter to other models */
broadcaster->AddJobBroadcast(jobQueue, &serverModel, &membersAll, j); broadcaster->AddJobBroadcast(jobQueue, &paramServer, &paramList[j]);
broadcaster->AddJobEnqueueFinished(jobQueue); broadcaster->AddJobEnqueueFinished(jobQueue);
} }
} }
...@@ -658,6 +665,7 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -658,6 +665,7 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
delete[] finishedCount; delete[] finishedCount;
delete[] modelFlag; delete[] modelFlag;
delete[] paramList;
} }
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -170,6 +170,8 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -170,6 +170,8 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
} }
delete[] ids; delete[] ids;
XPRINT(1, stderr, "[INFO] Training Finished[DONE]");
} }
/* show settings of training */ /* show settings of training */
......
...@@ -53,22 +53,21 @@ void XWorkerBroadcast::SetBroadcastMode(DATA_BROADCAST_TYPE myMode) ...@@ -53,22 +53,21 @@ void XWorkerBroadcast::SetBroadcastMode(DATA_BROADCAST_TYPE myMode)
/* /*
broadcast data for a parameter broadcast data for a parameter
>> source - the data (as a model) that we want to broadcast >> source - the data (as a model) that we want to broadcast
>> targetList - the target places that we recieve the data >> targetList - the target places where we recieve the data
>> pid - the parameter index
*/ */
void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, int pid) void XWorkerBroadcast::BroadcastData(XTensorKeeper * source, XList * targetList)
{ {
CheckNTErrors(source->params[pid].flag == PARAM_STATE_UPDATED, CheckNTErrors(source->flag == PARAM_STATE_UPDATED,
"The parameter is not ready for broadcasting"); "The parameter is not ready for broadcasting");
for (int i = 0; i < targetList->count; i++) { for (int i = 0; i < targetList->count; i++) {
XModel * target = (XModel*)targetList->GetItem(i); XTensorKeeper * target = (XTensorKeeper*)targetList->GetItem(i);
/* data transmit */ /* data transmit */
BroadcastP2P(source->params[pid].tensor, target->params[pid].tensor); BroadcastP2P(source->tensor, target->tensor);
/* update the flag */ /* update the flag */
target->params[pid].flag = PARAM_STATE_UPDATED; target->flag = PARAM_STATE_UPDATED;
} }
} }
...@@ -81,20 +80,17 @@ void XWorkerBroadcast::Broadcast(XList * args) ...@@ -81,20 +80,17 @@ void XWorkerBroadcast::Broadcast(XList * args)
int paramCount = 0; int paramCount = 0;
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(paramCount++); XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(paramCount++);
XModel * source = (XModel*)args->GetItem(paramCount++); XTensorKeeper * source = (XTensorKeeper*)args->GetItem(paramCount++);
/* target models */ /* target models */
int targetNum = args->GetItemInt(paramCount++); int targetNum = args->GetItemInt(paramCount++);
XList target; XList target;
for (int i = 0; i < targetNum; i++) { for (int i = 0; i < targetNum; i++) {
XModel * model = (XModel*)args->GetItem(paramCount++); XTensorKeeper * model = (XTensorKeeper*)args->GetItem(paramCount++);
target.Add(model); target.Add(model);
} }
/* parameter index */ broadcaster->BroadcastData(source, &target);
int p = args->GetInt(paramCount++);
broadcaster->BroadcastData(source, &target, p);
} }
/* /*
...@@ -116,21 +112,18 @@ void XWorkerBroadcast::BroadcastP2P(XTensor * source, XTensor * target) ...@@ -116,21 +112,18 @@ void XWorkerBroadcast::BroadcastP2P(XTensor * source, XTensor * target)
add a new job of broadcasting data (for a parameter) add a new job of broadcasting data (for a parameter)
>> jobQueue - the queue where we push jobs >> jobQueue - the queue where we push jobs
>> source - the data that we want to broadcast >> source - the data that we want to broadcast
>> targetList - the target places that we recieve the data >> targetList - the target places where we recieve the data
>> pid - the parameter index
*/ */
bool XWorkerBroadcast::AddJobBroadcast(XQueue * jobQueue, XModel * source, XList * targetList, int pid) bool XWorkerBroadcast::AddJobBroadcast(XQueue * jobQueue, XTensorKeeper * source, XList * targetList)
{ {
CheckNTErrors(source != NULL, "no input source tensor!"); CheckNTErrors(source != NULL, "no input source tensor!");
CheckNTErrors(targetList != NULL, "no input target tensor list!"); CheckNTErrors(targetList != NULL, "no input target tensor list!");
CheckNTErrors(pid >= 0 && pid < source->paramNum, "illegal parameter index!");
XList args; XList args;
args.Add(this); args.Add(this);
args.Add(source); args.Add(source);
args.AddInt(targetList->count); args.AddInt(targetList->count);
args.AddList(targetList); args.AddList(targetList);
args.AddInt(pid);
XQueue& queueRun = jobQueue != NULL ? *jobQueue : queue; XQueue& queueRun = jobQueue != NULL ? *jobQueue : queue;
......
...@@ -61,7 +61,7 @@ public: ...@@ -61,7 +61,7 @@ public:
void SetBroadcastMode(DATA_BROADCAST_TYPE myMode); void SetBroadcastMode(DATA_BROADCAST_TYPE myMode);
/* broadcast data for a parameter */ /* broadcast data for a parameter */
void BroadcastData(XModel * source, XList * targetList, int pid); void BroadcastData(XTensorKeeper * source, XList * targetList);
/* wrapper of BroadcastDataSingle */ /* wrapper of BroadcastDataSingle */
static static
...@@ -71,7 +71,7 @@ public: ...@@ -71,7 +71,7 @@ public:
void BroadcastP2P(XTensor * source, XTensor * target); void BroadcastP2P(XTensor * source, XTensor * target);
/* add a new job of broadcasting data (for a parameter) */ /* add a new job of broadcasting data (for a parameter) */
bool AddJobBroadcast(XQueue * jobQueue, XModel * source, XList * targetList, int pid); bool AddJobBroadcast(XQueue * jobQueue, XTensorKeeper * source, XList * targetList);
}; };
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论