Commit 0483b910 by xiaotong

function renaming and redefination of XLeaderPS::MakeAll()

parent 2a9c6078
...@@ -317,9 +317,9 @@ void XLeader::Start() ...@@ -317,9 +317,9 @@ void XLeader::Start()
add a number of job workers (given their device ids) add a number of job workers (given their device ids)
>> model - the neural network >> model - the neural network
>> n - number of the models >> n - number of the models
>> ids - the array of device ids >> devIDs - the array of device ids
*/ */
void XLeader::AddJobWorker(XModel * model, int n, int * ids) void XLeader::AddJobWorker(XModel * model, int n, const int * devIDs)
{ {
/* we keep the input model */ /* we keep the input model */
if (n >= 1) { if (n >= 1) {
...@@ -331,7 +331,7 @@ void XLeader::AddJobWorker(XModel * model, int n, int * ids) ...@@ -331,7 +331,7 @@ void XLeader::AddJobWorker(XModel * model, int n, int * ids)
/* we clone the input model */ /* we clone the input model */
for (int i = 1; i < n; i++) { for (int i = 1; i < n; i++) {
XWorkerJob * worker = new XWorkerJob(); XWorkerJob * worker = new XWorkerJob();
worker->SetModel(model->Clone(ids[i])); worker->SetModel(model->Clone(devIDs[i]));
jworkers.Add(worker); jworkers.Add(worker);
} }
} }
...@@ -340,7 +340,7 @@ void XLeader::AddJobWorker(XModel * model, int n, int * ids) ...@@ -340,7 +340,7 @@ void XLeader::AddJobWorker(XModel * model, int n, int * ids)
add a data-collecting worker add a data-collecting worker
>> mode - the data-transfer mode of the worker >> mode - the data-transfer mode of the worker
*/ */
void XLeader::AddJobCollectWorker(DATA_COLLECT_TYPE mode) void XLeader::AddCollectWorker(DATA_COLLECT_TYPE mode)
{ {
XWorkerCollect * worker = new XWorkerCollect(); XWorkerCollect * worker = new XWorkerCollect();
worker->SetCollectMode(mode); worker->SetCollectMode(mode);
...@@ -350,17 +350,15 @@ void XLeader::AddJobCollectWorker(DATA_COLLECT_TYPE mode) ...@@ -350,17 +350,15 @@ void XLeader::AddJobCollectWorker(DATA_COLLECT_TYPE mode)
/* /*
add a model-update worker add a model-update worker
>> model - the model >> model - the model
>> optimizer - the optimizer
*/ */
void XLeader::AddJobUpdateWorker(XModel * model, XOptimizer * optimizer) void XLeader::AddUpdateWorker(XModel * model)
{ {
XWorkerUpdate * worker = new XWorkerUpdate(); XWorkerUpdate * worker = new XWorkerUpdate();
worker->SetOptimizer(optimizer);
uworkers.Add(worker); uworkers.Add(worker);
} }
/* add a data-broadcasting worker */ /* add a data-broadcasting worker */
void XLeader::AddJobBroadcastWorker() void XLeader::AddBroadcastWorker()
{ {
XWorkerBroadcast * worker = new XWorkerBroadcast(); XWorkerBroadcast * worker = new XWorkerBroadcast();
bworkers.Add(worker); bworkers.Add(worker);
...@@ -370,7 +368,7 @@ void XLeader::AddJobBroadcastWorker() ...@@ -370,7 +368,7 @@ void XLeader::AddJobBroadcastWorker()
add a parameter worker (or a pipeline) add a parameter worker (or a pipeline)
>> n - number of parameters >> n - number of parameters
*/ */
void XLeader::AddJobParamterWorker(int n) void XLeader::AddParamterWorker(int n)
{ {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
XWorker * worker = new XWorker(); XWorker * worker = new XWorker();
......
...@@ -160,19 +160,19 @@ public: ...@@ -160,19 +160,19 @@ public:
void SetInstantRun(bool flag = true); void SetInstantRun(bool flag = true);
/* add a number of job workers (given their device ids) */ /* add a number of job workers (given their device ids) */
void AddJobWorker(XModel * model, int n, int * ids); void AddJobWorker(XModel * model, int n, const int * devIDs);
/* add a data-collecting worker */ /* add a data-collecting worker */
void AddJobCollectWorker(DATA_COLLECT_TYPE mode = DATA_COLLECT_P2P); void AddCollectWorker(DATA_COLLECT_TYPE mode = DATA_COLLECT_P2P);
/* add a model-update worker */ /* add a model-update worker */
void AddJobUpdateWorker(XModel * model, XOptimizer * optimizer); void AddUpdateWorker(XModel * model);
/* add a data-broadcasting worker */ /* add a data-broadcasting worker */
void AddJobBroadcastWorker(); void AddBroadcastWorker();
/* add a parameter worker (or a pipeline) */ /* add a parameter worker (or a pipeline) */
void AddJobParamterWorker(int n); void AddParamterWorker(int n);
/* destroy the parameter map (and gradient map) */ /* destroy the parameter map (and gradient map) */
void DestroyParamMap(); void DestroyParamMap();
......
...@@ -44,6 +44,25 @@ XLeaderPS::~XLeaderPS() ...@@ -44,6 +44,25 @@ XLeaderPS::~XLeaderPS()
} }
/* /*
create workers
>> config - configuration
>> model - the model that we run
>> devIDs - device ids of the workers (the first id is for server)
>> jobWorkerNum - number of job workers
*/
void XLeaderPS::MakeAll(XConfig * config, XModel * model, const int * devIDs, const int jobWorkerNum)
{
Init();
AddJobWorker(model, jobWorkerNum, devIDs);
AddCollectWorker();
AddUpdateWorker(model);
AddBroadcastWorker();
AddParamterWorker(model->paramNum);
XLeader::MakeAll(config, model);
}
/*
run the model (for one time). Basically this is a map-reduce process. run the model (for one time). Basically this is a map-reduce process.
>> config - the configuration >> config - the configuration
>> dataDistributor - data distributor >> dataDistributor - data distributor
......
...@@ -47,6 +47,9 @@ public: ...@@ -47,6 +47,9 @@ public:
/* deconstructor */ /* deconstructor */
~XLeaderPS(); ~XLeaderPS();
/* create workers and other stuff used in training */
void MakeAll(XConfig * config, XModel * model, const int * devIDs, const int jobWorkerNum);
/* run the model and update it (for one time) */ /* run the model and update it (for one time) */
bool Run(XConfig* config, DataDistributeBase* dataDistributor, XOptimizer* optimizer); bool Run(XConfig* config, DataDistributeBase* dataDistributor, XOptimizer* optimizer);
......
...@@ -113,14 +113,8 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -113,14 +113,8 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
/* create the server and workers */ /* create the server and workers */
XLeaderPS leader; XLeaderPS leader;
leader.Init(); leader.MakeAll(config, model, ids, jobNum);
leader.AddJobWorker(model, jobNum, ids);
leader.AddJobCollectWorker();
leader.AddJobUpdateWorker(model, optimizer);
leader.AddJobBroadcastWorker();
leader.AddJobParamterWorker(model->paramNum);
//leader.SetInstantRun(); //leader.SetInstantRun();
leader.MakeAll(config, model);
leader.Start(); leader.Start();
/* learning rate scheduler */ /* learning rate scheduler */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论