Commit 7c1aeb5c by xiaotong

add a new pipeline for parameter update

parent f408c730
...@@ -66,6 +66,10 @@ void XLeader::Init() ...@@ -66,6 +66,10 @@ void XLeader::Init()
delete (XWorkerBroadcast*)bworkers.GetItem(i); delete (XWorkerBroadcast*)bworkers.GetItem(i);
bworkers.Clear(); bworkers.Clear();
for(int i = 0; i < pworkers.count; i++)
delete (XWorker*)pworkers.GetItem(i);
pworkers.Clear();
serverRecord.Clear(); serverRecord.Clear();
} }
...@@ -129,6 +133,7 @@ void XLeader::InitForRun() ...@@ -129,6 +133,7 @@ void XLeader::InitForRun()
workers.AddList(&cworkers); workers.AddList(&cworkers);
workers.AddList(&uworkers); workers.AddList(&uworkers);
workers.AddList(&bworkers); workers.AddList(&bworkers);
workers.AddList(&pworkers);
for (int i = 0; i < workers.count; i++) { for (int i = 0; i < workers.count; i++) {
XWorker* worker = (XWorker*)workers[i]; XWorker* worker = (XWorker*)workers[i];
...@@ -243,6 +248,11 @@ void XLeader::SetInstantRun(bool flag) ...@@ -243,6 +248,11 @@ void XLeader::SetInstantRun(bool flag)
XWorkerJob * worker = (XWorkerJob*)bworkers.GetItem(i); XWorkerJob * worker = (XWorkerJob*)bworkers.GetItem(i);
worker->SetInstantRun(flag); worker->SetInstantRun(flag);
} }
for (int i = 0; i < pworkers.count; i++) {
XWorker * worker = (XWorker*)pworkers.GetItem(i);
worker->SetInstantRun(flag);
}
} }
/* start the workers */ /* start the workers */
...@@ -257,17 +267,22 @@ void XLeader::Start() ...@@ -257,17 +267,22 @@ void XLeader::Start()
} }
for (int i = 0; i < cworkers.count; i++) { for (int i = 0; i < cworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)cworkers.GetItem(i); XWorkerCollect * worker = (XWorkerCollect*)cworkers.GetItem(i);
worker->Start(); worker->Start();
} }
for (int i = 0; i < uworkers.count; i++) { for (int i = 0; i < uworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)uworkers.GetItem(i); XWorkerUpdate * worker = (XWorkerUpdate*)uworkers.GetItem(i);
worker->Start(); worker->Start();
} }
for (int i = 0; i < bworkers.count; i++) { for (int i = 0; i < bworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)bworkers.GetItem(i); XWorkerBroadcast * worker = (XWorkerBroadcast*)bworkers.GetItem(i);
worker->Start();
}
for (int i = 0; i < pworkers.count; i++) {
XWorker * worker = (XWorker*)pworkers.GetItem(i);
worker->Start(); worker->Start();
} }
} }
...@@ -326,6 +341,18 @@ void XLeader::AddJobBroadcastWorker() ...@@ -326,6 +341,18 @@ void XLeader::AddJobBroadcastWorker()
} }
/* /*
add a parameter worker (or a pipeline)
>> n - number of parameters
*/
void XLeader::AddJobParamterWorker(int n)
{
for (int i = 0; i < n; i++) {
XWorker * worker = new XWorker();
pworkers.Add(worker);
}
}
/*
run the model (for one time). Basically this is a map-reduce process. run the model (for one time). Basically this is a map-reduce process.
>> config - the configuration >> config - the configuration
>> dataDistributor - data distributor >> dataDistributor - data distributor
...@@ -340,6 +367,7 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -340,6 +367,7 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
CheckNTErrors(cworkers.count > 0, "No cworkers!"); CheckNTErrors(cworkers.count > 0, "No cworkers!");
CheckNTErrors(uworkers.count > 0, "No uworkers!"); CheckNTErrors(uworkers.count > 0, "No uworkers!");
CheckNTErrors(bworkers.count > 0, "No bworkers!"); CheckNTErrors(bworkers.count > 0, "No bworkers!");
CheckNTErrors(pworkers.count > 0, "No pworkers!");
bool isDataOK = true; bool isDataOK = true;
bool isToUpdate = (optimizer != NULL); bool isToUpdate = (optimizer != NULL);
...@@ -366,15 +394,15 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -366,15 +394,15 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
isDataOK = false; isDataOK = false;
else { else {
/* job in queue 1: refresh the model */ /* job in queue 1: refresh the model */
worker->AddJobRefresh(jmodel); worker->AddJobRefresh(worker->GetJobQueue(), jmodel);
/* job in queue 1: run the model */ /* job in queue 1: run the model */
worker->AddJobNeuralNet(jmodel, worker->AddJobNeuralNet(worker->GetJobQueue(), jmodel,
worker->GetInput(), worker->GetOutput(), worker->GetInput(), worker->GetOutput(),
worker->GetGold(), worker->GetLoss()); worker->GetGold(), worker->GetLoss());
/* job in queue 1: make a record of the run */ /* job in queue 1: make a record of the run */
worker->AddJobRecord(&serverRecord); worker->AddJobRecord(worker->GetJobQueue(), &serverRecord);
/* job in queue 1: mark finished */ /* job in queue 1: mark finished */
worker->AddJobEnqueueFinished(); worker->AddJobEnqueueFinished();
...@@ -409,7 +437,8 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -409,7 +437,8 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
broadcast the lastest parameters to workers. NOTE that we would update broadcast the lastest parameters to workers. NOTE that we would update
a worker to the laster model parameters, even if it is not involved a worker to the laster model parameters, even if it is not involved
in this run. */ in this run. */
collecter->AddJobUpdateAll(&members, &membersAll, &serverModel, collecter->AddJobUpdateAll(collecter->GetJobQueue(),
&members, &membersAll, &serverModel,
optimizer, updater, broadcaster); optimizer, updater, broadcaster);
collecter->AddJobEnqueueFinished(); collecter->AddJobEnqueueFinished();
} }
......
...@@ -88,6 +88,12 @@ protected: ...@@ -88,6 +88,12 @@ protected:
/* data-broadcasting workers */ /* data-broadcasting workers */
XList bworkers; XList bworkers;
/* parameter workers (each for a paramter). cworkers,
uworkers, and bworkers would push their jobs into
parameter workers. So they are actually pipelines
of jobs. */
XList pworkers;
public: public:
/* constructor */ /* constructor */
XLeader(); XLeader();
...@@ -149,6 +155,9 @@ public: ...@@ -149,6 +155,9 @@ public:
/* add a data-broadcasting worker */ /* add a data-broadcasting worker */
void AddJobBroadcastWorker(); void AddJobBroadcastWorker();
/* add a parameter worker (or a pipeline) */
void AddJobParamterWorker(int n);
/* run the model (for one time) */ /* run the model (for one time) */
bool Run(XConfig * config, DataDistributeBase * dataDistributor, bool Run(XConfig * config, DataDistributeBase * dataDistributor,
XModel * model, XOptimizer * optimizer); XModel * model, XOptimizer * optimizer);
......
...@@ -101,7 +101,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -101,7 +101,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
int jobNum = 0; int jobNum = 0;
int accumulation = config->GetInt("accumulation", 1); int accumulation = config->GetInt("accumulation", 1);
int nwarmup = config->GetInt("nwarmup", 0); int nwarmup = config->GetInt("nwarmup", 0);
int lrate = optimizer->GetLearningRate(); float lrate = optimizer->GetLearningRate();
CheckNTErrors(accumulation >= 1, "accumulation must be larger than 0!"); CheckNTErrors(accumulation >= 1, "accumulation must be larger than 0!");
...@@ -118,6 +118,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -118,6 +118,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
leader.AddJobCollectWorker(); leader.AddJobCollectWorker();
leader.AddJobUpdateWorker(model, optimizer); leader.AddJobUpdateWorker(model, optimizer);
leader.AddJobBroadcastWorker(); leader.AddJobBroadcastWorker();
leader.AddJobParamterWorker(model->paramNum);
//leader.SetInstantRun(); //leader.SetInstantRun();
leader.SetServerModel(config, model); leader.SetServerModel(config, model);
leader.Start(); leader.Start();
......
...@@ -70,6 +70,12 @@ int XWorker::GetID() ...@@ -70,6 +70,12 @@ int XWorker::GetID()
return id; return id;
} }
/* get job queue */
XQueue * XWorker::GetJobQueue()
{
return &queue;
}
/* set the flag of instant run */ /* set the flag of instant run */
void XWorker::SetInstantRun(bool flag) void XWorker::SetInstantRun(bool flag)
{ {
......
...@@ -86,6 +86,9 @@ public: ...@@ -86,6 +86,9 @@ public:
/* get worker id */ /* get worker id */
int GetID(); int GetID();
/* get job queue */
XQueue * GetJobQueue();
/* set the flag of instant run */ /* set the flag of instant run */
void SetInstantRun(bool flag = true); void SetInstantRun(bool flag = true);
......
...@@ -125,7 +125,9 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod ...@@ -125,7 +125,9 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
if (finishedCount[j] == memberActive->count) { if (finishedCount[j] == memberActive->count) {
paramServer.flag = PARAM_STATE_COLLECTED; paramServer.flag = PARAM_STATE_COLLECTED;
if (updater != NULL) { if (updater != NULL) {
updater->AddJobUpdateSingle(server, memberAll, j, optimizer, broadcaster); updater->AddJobUpdate(updater->GetJobQueue(),
server, memberAll, j,
optimizer, broadcaster);
updater->AddJobEnqueueFinished(); updater->AddJobEnqueueFinished();
} }
} }
...@@ -241,6 +243,7 @@ void XWorkerCollect::CollectAllReduce(XList * all) ...@@ -241,6 +243,7 @@ void XWorkerCollect::CollectAllReduce(XList * all)
/* /*
add a new job of collecting data, update the parameter and add a new job of collecting data, update the parameter and
broadcast the new parameter broadcast the new parameter
>> myQueue - the queue where we push the job
>> memberActive - member models that are active, i.e., have generated gradients >> memberActive - member models that are active, i.e., have generated gradients
>> memberAll - all member models >> memberAll - all member models
>> server - the server model >> server - the server model
...@@ -250,7 +253,7 @@ broadcast the new parameter ...@@ -250,7 +253,7 @@ broadcast the new parameter
models models
<< return - successful or not << return - successful or not
*/ */
bool XWorkerCollect::AddJobUpdateAll(XList * memberActive, XList * memberAll, XModel * server, bool XWorkerCollect::AddJobUpdateAll(XQueue * myQueue, XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster) XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster)
{ {
CheckNTErrors(memberActive != NULL, "No input (active) member list!"); CheckNTErrors(memberActive != NULL, "No input (active) member list!");
...@@ -271,124 +274,12 @@ bool XWorkerCollect::AddJobUpdateAll(XList * memberActive, XList * memberAll, XM ...@@ -271,124 +274,12 @@ bool XWorkerCollect::AddJobUpdateAll(XList * memberActive, XList * memberAll, XM
args.Add(updater); args.Add(updater);
args.Add(broadcaster); args.Add(broadcaster);
if (isInstantRun) XQueue &q = myQueue != NULL ? *myQueue : queue;
XWorkerCollect::UpdateAll(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerCollect::UpdateAll, &args);
return true;
}
/*
add a new job of collecting data
>> sourceList - the list of models that we want collect data from
>> target - the destination of the collection
<< return - successful or not
*/
bool XWorkerCollect::AddJobCollect(XList * sourceList, XModel * target)
{
CheckNTErrors(sourceList != NULL, "no input source model list!");
CheckNTErrors(target != NULL, "no input target model!");
XList args;
args.Add(this);
args.AddInt(sourceList->count);
args.AddList(sourceList);
args.AddInt(0);
args.Add(target);
args.Add(NULL);
args.Add(NULL);
args.Add(NULL);
if (isInstantRun) if (isInstantRun)
XWorkerCollect::UpdateAll(&args); XWorkerCollect::UpdateAll(&args);
else else
queue.EnqueueJob((void*)(char*)XWorkerCollect::UpdateAll, &args); q.EnqueueJob((void*)(char*)XWorkerCollect::UpdateAll, &args);
return true;
}
/*
collect the data of the run (i.e., loss). This is a reducer.
>> sourceList - the list of record
>> target - the record that we keep the reduce result
>> sleepTime - waiting time in collecting data
*/
void XWorkerCollect::CollectOtherData(XList* sourceList, XNNRecord* target, int sleepTime)
{
int finished = 0;
int* flags = new int[sourceList->count];
for (int i = 0; i < sourceList->count; i++)
flags[i] = 0;
while (1) {
for (int i = 0; i < sourceList->count; i++) {
if (flags[i] != 0)
continue;
XNNRecord* source = (XNNRecord*)sourceList->GetItem(i);
if (source->state == XWORKER_FINISHED) {
if(target != source)
target->Update(*source);
flags[i] = 1;
finished++;
}
}
if (finished == sourceList->count)
break;
XSleep(sleepTime);
}
delete[] flags;
}
/* wrapper of CollectOtherData */
void XWorkerCollect::CollectOther(XList* args)
{
//fprintf(stderr, "collect data other 0\n");
XWorkerCollect* collecter = (XWorkerCollect*)args->GetItem(0);
int sourceNum = args->GetItemInt(1);
/* the source records */
XList source;
for (int i = 0; i < sourceNum; i++) {
XNNRecord * record = (XNNRecord*)args->GetItem(2 + i);
source.Add(record);
}
/* the target record */
XNNRecord* target = (XNNRecord*)args->GetItem(2 + sourceNum);
collecter->CollectOtherData(&source, target, SLEEP_TIME_IN_COLLECTING_OTHER);
//fprintf(stderr, "collect data other 1\n");
}
/*
add a new job of collecting data of the run (i.e., loss)
collect the data of the run (i.e., loss). This is a reducer.
>> sourceList - the list of record
>> target - the record that we keep the reduce result
*/
bool XWorkerCollect::AddJobCollectOther(XList* sourceList, XNNRecord* target)
{
CheckNTErrors(sourceList != NULL, "no input source record list!");
CheckNTErrors(target != NULL, "no input target record!");
XList args;
args.Add(this);
args.AddInt(sourceList->count);
args.AddList(sourceList);
args.Add(target);
if (isInstantRun)
XWorkerCollect::CollectOther(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerCollect::CollectOther, &args);
return true; return true;
} }
......
...@@ -88,21 +88,9 @@ public: ...@@ -88,21 +88,9 @@ public:
void CollectAllReduce(XList * all); void CollectAllReduce(XList * all);
/* add a new job of collecting data, update the parameter and broadcast the new parameter */ /* add a new job of collecting data, update the parameter and broadcast the new parameter */
bool AddJobUpdateAll(XList * memberActive, XList * memberAll, XModel * server, bool AddJobUpdateAll(XQueue * myQueue, XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster); XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster);
/* add a new job of collecting data */
bool AddJobCollect(XList * sourceList, XModel * target);
/* collect the data of the run (i.e., loss). This is a reducer. */
void CollectOtherData(XList * sourceList, XNNRecord * target, int sleepTime);
/* wrapper of CollectOtherData */
static
void CollectOther(XList * args);
/* add a new job of collecting data of the run (i.e., loss) */
bool AddJobCollectOther(XList * sourceList, XNNRecord * target);
}; };
} }
......
...@@ -175,10 +175,11 @@ int XWorkerJob::GetPredictNum() ...@@ -175,10 +175,11 @@ int XWorkerJob::GetPredictNum()
/* /*
add a new job of model refreshment add a new job of model refreshment
>> myQueue - the queue where we push the job
>> myModel - the model >> myModel - the model
<< return - succeeded or not << return - succeeded or not
*/ */
bool XWorkerJob::AddJobRefresh(XModel * myModel) bool XWorkerJob::AddJobRefresh(XQueue * myQueue, XModel * myModel)
{ {
//fprintf(stderr, "refresh 0\n"); //fprintf(stderr, "refresh 0\n");
...@@ -187,10 +188,12 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel) ...@@ -187,10 +188,12 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel)
XList args(1); XList args(1);
args.Add(myModel); args.Add(myModel);
XQueue &q = myQueue != NULL ? *myQueue : queue;
if(isInstantRun) if(isInstantRun)
XModel::Refresh(&args); XModel::Refresh(&args);
else else
queue.EnqueueJob((void*)(char*)XModel::Refresh, &args); q.EnqueueJob((void*)(char*)XModel::Refresh, &args);
//fprintf(stderr, "refresh 1\n"); //fprintf(stderr, "refresh 1\n");
...@@ -199,6 +202,7 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel) ...@@ -199,6 +202,7 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel)
/* /*
add a new job of neural network forward and backward computation (with the input) add a new job of neural network forward and backward computation (with the input)
>> myQueue - the queue where we push the job
>> myModel - the model >> myModel - the model
>> inputs - inputs of the neural network >> inputs - inputs of the neural network
>> outputs - outputs of the neural network >> outputs - outputs of the neural network
...@@ -206,7 +210,7 @@ add a new job of neural network forward and backward computation (with the input ...@@ -206,7 +210,7 @@ add a new job of neural network forward and backward computation (with the input
>> losses - losses of the outputs respect to the gold standards >> losses - losses of the outputs respect to the gold standards
<< return - succeeded or not << return - succeeded or not
*/ */
bool XWorkerJob::AddJobNeuralNet(XModel * myModel, bool XWorkerJob::AddJobNeuralNet(XQueue * myQueue, XModel * myModel,
XList * inputs, XList * outputs, XList * golds, XList * losses) XList * inputs, XList * outputs, XList * golds, XList * losses)
{ {
CheckNTErrors(myModel != NULL, "no input neural network!"); CheckNTErrors(myModel != NULL, "no input neural network!");
...@@ -220,10 +224,12 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel, ...@@ -220,10 +224,12 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel,
args.Add(golds); args.Add(golds);
args.Add(losses); args.Add(losses);
XQueue &q = myQueue != NULL ? *myQueue : queue;
if(isInstantRun) if(isInstantRun)
XModel::Run(&args); XModel::Run(&args);
else else
queue.EnqueueJob((void*)(char*)XModel::Run, &args); q.EnqueueJob((void*)(char*)XModel::Run, &args);
SetState(XWORKER_STARTED); SetState(XWORKER_STARTED);
...@@ -254,18 +260,21 @@ void XWorkerJob::RecordMeStatic(XList* args) ...@@ -254,18 +260,21 @@ void XWorkerJob::RecordMeStatic(XList* args)
/* /*
add a new job of recording the running of the nerual network add a new job of recording the running of the nerual network
>> >> myQueue - the queue where we push the job
>> serverRecord - the model record on the server side
*/ */
bool XWorkerJob::AddJobRecord(XNNRecord * serverRecord) bool XWorkerJob::AddJobRecord(XQueue * myQueue, XNNRecord * serverRecord)
{ {
XList args; XList args;
args.Add(this); args.Add(this);
args.Add(serverRecord); args.Add(serverRecord);
XQueue &q = myQueue != NULL ? *myQueue : queue;
if (isInstantRun) if (isInstantRun)
XWorkerJob::RecordMeStatic(&args); XWorkerJob::RecordMeStatic(&args);
else else
queue.EnqueueJob((void*)(char*)XWorkerJob::RecordMeStatic, &args); q.EnqueueJob((void*)(char*)XWorkerJob::RecordMeStatic, &args);
return true; return true;
} }
......
...@@ -107,13 +107,13 @@ public: ...@@ -107,13 +107,13 @@ public:
int GetPredictNum(); int GetPredictNum();
/* add a new job of model refreshment */ /* add a new job of model refreshment */
bool AddJobRefresh(XModel * myModel); bool AddJobRefresh(XQueue * myQueue, XModel * myModel);
/* add a new job of neural network forward and backward computation (with the input) */ /* add a new job of neural network forward and backward computation (with the input) */
bool AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs, XList * golds, XList * losses); bool AddJobNeuralNet(XQueue * myQueue, XModel * myModel, XList * inputs, XList * outputs, XList * golds, XList * losses);
/* add a new job of recording the running of the nerual network */ /* add a new job of recording the running of the nerual network */
bool AddJobRecord(XNNRecord * serverRecord); bool AddJobRecord(XQueue * myQueue, XNNRecord * serverRecord);
private: private:
/* wrapper of RecordMe */ /* wrapper of RecordMe */
......
...@@ -83,46 +83,10 @@ void XWorkerUpdate::UpdateParameter(XModel * server, XList * members, int pid, ...@@ -83,46 +83,10 @@ void XWorkerUpdate::UpdateParameter(XModel * server, XList * members, int pid,
} }
/* /*
update the model
>> model - the model that we want to update
>> optimizer - the optimizer
>> sleepTime - waiting time in each update
*/
void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, int sleepTime)
{
int finished = 0;
while (1) {
for (int i = 0; i < model->paramNum; i++) {
if (model->params[i].flag == PARAM_STATE_COLLECTED) {
XTensor * param = model->params[i].param;
XTensor * grad = param->grad;
CheckNTErrors(grad != NULL, "No gradient!");
/* update the parameter */
optimizer->UpdateParam(param, grad, i);
/* set the flag */
model->params[i].flag = PARAM_STATE_UPDATED;
finished++;
}
}
if (finished == model->paramNum)
break;
XSleep(sleepTime);
}
optimizer->Note(model);
}
/*
wrapper of UpdateParameter wrapper of UpdateParameter
>> args - arguments of the update >> args - arguments of the update
*/ */
void XWorkerUpdate::UpdateSingle(XList * args) void XWorkerUpdate::Update(XList * args)
{ {
CheckNTErrors(args != NULL && args->count >= 6, "Illegal argument list!"); CheckNTErrors(args != NULL && args->count >= 6, "Illegal argument list!");
...@@ -145,34 +109,15 @@ void XWorkerUpdate::UpdateSingle(XList * args) ...@@ -145,34 +109,15 @@ void XWorkerUpdate::UpdateSingle(XList * args)
} }
/* /*
wrapper of UpdateModel
>> args - arguments of the update
*/
void XWorkerUpdate::Update(XList * args)
{
//fprintf(stderr, "update 0\n");
CheckNTErrors(args != NULL && args->count >= 3, "Illegal argument list!");
XWorkerUpdate * updater = (XWorkerUpdate*)args->GetItem(0);
XModel * model = (XModel*)args->GetItem(1);
XOptimizer * optimizer = (XOptimizer*)args->GetItem(2);
if(updater != NULL)
updater->UpdateModel(model, optimizer, SLEEP_TIME_IN_MODEL_UPDATE);
//fprintf(stderr, "update 1\n");
}
/*
add a new job of model update (for a parameter) add a new job of model update (for a parameter)
>> myQueue - the queue where we push the job
>> model - the model that we want to update (on the server side) >> model - the model that we want to update (on the server side)
>> members - models that would share the updated parameters >> members - models that would share the updated parameters
>> pid - the parameter index >> pid - the parameter index
>> optimizer - the optimizer >> optimizer - the optimizer
>> broadcaster - the worker that would broadcast the new parameter to members >> broadcaster - the worker that would broadcast the new parameter to members
*/ */
bool XWorkerUpdate::AddJobUpdateSingle(XModel * model, XList * members, int pid, bool XWorkerUpdate::AddJobUpdate(XQueue * myQueue, XModel * model, XList * members, int pid,
XOptimizer * optimizer, XWorkerBroadcast * broadcaster) XOptimizer * optimizer, XWorkerBroadcast * broadcaster)
{ {
CheckNTErrors(model != NULL, "No input model!"); CheckNTErrors(model != NULL, "No input model!");
...@@ -190,33 +135,12 @@ bool XWorkerUpdate::AddJobUpdateSingle(XModel * model, XList * members, int pid, ...@@ -190,33 +135,12 @@ bool XWorkerUpdate::AddJobUpdateSingle(XModel * model, XList * members, int pid,
args.Add(optimizer); args.Add(optimizer);
args.Add(broadcaster); args.Add(broadcaster);
if (isInstantRun) XQueue &q = myQueue != NULL ? *myQueue : queue;
XWorkerUpdate::UpdateSingle(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerUpdate::UpdateSingle, &args);
return true;
}
/*
add a new job of model update
>> model - the model that we want to update
>> optimizer - the optimizer
*/
bool XWorkerUpdate::AddJobUpdate(XModel * model, XOptimizer * optimizer)
{
CheckNTErrors(model != NULL, "No input model!");
CheckNTErrors(optimizer != NULL, "No optimizer!");
XList args; if (isInstantRun)
args.Add(this);
args.Add(model);
args.Add(optimizer);
if(isInstantRun)
XWorkerUpdate::Update(&args); XWorkerUpdate::Update(&args);
else else
queue.EnqueueJob((void*)(char*)XWorkerUpdate::Update, &args); q.EnqueueJob((void*)(char*)XWorkerUpdate::Update, &args);
return true; return true;
} }
......
...@@ -60,23 +60,15 @@ public: ...@@ -60,23 +60,15 @@ public:
void UpdateParameter(XModel * server, XList * members, int pid, void UpdateParameter(XModel * server, XList * members, int pid,
XOptimizer * optimizer, XWorkerBroadcast * broadcaster); XOptimizer * optimizer, XWorkerBroadcast * broadcaster);
/* update the model */
void UpdateModel(XModel * model, XOptimizer * optimizer, int sleepTime);
/* wrapper of UpdateParameter */ /* wrapper of UpdateParameter */
static static
void UpdateSingle(XList * args);
/* wrapper of UpdateModel */
static
void Update(XList * args); void Update(XList * args);
/* add a new job of model update (for a parameter) */ /* add a new job of model update (for a parameter) */
bool AddJobUpdateSingle(XModel * model, XList * members, int pid, bool AddJobUpdate(XQueue * myQueue, XModel * model, XList * members, int pid,
XOptimizer * optimizer, XWorkerBroadcast * broadcaster); XOptimizer * optimizer, XWorkerBroadcast * broadcaster);
/* add a new job of model update */
bool AddJobUpdate(XModel * model, XOptimizer * optimizer);
}; };
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论