Commit 7c1aeb5c by xiaotong

add a new pipeline for parameter update

parent f408c730
...@@ -66,6 +66,10 @@ void XLeader::Init() ...@@ -66,6 +66,10 @@ void XLeader::Init()
delete (XWorkerBroadcast*)bworkers.GetItem(i); delete (XWorkerBroadcast*)bworkers.GetItem(i);
bworkers.Clear(); bworkers.Clear();
for(int i = 0; i < pworkers.count; i++)
delete (XWorker*)pworkers.GetItem(i);
pworkers.Clear();
serverRecord.Clear(); serverRecord.Clear();
} }
...@@ -129,6 +133,7 @@ void XLeader::InitForRun() ...@@ -129,6 +133,7 @@ void XLeader::InitForRun()
workers.AddList(&cworkers); workers.AddList(&cworkers);
workers.AddList(&uworkers); workers.AddList(&uworkers);
workers.AddList(&bworkers); workers.AddList(&bworkers);
workers.AddList(&pworkers);
for (int i = 0; i < workers.count; i++) { for (int i = 0; i < workers.count; i++) {
XWorker* worker = (XWorker*)workers[i]; XWorker* worker = (XWorker*)workers[i];
...@@ -243,6 +248,11 @@ void XLeader::SetInstantRun(bool flag) ...@@ -243,6 +248,11 @@ void XLeader::SetInstantRun(bool flag)
XWorkerJob * worker = (XWorkerJob*)bworkers.GetItem(i); XWorkerJob * worker = (XWorkerJob*)bworkers.GetItem(i);
worker->SetInstantRun(flag); worker->SetInstantRun(flag);
} }
for (int i = 0; i < pworkers.count; i++) {
XWorker * worker = (XWorker*)pworkers.GetItem(i);
worker->SetInstantRun(flag);
}
} }
/* start the workers */ /* start the workers */
...@@ -257,17 +267,22 @@ void XLeader::Start() ...@@ -257,17 +267,22 @@ void XLeader::Start()
} }
for (int i = 0; i < cworkers.count; i++) { for (int i = 0; i < cworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)cworkers.GetItem(i); XWorkerCollect * worker = (XWorkerCollect*)cworkers.GetItem(i);
worker->Start(); worker->Start();
} }
for (int i = 0; i < uworkers.count; i++) { for (int i = 0; i < uworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)uworkers.GetItem(i); XWorkerUpdate * worker = (XWorkerUpdate*)uworkers.GetItem(i);
worker->Start(); worker->Start();
} }
for (int i = 0; i < bworkers.count; i++) { for (int i = 0; i < bworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)bworkers.GetItem(i); XWorkerBroadcast * worker = (XWorkerBroadcast*)bworkers.GetItem(i);
worker->Start();
}
for (int i = 0; i < pworkers.count; i++) {
XWorker * worker = (XWorker*)pworkers.GetItem(i);
worker->Start(); worker->Start();
} }
} }
...@@ -326,6 +341,18 @@ void XLeader::AddJobBroadcastWorker() ...@@ -326,6 +341,18 @@ void XLeader::AddJobBroadcastWorker()
} }
/* /*
add a parameter worker (or a pipeline)
>> n - number of parameters
*/
void XLeader::AddJobParamterWorker(int n)
{
for (int i = 0; i < n; i++) {
XWorker * worker = new XWorker();
pworkers.Add(worker);
}
}
/*
run the model (for one time). Basically this is a map-reduce process. run the model (for one time). Basically this is a map-reduce process.
>> config - the configuration >> config - the configuration
>> dataDistributor - data distributor >> dataDistributor - data distributor
...@@ -340,6 +367,7 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -340,6 +367,7 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
CheckNTErrors(cworkers.count > 0, "No cworkers!"); CheckNTErrors(cworkers.count > 0, "No cworkers!");
CheckNTErrors(uworkers.count > 0, "No uworkers!"); CheckNTErrors(uworkers.count > 0, "No uworkers!");
CheckNTErrors(bworkers.count > 0, "No bworkers!"); CheckNTErrors(bworkers.count > 0, "No bworkers!");
CheckNTErrors(pworkers.count > 0, "No pworkers!");
bool isDataOK = true; bool isDataOK = true;
bool isToUpdate = (optimizer != NULL); bool isToUpdate = (optimizer != NULL);
...@@ -366,15 +394,15 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -366,15 +394,15 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
isDataOK = false; isDataOK = false;
else { else {
/* job in queue 1: refresh the model */ /* job in queue 1: refresh the model */
worker->AddJobRefresh(jmodel); worker->AddJobRefresh(worker->GetJobQueue(), jmodel);
/* job in queue 1: run the model */ /* job in queue 1: run the model */
worker->AddJobNeuralNet(jmodel, worker->AddJobNeuralNet(worker->GetJobQueue(), jmodel,
worker->GetInput(), worker->GetOutput(), worker->GetInput(), worker->GetOutput(),
worker->GetGold(), worker->GetLoss()); worker->GetGold(), worker->GetLoss());
/* job in queue 1: make a record of the run */ /* job in queue 1: make a record of the run */
worker->AddJobRecord(&serverRecord); worker->AddJobRecord(worker->GetJobQueue(), &serverRecord);
/* job in queue 1: mark finished */ /* job in queue 1: mark finished */
worker->AddJobEnqueueFinished(); worker->AddJobEnqueueFinished();
...@@ -409,8 +437,9 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -409,8 +437,9 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
broadcast the lastest parameters to workers. NOTE that we would update broadcast the lastest parameters to workers. NOTE that we would update
a worker to the laster model parameters, even if it is not involved a worker to the laster model parameters, even if it is not involved
in this run. */ in this run. */
collecter->AddJobUpdateAll(&members, &membersAll, &serverModel, collecter->AddJobUpdateAll(collecter->GetJobQueue(),
optimizer, updater, broadcaster); &members, &membersAll, &serverModel,
optimizer, updater, broadcaster);
collecter->AddJobEnqueueFinished(); collecter->AddJobEnqueueFinished();
} }
......
...@@ -88,6 +88,12 @@ protected: ...@@ -88,6 +88,12 @@ protected:
/* data-broadcasting workers */ /* data-broadcasting workers */
XList bworkers; XList bworkers;
/* parameter workers (each for a paramter). cworkers,
uworkers, and bworkers would push their jobs into
parameter workers. So they are actually pipelines
of jobs. */
XList pworkers;
public: public:
/* constructor */ /* constructor */
XLeader(); XLeader();
...@@ -149,6 +155,9 @@ public: ...@@ -149,6 +155,9 @@ public:
/* add a data-broadcasting worker */ /* add a data-broadcasting worker */
void AddJobBroadcastWorker(); void AddJobBroadcastWorker();
/* add a parameter worker (or a pipeline) */
void AddJobParamterWorker(int n);
/* run the model (for one time) */ /* run the model (for one time) */
bool Run(XConfig * config, DataDistributeBase * dataDistributor, bool Run(XConfig * config, DataDistributeBase * dataDistributor,
XModel * model, XOptimizer * optimizer); XModel * model, XOptimizer * optimizer);
......
...@@ -101,7 +101,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -101,7 +101,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
int jobNum = 0; int jobNum = 0;
int accumulation = config->GetInt("accumulation", 1); int accumulation = config->GetInt("accumulation", 1);
int nwarmup = config->GetInt("nwarmup", 0); int nwarmup = config->GetInt("nwarmup", 0);
int lrate = optimizer->GetLearningRate(); float lrate = optimizer->GetLearningRate();
CheckNTErrors(accumulation >= 1, "accumulation must be larger than 0!"); CheckNTErrors(accumulation >= 1, "accumulation must be larger than 0!");
...@@ -118,6 +118,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -118,6 +118,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
leader.AddJobCollectWorker(); leader.AddJobCollectWorker();
leader.AddJobUpdateWorker(model, optimizer); leader.AddJobUpdateWorker(model, optimizer);
leader.AddJobBroadcastWorker(); leader.AddJobBroadcastWorker();
leader.AddJobParamterWorker(model->paramNum);
//leader.SetInstantRun(); //leader.SetInstantRun();
leader.SetServerModel(config, model); leader.SetServerModel(config, model);
leader.Start(); leader.Start();
...@@ -149,7 +150,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -149,7 +150,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
if ((step + 1) % 100 == 0) if ((step + 1) % 100 == 0)
XPRINT5(1, stderr, "[INFO] elapsed=%.1fs epoch:%d step:%d sample:%d loss:%f\n", XPRINT5(1, stderr, "[INFO] elapsed=%.1fs epoch:%d step:%d sample:%d loss:%f\n",
GetClockSec() - startT, epoch + 1, step + 1, leader.GetSampleNum(), loss); GetClockSec() - startT, epoch + 1, step + 1, leader.GetSampleNum(), loss);
leader.ResetParamGrad(); leader.ResetParamGrad();
......
...@@ -70,6 +70,12 @@ int XWorker::GetID() ...@@ -70,6 +70,12 @@ int XWorker::GetID()
return id; return id;
} }
/* get job queue */
XQueue * XWorker::GetJobQueue()
{
return &queue;
}
/* set the flag of instant run */ /* set the flag of instant run */
void XWorker::SetInstantRun(bool flag) void XWorker::SetInstantRun(bool flag)
{ {
......
...@@ -86,6 +86,9 @@ public: ...@@ -86,6 +86,9 @@ public:
/* get worker id */ /* get worker id */
int GetID(); int GetID();
/* get job queue */
XQueue * GetJobQueue();
/* set the flag of instant run */ /* set the flag of instant run */
void SetInstantRun(bool flag = true); void SetInstantRun(bool flag = true);
......
...@@ -125,7 +125,9 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod ...@@ -125,7 +125,9 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
if (finishedCount[j] == memberActive->count) { if (finishedCount[j] == memberActive->count) {
paramServer.flag = PARAM_STATE_COLLECTED; paramServer.flag = PARAM_STATE_COLLECTED;
if (updater != NULL) { if (updater != NULL) {
updater->AddJobUpdateSingle(server, memberAll, j, optimizer, broadcaster); updater->AddJobUpdate(updater->GetJobQueue(),
server, memberAll, j,
optimizer, broadcaster);
updater->AddJobEnqueueFinished(); updater->AddJobEnqueueFinished();
} }
} }
...@@ -241,6 +243,7 @@ void XWorkerCollect::CollectAllReduce(XList * all) ...@@ -241,6 +243,7 @@ void XWorkerCollect::CollectAllReduce(XList * all)
/* /*
add a new job of collecting data, update the parameter and add a new job of collecting data, update the parameter and
broadcast the new parameter broadcast the new parameter
>> myQueue - the queue where we push the job
>> memberActive - member models that are active, i.e., have generated gradients >> memberActive - member models that are active, i.e., have generated gradients
>> memberAll - all member models >> memberAll - all member models
>> server - the server model >> server - the server model
...@@ -250,7 +253,7 @@ broadcast the new parameter ...@@ -250,7 +253,7 @@ broadcast the new parameter
models models
<< return - successful or not << return - successful or not
*/ */
bool XWorkerCollect::AddJobUpdateAll(XList * memberActive, XList * memberAll, XModel * server, bool XWorkerCollect::AddJobUpdateAll(XQueue * myQueue, XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster) XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster)
{ {
CheckNTErrors(memberActive != NULL, "No input (active) member list!"); CheckNTErrors(memberActive != NULL, "No input (active) member list!");
...@@ -271,124 +274,12 @@ bool XWorkerCollect::AddJobUpdateAll(XList * memberActive, XList * memberAll, XM ...@@ -271,124 +274,12 @@ bool XWorkerCollect::AddJobUpdateAll(XList * memberActive, XList * memberAll, XM
args.Add(updater); args.Add(updater);
args.Add(broadcaster); args.Add(broadcaster);
if (isInstantRun) XQueue &q = myQueue != NULL ? *myQueue : queue;
XWorkerCollect::UpdateAll(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerCollect::UpdateAll, &args);
return true;
}
/*
add a new job of collecting data
>> sourceList - the list of models that we want collect data from
>> target - the destination of the collection
<< return - successful or not
*/
bool XWorkerCollect::AddJobCollect(XList * sourceList, XModel * target)
{
CheckNTErrors(sourceList != NULL, "no input source model list!");
CheckNTErrors(target != NULL, "no input target model!");
XList args;
args.Add(this);
args.AddInt(sourceList->count);
args.AddList(sourceList);
args.AddInt(0);
args.Add(target);
args.Add(NULL);
args.Add(NULL);
args.Add(NULL);
if (isInstantRun) if (isInstantRun)
XWorkerCollect::UpdateAll(&args); XWorkerCollect::UpdateAll(&args);
else else
queue.EnqueueJob((void*)(char*)XWorkerCollect::UpdateAll, &args); q.EnqueueJob((void*)(char*)XWorkerCollect::UpdateAll, &args);
return true;
}
/*
collect the data of the run (i.e., loss). This is a reducer.
>> sourceList - the list of record
>> target - the record that we keep the reduce result
>> sleepTime - waiting time in collecting data
*/
void XWorkerCollect::CollectOtherData(XList* sourceList, XNNRecord* target, int sleepTime)
{
int finished = 0;
int* flags = new int[sourceList->count];
for (int i = 0; i < sourceList->count; i++)
flags[i] = 0;
while (1) {
for (int i = 0; i < sourceList->count; i++) {
if (flags[i] != 0)
continue;
XNNRecord* source = (XNNRecord*)sourceList->GetItem(i);
if (source->state == XWORKER_FINISHED) {
if(target != source)
target->Update(*source);
flags[i] = 1;
finished++;
}
}
if (finished == sourceList->count)
break;
XSleep(sleepTime);
}
delete[] flags;
}
/* wrapper of CollectOtherData */
void XWorkerCollect::CollectOther(XList* args)
{
//fprintf(stderr, "collect data other 0\n");
XWorkerCollect* collecter = (XWorkerCollect*)args->GetItem(0);
int sourceNum = args->GetItemInt(1);
/* the source records */
XList source;
for (int i = 0; i < sourceNum; i++) {
XNNRecord * record = (XNNRecord*)args->GetItem(2 + i);
source.Add(record);
}
/* the target record */
XNNRecord* target = (XNNRecord*)args->GetItem(2 + sourceNum);
collecter->CollectOtherData(&source, target, SLEEP_TIME_IN_COLLECTING_OTHER);
//fprintf(stderr, "collect data other 1\n");
}
/*
add a new job of collecting data of the run (i.e., loss)
collect the data of the run (i.e., loss). This is a reducer.
>> sourceList - the list of record
>> target - the record that we keep the reduce result
*/
bool XWorkerCollect::AddJobCollectOther(XList* sourceList, XNNRecord* target)
{
CheckNTErrors(sourceList != NULL, "no input source record list!");
CheckNTErrors(target != NULL, "no input target record!");
XList args;
args.Add(this);
args.AddInt(sourceList->count);
args.AddList(sourceList);
args.Add(target);
if (isInstantRun)
XWorkerCollect::CollectOther(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerCollect::CollectOther, &args);
return true; return true;
} }
......
...@@ -70,7 +70,7 @@ public: ...@@ -70,7 +70,7 @@ public:
from member models. Then it calls an XWorkerUpdate to update the parameters. from member models. Then it calls an XWorkerUpdate to update the parameters.
The XWorkerUpdate also calls an XWorkerBroadcast to broadcast the new parameter The XWorkerUpdate also calls an XWorkerBroadcast to broadcast the new parameter
to member models back. */ to member models back. */
void UpdateDataAll(XList * memberActive, XList * memberAll, XModel * server, void UpdateDataAll(XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster, XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster,
int sleepTime); int sleepTime);
...@@ -88,21 +88,9 @@ public: ...@@ -88,21 +88,9 @@ public:
void CollectAllReduce(XList * all); void CollectAllReduce(XList * all);
/* add a new job of collecting data, update the parameter and broadcast the new parameter */ /* add a new job of collecting data, update the parameter and broadcast the new parameter */
bool AddJobUpdateAll(XList * memberActive, XList * memberAll, XModel * server, bool AddJobUpdateAll(XQueue * myQueue, XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster); XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster);
/* add a new job of collecting data */
bool AddJobCollect(XList * sourceList, XModel * target);
/* collect the data of the run (i.e., loss). This is a reducer. */
void CollectOtherData(XList * sourceList, XNNRecord * target, int sleepTime);
/* wrapper of CollectOtherData */
static
void CollectOther(XList * args);
/* add a new job of collecting data of the run (i.e., loss) */
bool AddJobCollectOther(XList * sourceList, XNNRecord * target);
}; };
} }
......
...@@ -175,10 +175,11 @@ int XWorkerJob::GetPredictNum() ...@@ -175,10 +175,11 @@ int XWorkerJob::GetPredictNum()
/* /*
add a new job of model refreshment add a new job of model refreshment
>> myQueue - the queue where we push the job
>> myModel - the model >> myModel - the model
<< return - succeeded or not << return - succeeded or not
*/ */
bool XWorkerJob::AddJobRefresh(XModel * myModel) bool XWorkerJob::AddJobRefresh(XQueue * myQueue, XModel * myModel)
{ {
//fprintf(stderr, "refresh 0\n"); //fprintf(stderr, "refresh 0\n");
...@@ -187,10 +188,12 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel) ...@@ -187,10 +188,12 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel)
XList args(1); XList args(1);
args.Add(myModel); args.Add(myModel);
XQueue &q = myQueue != NULL ? *myQueue : queue;
if(isInstantRun) if(isInstantRun)
XModel::Refresh(&args); XModel::Refresh(&args);
else else
queue.EnqueueJob((void*)(char*)XModel::Refresh, &args); q.EnqueueJob((void*)(char*)XModel::Refresh, &args);
//fprintf(stderr, "refresh 1\n"); //fprintf(stderr, "refresh 1\n");
...@@ -199,6 +202,7 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel) ...@@ -199,6 +202,7 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel)
/* /*
add a new job of neural network forward and backward computation (with the input) add a new job of neural network forward and backward computation (with the input)
>> myQueue - the queue where we push the job
>> myModel - the model >> myModel - the model
>> inputs - inputs of the neural network >> inputs - inputs of the neural network
>> outputs - outputs of the neural network >> outputs - outputs of the neural network
...@@ -206,7 +210,7 @@ add a new job of neural network forward and backward computation (with the input ...@@ -206,7 +210,7 @@ add a new job of neural network forward and backward computation (with the input
>> losses - losses of the outputs respect to the gold standards >> losses - losses of the outputs respect to the gold standards
<< return - succeeded or not << return - succeeded or not
*/ */
bool XWorkerJob::AddJobNeuralNet(XModel * myModel, bool XWorkerJob::AddJobNeuralNet(XQueue * myQueue, XModel * myModel,
XList * inputs, XList * outputs, XList * golds, XList * losses) XList * inputs, XList * outputs, XList * golds, XList * losses)
{ {
CheckNTErrors(myModel != NULL, "no input neural network!"); CheckNTErrors(myModel != NULL, "no input neural network!");
...@@ -220,10 +224,12 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel, ...@@ -220,10 +224,12 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel,
args.Add(golds); args.Add(golds);
args.Add(losses); args.Add(losses);
XQueue &q = myQueue != NULL ? *myQueue : queue;
if(isInstantRun) if(isInstantRun)
XModel::Run(&args); XModel::Run(&args);
else else
queue.EnqueueJob((void*)(char*)XModel::Run, &args); q.EnqueueJob((void*)(char*)XModel::Run, &args);
SetState(XWORKER_STARTED); SetState(XWORKER_STARTED);
...@@ -254,18 +260,21 @@ void XWorkerJob::RecordMeStatic(XList* args) ...@@ -254,18 +260,21 @@ void XWorkerJob::RecordMeStatic(XList* args)
/* /*
add a new job of recording the running of the nerual network add a new job of recording the running of the nerual network
>> >> myQueue - the queue where we push the job
>> serverRecord - the model record on the server side
*/ */
bool XWorkerJob::AddJobRecord(XNNRecord * serverRecord) bool XWorkerJob::AddJobRecord(XQueue * myQueue, XNNRecord * serverRecord)
{ {
XList args; XList args;
args.Add(this); args.Add(this);
args.Add(serverRecord); args.Add(serverRecord);
XQueue &q = myQueue != NULL ? *myQueue : queue;
if (isInstantRun) if (isInstantRun)
XWorkerJob::RecordMeStatic(&args); XWorkerJob::RecordMeStatic(&args);
else else
queue.EnqueueJob((void*)(char*)XWorkerJob::RecordMeStatic, &args); q.EnqueueJob((void*)(char*)XWorkerJob::RecordMeStatic, &args);
return true; return true;
} }
......
...@@ -107,13 +107,13 @@ public: ...@@ -107,13 +107,13 @@ public:
int GetPredictNum(); int GetPredictNum();
/* add a new job of model refreshment */ /* add a new job of model refreshment */
bool AddJobRefresh(XModel * myModel); bool AddJobRefresh(XQueue * myQueue, XModel * myModel);
/* add a new job of neural network forward and backward computation (with the input) */ /* add a new job of neural network forward and backward computation (with the input) */
bool AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs, XList * golds, XList * losses); bool AddJobNeuralNet(XQueue * myQueue, XModel * myModel, XList * inputs, XList * outputs, XList * golds, XList * losses);
/* add a new job of recording the running of the nerual network */ /* add a new job of recording the running of the nerual network */
bool AddJobRecord(XNNRecord * serverRecord); bool AddJobRecord(XQueue * myQueue, XNNRecord * serverRecord);
private: private:
/* wrapper of RecordMe */ /* wrapper of RecordMe */
......
...@@ -83,46 +83,10 @@ void XWorkerUpdate::UpdateParameter(XModel * server, XList * members, int pid, ...@@ -83,46 +83,10 @@ void XWorkerUpdate::UpdateParameter(XModel * server, XList * members, int pid,
} }
/* /*
update the model
>> model - the model that we want to update
>> optimizer - the optimizer
>> sleepTime - waiting time in each update
*/
void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, int sleepTime)
{
int finished = 0;
while (1) {
for (int i = 0; i < model->paramNum; i++) {
if (model->params[i].flag == PARAM_STATE_COLLECTED) {
XTensor * param = model->params[i].param;
XTensor * grad = param->grad;
CheckNTErrors(grad != NULL, "No gradient!");
/* update the parameter */
optimizer->UpdateParam(param, grad, i);
/* set the flag */
model->params[i].flag = PARAM_STATE_UPDATED;
finished++;
}
}
if (finished == model->paramNum)
break;
XSleep(sleepTime);
}
optimizer->Note(model);
}
/*
wrapper of UpdateParameter wrapper of UpdateParameter
>> args - arguments of the update >> args - arguments of the update
*/ */
void XWorkerUpdate::UpdateSingle(XList * args) void XWorkerUpdate::Update(XList * args)
{ {
CheckNTErrors(args != NULL && args->count >= 6, "Illegal argument list!"); CheckNTErrors(args != NULL && args->count >= 6, "Illegal argument list!");
...@@ -145,35 +109,16 @@ void XWorkerUpdate::UpdateSingle(XList * args) ...@@ -145,35 +109,16 @@ void XWorkerUpdate::UpdateSingle(XList * args)
} }
/* /*
wrapper of UpdateModel
>> args - arguments of the update
*/
void XWorkerUpdate::Update(XList * args)
{
//fprintf(stderr, "update 0\n");
CheckNTErrors(args != NULL && args->count >= 3, "Illegal argument list!");
XWorkerUpdate * updater = (XWorkerUpdate*)args->GetItem(0);
XModel * model = (XModel*)args->GetItem(1);
XOptimizer * optimizer = (XOptimizer*)args->GetItem(2);
if(updater != NULL)
updater->UpdateModel(model, optimizer, SLEEP_TIME_IN_MODEL_UPDATE);
//fprintf(stderr, "update 1\n");
}
/*
add a new job of model update (for a parameter) add a new job of model update (for a parameter)
>> myQueue - the queue where we push the job
>> model - the model that we want to update (on the server side) >> model - the model that we want to update (on the server side)
>> members - models that would share the updated parameters >> members - models that would share the updated parameters
>> pid - the parameter index >> pid - the parameter index
>> optimizer - the optimizer >> optimizer - the optimizer
>> broadcaster - the worker that would broadcast the new parameter to members >> broadcaster - the worker that would broadcast the new parameter to members
*/ */
bool XWorkerUpdate::AddJobUpdateSingle(XModel * model, XList * members, int pid, bool XWorkerUpdate::AddJobUpdate(XQueue * myQueue, XModel * model, XList * members, int pid,
XOptimizer * optimizer, XWorkerBroadcast * broadcaster) XOptimizer * optimizer, XWorkerBroadcast * broadcaster)
{ {
CheckNTErrors(model != NULL, "No input model!"); CheckNTErrors(model != NULL, "No input model!");
CheckNTErrors(members != NULL, "No member model list!"); CheckNTErrors(members != NULL, "No member model list!");
...@@ -190,34 +135,13 @@ bool XWorkerUpdate::AddJobUpdateSingle(XModel * model, XList * members, int pid, ...@@ -190,34 +135,13 @@ bool XWorkerUpdate::AddJobUpdateSingle(XModel * model, XList * members, int pid,
args.Add(optimizer); args.Add(optimizer);
args.Add(broadcaster); args.Add(broadcaster);
if (isInstantRun) XQueue &q = myQueue != NULL ? *myQueue : queue;
XWorkerUpdate::UpdateSingle(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerUpdate::UpdateSingle, &args);
return true;
}
/* if (isInstantRun)
add a new job of model update
>> model - the model that we want to update
>> optimizer - the optimizer
*/
bool XWorkerUpdate::AddJobUpdate(XModel * model, XOptimizer * optimizer)
{
CheckNTErrors(model != NULL, "No input model!");
CheckNTErrors(optimizer != NULL, "No optimizer!");
XList args;
args.Add(this);
args.Add(model);
args.Add(optimizer);
if(isInstantRun)
XWorkerUpdate::Update(&args); XWorkerUpdate::Update(&args);
else else
queue.EnqueueJob((void*)(char*)XWorkerUpdate::Update, &args); q.EnqueueJob((void*)(char*)XWorkerUpdate::Update, &args);
return true; return true;
} }
......
...@@ -60,23 +60,15 @@ public: ...@@ -60,23 +60,15 @@ public:
void UpdateParameter(XModel * server, XList * members, int pid, void UpdateParameter(XModel * server, XList * members, int pid,
XOptimizer * optimizer, XWorkerBroadcast * broadcaster); XOptimizer * optimizer, XWorkerBroadcast * broadcaster);
/* update the model */
void UpdateModel(XModel * model, XOptimizer * optimizer, int sleepTime);
/* wrapper of UpdateParameter */ /* wrapper of UpdateParameter */
static static
void UpdateSingle(XList * args);
/* wrapper of UpdateModel */
static
void Update(XList * args); void Update(XList * args);
/* add a new job of model update (for a parameter) */
bool AddJobUpdateSingle(XModel * model, XList * members, int pid,
XOptimizer * optimizer, XWorkerBroadcast * broadcaster);
/* add a new job of model update */ /* add a new job of model update (for a parameter) */
bool AddJobUpdate(XModel * model, XOptimizer * optimizer); bool AddJobUpdate(XQueue * myQueue, XModel * model, XList * members, int pid,
XOptimizer * optimizer, XWorkerBroadcast * broadcaster);
}; };
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论