Commit 7e102f8a by xiaotong

impove the implementation of the wating function. Now we replace rolling as a…

impove the implementation of the wating function. Now we replace rolling as a queue to check if the jobs are finished.
parent 4759cbc2
...@@ -118,11 +118,57 @@ void XLeader::SetServerModel(XConfig * config, XModel * model) ...@@ -118,11 +118,57 @@ void XLeader::SetServerModel(XConfig * config, XModel * model)
void XLeader::InitForRun() void XLeader::InitForRun()
{ {
serverModel.InitForRun(); serverModel.InitForRun();
for (int i = 0; i < jworkers.count; i++) { for (int i = 0; i < jworkers.count; i++) {
XModel * model = ((XWorkerJob*)jworkers[i])->GetModel(); XModel* model = ((XWorkerJob*)jworkers[i])->GetModel();
model->InitForRun(); model->InitForRun();
} }
XList workers;
workers.AddList(&jworkers);
workers.AddList(&cworkers);
workers.AddList(&uworkers);
workers.AddList(&bworkers);
for (int i = 0; i < workers.count; i++) {
XWorker* worker = (XWorker*)workers[i];
CheckNTErrors(worker->IsEmpty(), "Something is wrong with the finishedQueue!");
}
}
/*
wait for finished states (i.e., all workers finish their jobs)
>> activeJobWorkers - indicates whether each job worker is active
*/
void XLeader::WaitForFinishing(const int* activeJobWorkers)
{
int activeCount = 0;
for (int i = 0; i < jworkers.count; i++) {
if (activeJobWorkers[i] > 0) {
XWorker* worker = (XWorker*)jworkers[i];
worker->DequeueFinishedJob();
activeCount++;
}
}
if (activeCount > 0) {
for (int i = 0; i < cworkers.count; i++) {
XWorker* worker = (XWorker*)cworkers[i];
worker->DequeueFinishedJob();
}
for (int i = 0; i < uworkers.count; i++) {
XWorker* worker = (XWorker*)uworkers[i];
for (int j = 0; j < serverModel.paramNum; j++)
worker->DequeueFinishedJob();
}
for (int i = 0; i < bworkers.count; i++) {
XWorker* worker = (XWorker*)bworkers[i];
for (int j = 0; j < serverModel.paramNum; j++)
worker->DequeueFinishedJob();
}
}
} }
/* get loss */ /* get loss */
...@@ -278,8 +324,6 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -278,8 +324,6 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
InitForRun(); InitForRun();
//LockWorkers();
for (int i = 0; i < jworkers.count; i++) for (int i = 0; i < jworkers.count; i++)
active[i] = 0; active[i] = 0;
...@@ -308,6 +352,9 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -308,6 +352,9 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
/* job in queue 1: make a record of the run */ /* job in queue 1: make a record of the run */
worker->AddJobRecord(&serverRecord); worker->AddJobRecord(&serverRecord);
/* job in queue 1: mark finished */
worker->AddJobEnqueueFinished();
active[i] = 1; active[i] = 1;
activeJobCount++; activeJobCount++;
} }
...@@ -339,7 +386,8 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -339,7 +386,8 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
collecter->AddJobUpdateAll(&members, &membersAll, &serverModel, collecter->AddJobUpdateAll(&members, &membersAll, &serverModel,
optimizer, updater, broadcaster); optimizer, updater, broadcaster);
collecter->AddJobCollectOther(&memberRecords, &serverRecord); //collecter->AddJobCollectOther(&memberRecords, &serverRecord);
collecter->AddJobEnqueueFinished();
/* jobs in queue 2: collect the (gradient) data and other stuff. This /* jobs in queue 2: collect the (gradient) data and other stuff. This
is a reduce process. */ is a reduce process. */
...@@ -354,9 +402,11 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -354,9 +402,11 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
not involved in this run. */ not involved in this run. */
//broadcaster->AddJobBroadcast(&serverModel, &membersAll); //broadcaster->AddJobBroadcast(&serverModel, &membersAll);
WaitForFinishing(); //WaitForFinishing();
} }
WaitForFinishing(active);
for (int i = 0; i < jworkers.count; i++) { for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i]; XWorkerJob * worker = (XWorkerJob*)jworkers[i];
worker->Clear(); worker->Clear();
......
...@@ -112,11 +112,8 @@ public: ...@@ -112,11 +112,8 @@ public:
/* initialize the models for running them */ /* initialize the models for running them */
void InitForRun(); void InitForRun();
/* mark the workers as LOCKED */ /* wait for finished states (i.e., all workers finish their jobs) */
void LockWorkers(); void WaitForFinishing(const int * activeJobWorkers);
/* wait for unlocked workers (i.e., all workers finish their jobs) */
void WaitForUnlockedWorkers(const int * activeJobWorkers);
/* get loss */ /* get loss */
float GetLoss(); float GetLoss();
......
...@@ -110,4 +110,55 @@ bool XWorker::IsEmpty() ...@@ -110,4 +110,55 @@ bool XWorker::IsEmpty()
return queue.IsEmpty(); return queue.IsEmpty();
} }
/* enqueue a counting job of a finished job */
void XWorker::EnqueueFinishedJob()
{
finishedQueue.Enqueue(NULL);
}
/* dequeue a counting job of a finished job */
void XWorker::DequeueFinishedJob()
{
finishedQueue.Dequeue();
}
/* wrapper of EnqueueFinished() */
void XWorker::EnqueueFinished(XList* args)
{
XWorker* worker = (XWorker*)args->GetItem(0);
worker->EnqueueFinishedJob();
}
/* wrapper of DequeueFinished() */
void XWorker::DequeueFinished(XList* args)
{
XWorker* worker = (XWorker*)args->GetItem(0);
worker->DequeueFinishedJob();
}
/* add a job of enqueuing a counting a finished job */
void XWorker::AddJobEnqueueFinished()
{
XList args;
args.Add(this);
if (isInstantRun)
XWorker::EnqueueFinished(&args);
else
queue.EnqueueJob((void*)(char*)XWorker::EnqueueFinished, &args);
}
/* add a job of dequeuing a counting a finished job */
void XWorker::AddJobDequeueFinished()
{
XList args;
args.Add(this);
if (isInstantRun)
XWorker::DequeueFinished(&args);
else
queue.EnqueueJob((void*)(char*)XWorker::DequeueFinished, &args);
}
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -55,7 +55,7 @@ protected: ...@@ -55,7 +55,7 @@ protected:
/* id of the worker */ /* id of the worker */
int id; int id;
/* the queue */ /* the queue of jobs */
XQueue queue; XQueue queue;
/* state of the worker */ /* state of the worker */
...@@ -63,6 +63,9 @@ protected: ...@@ -63,6 +63,9 @@ protected:
/* fire the flag of instant run */ /* fire the flag of instant run */
bool isInstantRun; bool isInstantRun;
/* the queue of counting finished jobs */
XQueue finishedQueue;
public: public:
/* constructor */ /* constructor */
...@@ -100,6 +103,26 @@ public: ...@@ -100,6 +103,26 @@ public:
/* whether the job queue is empty? */ /* whether the job queue is empty? */
bool IsEmpty(); bool IsEmpty();
/* enqueue a counting job of a finished job */
void EnqueueFinishedJob();
/* dequeue a counting job of a finished job */
void DequeueFinishedJob();
/* wrapper of EnqueueFinished() */
static
void EnqueueFinished(XList* args);
/* wrapper of DequeueFinished() */
static
void DequeueFinished(XList* args);
/* add a job of enqueuing a counting a finished job */
void AddJobEnqueueFinished();
/* add a job of dequeuing a counting a finished job */
void AddJobDequeueFinished();
}; };
} }
......
...@@ -223,49 +223,4 @@ bool XWorkerBroadcast::AddJobBroadcast(XModel * source, XList * targetList) ...@@ -223,49 +223,4 @@ bool XWorkerBroadcast::AddJobBroadcast(XModel * source, XList * targetList)
return true; return true;
} }
/*
mark the state of the parameter to FINISHED
>> source - the model that we are updating
>> pid - the parameter index
*/
void XWorkerBroadcast::FinishUpdateSingle(XModel * source, int pid)
{
source->params[pid].trainFlag = PARAM_STATE_UPDATED;
MUTEX_UNLOCK(source->params[pid].trainLock);
}
/* wrapper of FinishUpdateSingle */
void XWorkerBroadcast::FinishSingle(XList * args)
{
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(0);
XModel * source = (XModel*)args->GetItem(1);
int pid = args->GetInt(2);
broadcaster->FinishUpdateSingle(source, pid);
}
/*
add a new job of finishing the update
>> source - the model that we are updating
>> pid - the parameter index
*/
bool XWorkerBroadcast::AddJobFinish(XModel * source, int pid)
{
CheckNTErrors(source != NULL, "no input source tensor!");
CheckNTErrors(pid >= 0 && pid < source->paramNum, "illegal parameter index!");
XList args;
args.Add(this);
args.Add(source);
args.AddInt(pid);
if (isInstantRun)
XWorkerBroadcast::FinishSingle(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerBroadcast::FinishSingle, &args);
return true;
}
} }
...@@ -82,16 +82,6 @@ public: ...@@ -82,16 +82,6 @@ public:
/* add a new job of broadcasting data (for a model) */ /* add a new job of broadcasting data (for a model) */
bool AddJobBroadcast(XModel * source, XList * targetList); bool AddJobBroadcast(XModel * source, XList * targetList);
/* mark the state of the parameter to FINISHED */
void FinishUpdateSingle(XModel * source, int pid);
/* wrapper of FinishUpdateSingle */
static
void FinishSingle(XList * args);
/* add a new job of finishing the update */
bool AddJobFinish(XModel * source, int pid);
}; };
} }
......
...@@ -126,6 +126,8 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod ...@@ -126,6 +126,8 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
paramServer.flag = PARAM_STATE_COLLECTED; paramServer.flag = PARAM_STATE_COLLECTED;
if (updater != NULL) { if (updater != NULL) {
updater->AddJobUpdateSingle(server, memberAll, j, optimizer, broadcaster); updater->AddJobUpdateSingle(server, memberAll, j, optimizer, broadcaster);
updater->AddJobEnqueueFinished();
} }
} }
else if (finishedCount[j] > memberActive->count) { else if (finishedCount[j] > memberActive->count) {
...@@ -135,63 +137,6 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod ...@@ -135,63 +137,6 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
} }
} }
} }
else if (collectMode == DATA_COLLECT_REDUCESUM) {
for (int j = 0; j < server->paramNum; j++) {
bool ready = true;
XParamKeeper &paramServer = server->params[j];
/* tp[j]->isGradFinished is true only if the model finishes the computation
(in another process) */
if (paramServer.flag != PARAM_STATE_NOT_READY || !paramServer.param->isGradFinished)
continue;
/* check if all the models (or part of them) are ready */
for (int i = 0; i < memberActive->count; i++) {
XModel * source = (XModel*)memberActive->GetItem(i);
XParamKeeper &paramSource = source->params[j];
/* sp[j]->isGradFinished is true only if the model finishes the computation
(in another process) */
if (paramSource.flag == PARAM_STATE_COLLECTED ||
paramSource.flag == PARAM_STATE_UPDATED ||
!paramSource.param->isGradFinished)
{
ready = false;
break;
}
else if (paramSource.flag == PARAM_STATE_NOT_READY) {
paramSource.flag = PARAM_STATE_READY;
}
}
if (ready) {
XList tensorList(memberActive->count);
for (int i = 0; i < memberActive->count; i++) {
XModel * source = (XModel*)memberActive->GetItem(i);
tensorList.Add(source->params[j].param->grad);
}
/* data transmit */
CollectReduceSum(&tensorList, server->params[j].param->grad);
/* reset the flags */
for (int i = 0; i < memberActive->count; i++) {
XModel * source = (XModel*)memberActive->GetItem(i);
source->params[j].flag = PARAM_STATE_COLLECTED;
}
server->params[j].flag = PARAM_STATE_COLLECTED;
finished += memberActive->count;
/* we call model update (in another thread) and then
broadcast the new parameters to member models
(in another thread) */
updater->AddJobUpdateSingle(server, memberAll, j, optimizer, broadcaster);
}
}
}
else { else {
ShowNTErrors("Unsupported data collection mode!"); ShowNTErrors("Unsupported data collection mode!");
} }
......
...@@ -79,6 +79,7 @@ void XWorkerUpdate::UpdateParameter(XModel * server, XList * members, int pid, ...@@ -79,6 +79,7 @@ void XWorkerUpdate::UpdateParameter(XModel * server, XList * members, int pid,
/* broadcast the new parameter to other models (in anotehr worker/thread) */ /* broadcast the new parameter to other models (in anotehr worker/thread) */
broadcaster->AddJobBroadcastSingle(server, members, pid); broadcaster->AddJobBroadcastSingle(server, members, pid);
broadcaster->AddJobEnqueueFinished();
} }
/* /*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论