Commit 959545df by xiaotong

updates of worker pipeline

parent 7c1aeb5c
...@@ -394,15 +394,15 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -394,15 +394,15 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
isDataOK = false; isDataOK = false;
else { else {
/* job in queue 1: refresh the model */ /* job in queue 1: refresh the model */
worker->AddJobRefresh(worker->GetJobQueue(), jmodel); worker->AddJobRefresh(jmodel);
/* job in queue 1: run the model */ /* job in queue 1: run the model */
worker->AddJobNeuralNet(worker->GetJobQueue(), jmodel, worker->AddJobNeuralNet(jmodel,
worker->GetInput(), worker->GetOutput(), worker->GetInput(), worker->GetOutput(),
worker->GetGold(), worker->GetLoss()); worker->GetGold(), worker->GetLoss());
/* job in queue 1: make a record of the run */ /* job in queue 1: make a record of the run */
worker->AddJobRecord(worker->GetJobQueue(), &serverRecord); worker->AddJobRecord(&serverRecord);
/* job in queue 1: mark finished */ /* job in queue 1: mark finished */
worker->AddJobEnqueueFinished(); worker->AddJobEnqueueFinished();
...@@ -424,6 +424,9 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -424,6 +424,9 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
/* all member models */ /* all member models */
XList membersAll(jworkers.count); XList membersAll(jworkers.count);
/* job queues */
XList jobQueues;
for (int i = 0; i < jworkers.count; i++) { for (int i = 0; i < jworkers.count; i++) {
XWorkerJob* worker = (XWorkerJob*)jworkers[i]; XWorkerJob* worker = (XWorkerJob*)jworkers[i];
membersAll.Add(worker->GetModel()); membersAll.Add(worker->GetModel());
...@@ -431,13 +434,18 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -431,13 +434,18 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
members.Add(worker->GetModel()); members.Add(worker->GetModel());
} }
for (int i = 0; i < pworkers.count; i++) {
XWorker * worker = (XWorker*)pworkers[i];
jobQueues.Add(worker->GetJobQueue());
}
/* jobs in queue 2: collect the (gradient) data and other stuff. This /* jobs in queue 2: collect the (gradient) data and other stuff. This
is a reduce process. The collector will add a job in queue 3 is a reduce process. The collector will add a job in queue 3
to update the model. The updater will add a job job in queue 4 to to update the model. The updater will add a job job in queue 4 to
broadcast the lastest parameters to workers. NOTE that we would update broadcast the lastest parameters to workers. NOTE that we would update
a worker to the laster model parameters, even if it is not involved a worker to the laster model parameters, even if it is not involved
in this run. */ in this run. */
collecter->AddJobUpdateAll(collecter->GetJobQueue(), collecter->AddJobUpdateAll(&jobQueues,
&members, &membersAll, &serverModel, &members, &membersAll, &serverModel,
optimizer, updater, broadcaster); optimizer, updater, broadcaster);
collecter->AddJobEnqueueFinished(); collecter->AddJobEnqueueFinished();
...@@ -455,57 +463,4 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -455,57 +463,4 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
return isDataOK; return isDataOK;
} }
/* wait until all workers finish their job */
void XLeader::WaitForFinishing(int sleepTime)
{
while (1) {
bool finished = true;
if (finished) {
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob* worker = (XWorkerJob*)jworkers[i];
if (worker->GetJobNum() > 0) {
finished = false;
break;
}
}
}
if (finished) {
for (int i = 0; i < cworkers.count; i++) {
XWorkerJob* worker = (XWorkerJob*)cworkers[i];
if (worker->GetJobNum() > 0) {
finished = false;
break;
}
}
}
if (finished) {
for (int i = 0; i < uworkers.count; i++) {
XWorkerJob* worker = (XWorkerJob*)uworkers[i];
if (worker->GetJobNum() > 0) {
finished = false;
break;
}
}
}
if (finished) {
for (int i = 0; i < bworkers.count; i++) {
XWorkerJob* worker = (XWorkerJob*)bworkers[i];
if (worker->GetJobNum() > 0) {
finished = false;
break;
}
}
}
if (finished)
break;
XSleep(sleepTime);
}
}
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -161,9 +161,6 @@ public: ...@@ -161,9 +161,6 @@ public:
/* run the model (for one time) */ /* run the model (for one time) */
bool Run(XConfig * config, DataDistributeBase * dataDistributor, bool Run(XConfig * config, DataDistributeBase * dataDistributor,
XModel * model, XOptimizer * optimizer); XModel * model, XOptimizer * optimizer);
/* wait until all workers finish their job */
void WaitForFinishing(int sleepTime = SLEEP_TIME_IN_WAITING_FOR_JOBS);
}; };
} }
......
...@@ -51,11 +51,14 @@ void XWorkerBroadcast::SetBroadcastMode(DATA_BROADCAST_TYPE myMode) ...@@ -51,11 +51,14 @@ void XWorkerBroadcast::SetBroadcastMode(DATA_BROADCAST_TYPE myMode)
/* /*
broadcast data for a parameter broadcast data for a parameter
>> jobQueue - the queue where we push jobs
>> source - the data (as a model) that we want to broadcast >> source - the data (as a model) that we want to broadcast
>> targetList - the target places that we recieve the data >> targetList - the target places that we recieve the data
>> pid - the parameter index >> pid - the parameter index
*/ */
void XWorkerBroadcast::BroadcastDataSingle(XModel * source, XList * targetList, int pid) void XWorkerBroadcast::BroadcastDataSingle(XQueue * jobQueue,
XModel * source, XList * targetList,
int pid)
{ {
CheckNTErrors(source->params[pid].flag == PARAM_STATE_UPDATED, CheckNTErrors(source->params[pid].flag == PARAM_STATE_UPDATED,
"The parameter is not ready for broadcasting"); "The parameter is not ready for broadcasting");
...@@ -72,89 +75,29 @@ void XWorkerBroadcast::BroadcastDataSingle(XModel * source, XList * targetList, ...@@ -72,89 +75,29 @@ void XWorkerBroadcast::BroadcastDataSingle(XModel * source, XList * targetList,
} }
/* /*
broadcast data for a model
>> source - the data that we want to broadcast
>> targetList - the target places that we recieve the data
>> sleepTime - the waiting time in broadcasting
*/
void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, int sleepTime)
{
int finished = 0;
int * finishedFlag = new int[source->paramNum];
memset(finishedFlag, 0, sizeof(int) * source->paramNum);
/* check */
for (int i = 0; i < targetList->count; i++) {
XModel * target = (XModel*)targetList->GetItem(i);
CheckNTErrors(source->paramNum == target->paramNum, "Incompatiable models!");
}
/* the major body of broadcasting */
while (1) {
for (int i = 0; i < source->paramNum; i++) {
if (source->params[i].flag == PARAM_STATE_UPDATED && finishedFlag[i] == 0) {
/* broadcasting */
BroadcastDataSingle(source, targetList, i);
/* counting */
finished += targetList->count;
finishedFlag[i] = 1;
}
}
if (finished == source->paramNum * targetList->count)
break;
XSleep(sleepTime);
}
delete[] finishedFlag;
}
/*
wrapper of BroadcastDataSingle wrapper of BroadcastDataSingle
>> args - the list of arguments >> args - the list of arguments
*/ */
void XWorkerBroadcast::BroadcastSingle(XList * args) void XWorkerBroadcast::BroadcastSingle(XList * args)
{ {
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(0); int paramCount = 0;
XModel * source = (XModel*)args->GetItem(1);
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(paramCount++);
XQueue * jobQueue = (XQueue*)args->GetItem(paramCount++);
XModel * source = (XModel*)args->GetItem(paramCount++);
/* target models */ /* target models */
int targetNum = args->GetItemInt(2); int targetNum = args->GetItemInt(paramCount++);
XList target; XList target;
for (int i = 0; i < targetNum; i++) { for (int i = 0; i < targetNum; i++) {
XModel * model = (XModel*)args->GetItem(3 + i); XModel * model = (XModel*)args->GetItem(paramCount++);
target.Add(model); target.Add(model);
} }
/* parameter index */ /* parameter index */
int p = args->GetInt(3 + targetNum); int p = args->GetInt(paramCount++);
broadcaster->BroadcastDataSingle(source, &target, p);
}
/*
wrapper of BroadcastData
>> args - the list of arguments
*/
void XWorkerBroadcast::Broadcast(XList * args)
{
//fprintf(stderr, "broadcast 0\n");
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(0);
XModel * source = (XModel*)args->GetItem(1);
/* target models */ broadcaster->BroadcastDataSingle(jobQueue, source, &target, p);
int targetNum = args->GetItemInt(2);
XList target;
for (int i = 0; i < targetNum; i++) {
XModel * model = (XModel*)args->GetItem(3 + i);
target.Add(model);
}
broadcaster->BroadcastData(source, &target, SLEEP_TIME_IN_BROADCASTING);
//fprintf(stderr, "broadcast 1\n");
} }
/* /*
...@@ -174,11 +117,12 @@ void XWorkerBroadcast::BroadcastP2P(XTensor * source, XTensor * target) ...@@ -174,11 +117,12 @@ void XWorkerBroadcast::BroadcastP2P(XTensor * source, XTensor * target)
/* /*
add a new job of broadcasting data (for a parameter) add a new job of broadcasting data (for a parameter)
>> jobQueue - the queue that we push jobs here
>> source - the data that we want to broadcast >> source - the data that we want to broadcast
>> targetList - the target places that we recieve the data >> targetList - the target places that we recieve the data
>> pid - the parameter index >> pid - the parameter index
*/ */
bool XWorkerBroadcast::AddJobBroadcastSingle(XModel * source, XList * targetList, int pid) bool XWorkerBroadcast::AddJobBroadcastSingle(XQueue * jobQueue, XModel * source, XList * targetList, int pid)
{ {
CheckNTErrors(source != NULL, "no input source tensor!"); CheckNTErrors(source != NULL, "no input source tensor!");
CheckNTErrors(targetList != NULL, "no input target tensor list!"); CheckNTErrors(targetList != NULL, "no input target tensor list!");
...@@ -186,6 +130,7 @@ bool XWorkerBroadcast::AddJobBroadcastSingle(XModel * source, XList * targetList ...@@ -186,6 +130,7 @@ bool XWorkerBroadcast::AddJobBroadcastSingle(XModel * source, XList * targetList
XList args; XList args;
args.Add(this); args.Add(this);
args.Add(jobQueue);
args.Add(source); args.Add(source);
args.AddInt(targetList->count); args.AddInt(targetList->count);
args.AddList(targetList); args.AddList(targetList);
...@@ -199,28 +144,4 @@ bool XWorkerBroadcast::AddJobBroadcastSingle(XModel * source, XList * targetList ...@@ -199,28 +144,4 @@ bool XWorkerBroadcast::AddJobBroadcastSingle(XModel * source, XList * targetList
return true; return true;
} }
/*
add a new job of broadcasting data (for a model)
>> source - the data that we want to broadcast
>> targetList - the target places that we recieve the data
*/
bool XWorkerBroadcast::AddJobBroadcast(XModel * source, XList * targetList)
{
CheckNTErrors(source != NULL, "no input source tensor!");
CheckNTErrors(targetList != NULL, "no input target tensor list!");
XList args;
args.Add(this);
args.Add(source);
args.AddInt(targetList->count);
args.AddList(targetList);
if (isInstantRun)
XWorkerBroadcast::Broadcast(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerBroadcast::Broadcast, &args);
return true;
}
} }
...@@ -61,27 +61,17 @@ public: ...@@ -61,27 +61,17 @@ public:
void SetBroadcastMode(DATA_BROADCAST_TYPE myMode); void SetBroadcastMode(DATA_BROADCAST_TYPE myMode);
/* broadcast data for a parameter */ /* broadcast data for a parameter */
void BroadcastDataSingle(XModel * source, XList * targetList, int pid); void BroadcastDataSingle(XQueue * jobQueue, XModel * source, XList * targetList, int pid);
/* broadcast data for a model */
void BroadcastData(XModel * source, XList * targetList, int sleepTime);
/* wrapper of BroadcastDataSingle */ /* wrapper of BroadcastDataSingle */
static static
void BroadcastSingle(XList * args); void BroadcastSingle(XList * args);
/* wrapper of BroadcastData */
static
void Broadcast(XList * args);
/* P2P data broadcasting */ /* P2P data broadcasting */
void BroadcastP2P(XTensor * source, XTensor * target); void BroadcastP2P(XTensor * source, XTensor * target);
/* add a new job of broadcasting data (for a parameter) */ /* add a new job of broadcasting data (for a parameter) */
bool AddJobBroadcastSingle(XModel * source, XList * targetList, int pid); bool AddJobBroadcastSingle(XQueue * jobQueue, XModel * source, XList * targetList, int pid);
/* add a new job of broadcasting data (for a model) */
bool AddJobBroadcast(XModel * source, XList * targetList);
}; };
} }
......
...@@ -54,6 +54,7 @@ new parameters to all models. NOTE that this method just collect graident ...@@ -54,6 +54,7 @@ new parameters to all models. NOTE that this method just collect graident
from member models. Then it calls an XWorkerUpdate to update the parameters. from member models. Then it calls an XWorkerUpdate to update the parameters.
The XWorkerUpdate also calls an XWorkerBroadcast to broadcast the new parameter The XWorkerUpdate also calls an XWorkerBroadcast to broadcast the new parameter
to member models back. to member models back.
>> jobQueues - queues that we process the jobs
>> memberActive - member models that are active, i.e., have generated gradients >> memberActive - member models that are active, i.e., have generated gradients
>> memberAll - all member models >> memberAll - all member models
>> server - the server model >> server - the server model
...@@ -63,7 +64,8 @@ to member models back. ...@@ -63,7 +64,8 @@ to member models back.
models models
>> sleepTime - waiting time in collecting >> sleepTime - waiting time in collecting
*/ */
void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XModel * server, void XWorkerCollect::UpdateDataAll(XList * jobQueues, XList * memberActive, XList * memberAll,
XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XOptimizer * optimizer, XWorkerUpdate * updater,
XWorkerBroadcast * broadcaster, int sleepTime) XWorkerBroadcast * broadcaster, int sleepTime)
{ {
...@@ -82,6 +84,8 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod ...@@ -82,6 +84,8 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
XModel * source = (XModel*)memberActive->GetItem(i); XModel * source = (XModel*)memberActive->GetItem(i);
CheckNTErrors(source->paramNum == server->paramNum, "Incompatiable models!"); CheckNTErrors(source->paramNum == server->paramNum, "Incompatiable models!");
} }
CheckNTErrors(jobQueues->count == server->paramNum, "Incompatiable model!");
/* counts how many member models are collect for each parameters */ /* counts how many member models are collect for each parameters */
int * finishedCount = new int[server->paramNum]; int * finishedCount = new int[server->paramNum];
...@@ -125,7 +129,7 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod ...@@ -125,7 +129,7 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
if (finishedCount[j] == memberActive->count) { if (finishedCount[j] == memberActive->count) {
paramServer.flag = PARAM_STATE_COLLECTED; paramServer.flag = PARAM_STATE_COLLECTED;
if (updater != NULL) { if (updater != NULL) {
updater->AddJobUpdate(updater->GetJobQueue(), updater->AddJobUpdate((XQueue*)jobQueues->GetItem(j),
server, memberAll, j, server, memberAll, j,
optimizer, broadcaster); optimizer, broadcaster);
updater->AddJobEnqueueFinished(); updater->AddJobEnqueueFinished();
...@@ -155,29 +159,39 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod ...@@ -155,29 +159,39 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
/* wrapper of UpdateDataAll */ /* wrapper of UpdateDataAll */
void XWorkerCollect::UpdateAll(XList * args) void XWorkerCollect::UpdateAll(XList * args)
{ {
XWorkerCollect * collecter = (XWorkerCollect*)args->GetItem(0); int paramCount = 0;
int activeNum = args->GetInt(1);
XWorkerCollect * collecter = (XWorkerCollect*)args->GetItem(paramCount++);
int queueNum = args->GetInt(paramCount++);
XList jobQueues;
for (int i = 0; i < queueNum; i++) {
XQueue * queue = (XQueue*)args->GetItem(paramCount++);
jobQueues.Add(queue);
}
int activeNum = args->GetInt(paramCount++);
XList memberActive; XList memberActive;
for (int i = 0; i < activeNum; i++) { for (int i = 0; i < activeNum; i++) {
XModel * member = (XModel*)args->GetItem(2 + i); XModel * member = (XModel*)args->GetItem(paramCount++);
memberActive.Add(member); memberActive.Add(member);
} }
int allNum = args->GetInt(2 + activeNum); int allNum = args->GetInt(paramCount++);
XList memberAll; XList memberAll;
for (int i = 0; i < allNum; i++) { for (int i = 0; i < allNum; i++) {
XModel * member = (XModel*)args->GetItem(2 + activeNum + 1 + i); XModel * member = (XModel*)args->GetItem(paramCount++);
memberAll.Add(member); memberAll.Add(member);
} }
XModel * server = (XModel*)args->GetItem(2 + activeNum + 1 + allNum); XModel * server = (XModel*)args->GetItem(paramCount++);
XOptimizer * optimizer = (XOptimizer*)args->GetItem(2 + activeNum + 1 + allNum + 1); XOptimizer * optimizer = (XOptimizer*)args->GetItem(paramCount++);
XWorkerUpdate * updater = (XWorkerUpdate*)args->GetItem(2 + activeNum + 1 + allNum + 2); XWorkerUpdate * updater = (XWorkerUpdate*)args->GetItem(paramCount++);
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(2 + activeNum + 1 + allNum + 3); XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(paramCount++);
collecter->UpdateDataAll(&memberActive, &memberAll, server, collecter->UpdateDataAll(&jobQueues, &memberActive, &memberAll, server,
optimizer, updater, broadcaster, optimizer, updater, broadcaster,
SLEEP_TIME_IN_COLLECTING); SLEEP_TIME_IN_COLLECTING);
} }
...@@ -243,7 +257,7 @@ void XWorkerCollect::CollectAllReduce(XList * all) ...@@ -243,7 +257,7 @@ void XWorkerCollect::CollectAllReduce(XList * all)
/* /*
add a new job of collecting data, update the parameter and add a new job of collecting data, update the parameter and
broadcast the new parameter broadcast the new parameter
>> myQueue - the queue where we push the job >> jobQueues - the queues that we would use in following jobs
>> memberActive - member models that are active, i.e., have generated gradients >> memberActive - member models that are active, i.e., have generated gradients
>> memberAll - all member models >> memberAll - all member models
>> server - the server model >> server - the server model
...@@ -253,8 +267,11 @@ broadcast the new parameter ...@@ -253,8 +267,11 @@ broadcast the new parameter
models models
<< return - successful or not << return - successful or not
*/ */
bool XWorkerCollect::AddJobUpdateAll(XQueue * myQueue, XList * memberActive, XList * memberAll, XModel * server, bool XWorkerCollect::AddJobUpdateAll(XList * jobQueues,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster) XList * memberActive, XList * memberAll,
XModel * server,
XOptimizer * optimizer,
XWorkerUpdate * updater, XWorkerBroadcast * broadcaster)
{ {
CheckNTErrors(memberActive != NULL, "No input (active) member list!"); CheckNTErrors(memberActive != NULL, "No input (active) member list!");
CheckNTErrors(memberAll != NULL, "No input (all) member list!"); CheckNTErrors(memberAll != NULL, "No input (all) member list!");
...@@ -265,6 +282,8 @@ bool XWorkerCollect::AddJobUpdateAll(XQueue * myQueue, XList * memberActive, XLi ...@@ -265,6 +282,8 @@ bool XWorkerCollect::AddJobUpdateAll(XQueue * myQueue, XList * memberActive, XLi
XList args; XList args;
args.Add(this); args.Add(this);
args.AddInt(jobQueues->count);
args.AddList(jobQueues);
args.AddInt(memberActive->count); args.AddInt(memberActive->count);
args.AddList(memberActive); args.AddList(memberActive);
args.AddInt(memberAll->count); args.AddInt(memberAll->count);
...@@ -274,12 +293,10 @@ bool XWorkerCollect::AddJobUpdateAll(XQueue * myQueue, XList * memberActive, XLi ...@@ -274,12 +293,10 @@ bool XWorkerCollect::AddJobUpdateAll(XQueue * myQueue, XList * memberActive, XLi
args.Add(updater); args.Add(updater);
args.Add(broadcaster); args.Add(broadcaster);
XQueue &q = myQueue != NULL ? *myQueue : queue;
if (isInstantRun) if (isInstantRun)
XWorkerCollect::UpdateAll(&args); XWorkerCollect::UpdateAll(&args);
else else
q.EnqueueJob((void*)(char*)XWorkerCollect::UpdateAll, &args); queue.EnqueueJob((void*)(char*)XWorkerCollect::UpdateAll, &args);
return true; return true;
} }
......
...@@ -70,7 +70,7 @@ public: ...@@ -70,7 +70,7 @@ public:
from member models. Then it calls an XWorkerUpdate to update the parameters. from member models. Then it calls an XWorkerUpdate to update the parameters.
The XWorkerUpdate also calls an XWorkerBroadcast to broadcast the new parameter The XWorkerUpdate also calls an XWorkerBroadcast to broadcast the new parameter
to member models back. */ to member models back. */
void UpdateDataAll(XList * memberActive, XList * memberAll, XModel * server, void UpdateDataAll(XList * jobQueues, XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster, XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster,
int sleepTime); int sleepTime);
...@@ -88,7 +88,7 @@ public: ...@@ -88,7 +88,7 @@ public:
void CollectAllReduce(XList * all); void CollectAllReduce(XList * all);
/* add a new job of collecting data, update the parameter and broadcast the new parameter */ /* add a new job of collecting data, update the parameter and broadcast the new parameter */
bool AddJobUpdateAll(XQueue * myQueue, XList * memberActive, XList * memberAll, XModel * server, bool AddJobUpdateAll(XList * jobQueues, XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster); XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster);
}; };
......
...@@ -175,11 +175,10 @@ int XWorkerJob::GetPredictNum() ...@@ -175,11 +175,10 @@ int XWorkerJob::GetPredictNum()
/* /*
add a new job of model refreshment add a new job of model refreshment
>> myQueue - the queue where we push the job
>> myModel - the model >> myModel - the model
<< return - succeeded or not << return - succeeded or not
*/ */
bool XWorkerJob::AddJobRefresh(XQueue * myQueue, XModel * myModel) bool XWorkerJob::AddJobRefresh(XModel * myModel)
{ {
//fprintf(stderr, "refresh 0\n"); //fprintf(stderr, "refresh 0\n");
...@@ -188,12 +187,10 @@ bool XWorkerJob::AddJobRefresh(XQueue * myQueue, XModel * myModel) ...@@ -188,12 +187,10 @@ bool XWorkerJob::AddJobRefresh(XQueue * myQueue, XModel * myModel)
XList args(1); XList args(1);
args.Add(myModel); args.Add(myModel);
XQueue &q = myQueue != NULL ? *myQueue : queue;
if(isInstantRun) if(isInstantRun)
XModel::Refresh(&args); XModel::Refresh(&args);
else else
q.EnqueueJob((void*)(char*)XModel::Refresh, &args); queue.EnqueueJob((void*)(char*)XModel::Refresh, &args);
//fprintf(stderr, "refresh 1\n"); //fprintf(stderr, "refresh 1\n");
...@@ -202,7 +199,6 @@ bool XWorkerJob::AddJobRefresh(XQueue * myQueue, XModel * myModel) ...@@ -202,7 +199,6 @@ bool XWorkerJob::AddJobRefresh(XQueue * myQueue, XModel * myModel)
/* /*
add a new job of neural network forward and backward computation (with the input) add a new job of neural network forward and backward computation (with the input)
>> myQueue - the queue where we push the job
>> myModel - the model >> myModel - the model
>> inputs - inputs of the neural network >> inputs - inputs of the neural network
>> outputs - outputs of the neural network >> outputs - outputs of the neural network
...@@ -210,8 +206,9 @@ add a new job of neural network forward and backward computation (with the input ...@@ -210,8 +206,9 @@ add a new job of neural network forward and backward computation (with the input
>> losses - losses of the outputs respect to the gold standards >> losses - losses of the outputs respect to the gold standards
<< return - succeeded or not << return - succeeded or not
*/ */
bool XWorkerJob::AddJobNeuralNet(XQueue * myQueue, XModel * myModel, bool XWorkerJob::AddJobNeuralNet(XModel * myModel,
XList * inputs, XList * outputs, XList * golds, XList * losses) XList * inputs, XList * outputs,
XList * golds, XList * losses)
{ {
CheckNTErrors(myModel != NULL, "no input neural network!"); CheckNTErrors(myModel != NULL, "no input neural network!");
CheckNTErrors(inputs != NULL, "no inputs of the model!"); CheckNTErrors(inputs != NULL, "no inputs of the model!");
...@@ -224,12 +221,10 @@ bool XWorkerJob::AddJobNeuralNet(XQueue * myQueue, XModel * myModel, ...@@ -224,12 +221,10 @@ bool XWorkerJob::AddJobNeuralNet(XQueue * myQueue, XModel * myModel,
args.Add(golds); args.Add(golds);
args.Add(losses); args.Add(losses);
XQueue &q = myQueue != NULL ? *myQueue : queue;
if(isInstantRun) if(isInstantRun)
XModel::Run(&args); XModel::Run(&args);
else else
q.EnqueueJob((void*)(char*)XModel::Run, &args); queue.EnqueueJob((void*)(char*)XModel::Run, &args);
SetState(XWORKER_STARTED); SetState(XWORKER_STARTED);
...@@ -260,21 +255,18 @@ void XWorkerJob::RecordMeStatic(XList* args) ...@@ -260,21 +255,18 @@ void XWorkerJob::RecordMeStatic(XList* args)
/* /*
add a new job of recording the running of the nerual network add a new job of recording the running of the nerual network
>> myQueue - the queue where we push the job
>> serverRecord - the model record on the server side >> serverRecord - the model record on the server side
*/ */
bool XWorkerJob::AddJobRecord(XQueue * myQueue, XNNRecord * serverRecord) bool XWorkerJob::AddJobRecord(XNNRecord * serverRecord)
{ {
XList args; XList args;
args.Add(this); args.Add(this);
args.Add(serverRecord); args.Add(serverRecord);
XQueue &q = myQueue != NULL ? *myQueue : queue;
if (isInstantRun) if (isInstantRun)
XWorkerJob::RecordMeStatic(&args); XWorkerJob::RecordMeStatic(&args);
else else
q.EnqueueJob((void*)(char*)XWorkerJob::RecordMeStatic, &args); queue.EnqueueJob((void*)(char*)XWorkerJob::RecordMeStatic, &args);
return true; return true;
} }
......
...@@ -107,13 +107,13 @@ public: ...@@ -107,13 +107,13 @@ public:
int GetPredictNum(); int GetPredictNum();
/* add a new job of model refreshment */ /* add a new job of model refreshment */
bool AddJobRefresh(XQueue * myQueue, XModel * myModel); bool AddJobRefresh(XModel * myModel);
/* add a new job of neural network forward and backward computation (with the input) */ /* add a new job of neural network forward and backward computation (with the input) */
bool AddJobNeuralNet(XQueue * myQueue, XModel * myModel, XList * inputs, XList * outputs, XList * golds, XList * losses); bool AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs, XList * golds, XList * losses);
/* add a new job of recording the running of the nerual network */ /* add a new job of recording the running of the nerual network */
bool AddJobRecord(XQueue * myQueue, XNNRecord * serverRecord); bool AddJobRecord(XNNRecord * serverRecord);
private: private:
/* wrapper of RecordMe */ /* wrapper of RecordMe */
......
...@@ -54,13 +54,14 @@ XOptimizer * XWorkerUpdate::GetOptimizer() ...@@ -54,13 +54,14 @@ XOptimizer * XWorkerUpdate::GetOptimizer()
/* /*
update a parameter of a model update a parameter of a model
>> jobQueue - the queue that place the jobs called here
>> model - the model that we want to update (on the server side) >> model - the model that we want to update (on the server side)
>> members - models that would share the updated parameters >> members - models that would share the updated parameters
>> pid - the parameter index >> pid - the parameter index
>> optimizer - the optimizer >> optimizer - the optimizer
>> broadcaster - the worker that would broadcast the new parameter to members >> broadcaster - the worker that would broadcast the new parameter to members
*/ */
void XWorkerUpdate::UpdateParameter(XModel * server, XList * members, int pid, void XWorkerUpdate::UpdateParameter(XQueue * jobQueue, XModel * server, XList * members, int pid,
XOptimizer * optimizer, XWorkerBroadcast * broadcaster) XOptimizer * optimizer, XWorkerBroadcast * broadcaster)
{ {
...@@ -78,7 +79,7 @@ void XWorkerUpdate::UpdateParameter(XModel * server, XList * members, int pid, ...@@ -78,7 +79,7 @@ void XWorkerUpdate::UpdateParameter(XModel * server, XList * members, int pid,
server->params[pid].flag = PARAM_STATE_UPDATED; server->params[pid].flag = PARAM_STATE_UPDATED;
/* broadcast the new parameter to other models (in anotehr worker/thread) */ /* broadcast the new parameter to other models (in anotehr worker/thread) */
broadcaster->AddJobBroadcastSingle(server, members, pid); broadcaster->AddJobBroadcastSingle(jobQueue, server, members, pid);
broadcaster->AddJobEnqueueFinished(); broadcaster->AddJobEnqueueFinished();
} }
...@@ -88,36 +89,40 @@ wrapper of UpdateParameter ...@@ -88,36 +89,40 @@ wrapper of UpdateParameter
*/ */
void XWorkerUpdate::Update(XList * args) void XWorkerUpdate::Update(XList * args)
{ {
int paramCount = 0;
CheckNTErrors(args != NULL && args->count >= 6, "Illegal argument list!"); CheckNTErrors(args != NULL && args->count >= 6, "Illegal argument list!");
XWorkerUpdate * updater = (XWorkerUpdate*)args->GetItem(0); XWorkerUpdate * updater = (XWorkerUpdate*)args->GetItem(paramCount++);
XModel * server = (XModel*)args->GetItem(1); XQueue * jobQueue = (XQueue*)args->GetItem(paramCount++);
int memNum = args->GetInt(2); XModel * server = (XModel*)args->GetItem(paramCount++);
int memNum = args->GetInt(paramCount++);
XList members; XList members;
for (int i = 0; i < memNum; i++) { for (int i = 0; i < memNum; i++) {
XModel * member = (XModel*)args->GetItem(3 + i); XModel * member = (XModel*)args->GetItem(paramCount++);
members.Add(member); members.Add(member);
} }
int pid = args->GetInt(3 + memNum); int pid = args->GetInt(paramCount++);
XOptimizer * optimizer = (XOptimizer*)args->GetItem(3 + memNum + 1); XOptimizer * optimizer = (XOptimizer*)args->GetItem(paramCount++);
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(3 + memNum + 2); XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(paramCount++);
if(updater != NULL) if(updater != NULL)
updater->UpdateParameter(server, &members, pid, optimizer, broadcaster); updater->UpdateParameter(jobQueue, server, &members, pid, optimizer, broadcaster);
} }
/* /*
add a new job of model update (for a parameter) add a new job of model update (for a parameter)
>> myQueue - the queue where we push the job >> jobQueue - the queue for sub-jobs executed in the job
>> model - the model that we want to update (on the server side) >> model - the model that we want to update (on the server side)
>> members - models that would share the updated parameters >> members - models that would share the updated parameters
>> pid - the parameter index >> pid - the parameter index
>> optimizer - the optimizer >> optimizer - the optimizer
>> broadcaster - the worker that would broadcast the new parameter to members >> broadcaster - the worker that would broadcast the new parameter to members
*/ */
bool XWorkerUpdate::AddJobUpdate(XQueue * myQueue, XModel * model, XList * members, int pid, bool XWorkerUpdate::AddJobUpdate(XQueue * jobQueue,
XModel * model, XList * members, int pid,
XOptimizer * optimizer, XWorkerBroadcast * broadcaster) XOptimizer * optimizer, XWorkerBroadcast * broadcaster)
{ {
CheckNTErrors(model != NULL, "No input model!"); CheckNTErrors(model != NULL, "No input model!");
...@@ -128,6 +133,7 @@ bool XWorkerUpdate::AddJobUpdate(XQueue * myQueue, XModel * model, XList * membe ...@@ -128,6 +133,7 @@ bool XWorkerUpdate::AddJobUpdate(XQueue * myQueue, XModel * model, XList * membe
XList args; XList args;
args.Add(this); args.Add(this);
args.Add(jobQueue);
args.Add(model); args.Add(model);
args.AddInt(members->count); args.AddInt(members->count);
args.AddList(members); args.AddList(members);
...@@ -135,12 +141,10 @@ bool XWorkerUpdate::AddJobUpdate(XQueue * myQueue, XModel * model, XList * membe ...@@ -135,12 +141,10 @@ bool XWorkerUpdate::AddJobUpdate(XQueue * myQueue, XModel * model, XList * membe
args.Add(optimizer); args.Add(optimizer);
args.Add(broadcaster); args.Add(broadcaster);
XQueue &q = myQueue != NULL ? *myQueue : queue;
if (isInstantRun) if (isInstantRun)
XWorkerUpdate::Update(&args); XWorkerUpdate::Update(&args);
else else
q.EnqueueJob((void*)(char*)XWorkerUpdate::Update, &args); queue.EnqueueJob((void*)(char*)XWorkerUpdate::Update, &args);
return true; return true;
} }
......
...@@ -57,7 +57,7 @@ public: ...@@ -57,7 +57,7 @@ public:
XOptimizer * GetOptimizer(); XOptimizer * GetOptimizer();
/* update the parameter */ /* update the parameter */
void UpdateParameter(XModel * server, XList * members, int pid, void UpdateParameter(XQueue * jobQueue, XModel * server, XList * members, int pid,
XOptimizer * optimizer, XWorkerBroadcast * broadcaster); XOptimizer * optimizer, XWorkerBroadcast * broadcaster);
...@@ -67,7 +67,7 @@ public: ...@@ -67,7 +67,7 @@ public:
/* add a new job of model update (for a parameter) */ /* add a new job of model update (for a parameter) */
bool AddJobUpdate(XQueue * myQueue, XModel * model, XList * members, int pid, bool AddJobUpdate(XQueue * jobQueue, XModel * model, XList * members, int pid,
XOptimizer * optimizer, XWorkerBroadcast * broadcaster); XOptimizer * optimizer, XWorkerBroadcast * broadcaster);
}; };
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论