Commit 149b3380 by xiaotong

updates

parent f866d06d
...@@ -182,6 +182,7 @@ void XLeader::MakeAll(XConfig * config, XModel * model) ...@@ -182,6 +182,7 @@ void XLeader::MakeAll(XConfig * config, XModel * model)
{ {
SetServerModel(config, model); SetServerModel(config, model);
ResetParamGrad(); ResetParamGrad();
MakeParamMap();
} }
/* /*
...@@ -435,17 +436,36 @@ void XLeader::MakeParamMap() ...@@ -435,17 +436,36 @@ void XLeader::MakeParamMap()
{ {
DestroyParamMap(); DestroyParamMap();
modelNum = jworkers.count; modelNum = 0;
for (int i = 0; i < jworkers.count; i++) {
XWorker * worker = (XWorker*)jworkers[i];
if (worker->GetWorkerType() == XWORKER_TYPE_JOB) {
modelNum += worker->GetModelNum();
CheckNTErrors(worker->GetModelNum() == 1, "Wrong model number!");
}
else {
ShowNTErrors("TODO: support a new XWorker type!");
}
}
paramMap = new XTensorKeeper*[serverModel.paramNum]; paramMap = new XTensorKeeper*[serverModel.paramNum];
gradMap = new XTensorKeeper*[serverModel.paramNum]; gradMap = new XTensorKeeper*[serverModel.paramNum];
for(int i = 0; i < serverModel.paramNum; i++){ for(int i = 0; i < serverModel.paramNum; i++){
paramMap[i] = new XTensorKeeper[modelNum]; paramMap[i] = new XTensorKeeper[modelNum];
gradMap[i] = new XTensorKeeper[modelNum]; gradMap[i] = new XTensorKeeper[modelNum];
for(int j = 0; j < modelNum; j++){
XModel * model = ((XWorkerJob*)jworkers[j])->GetModel(); for (int j = 0, c = 0; j < jworkers.count; j++) {
paramMap[i][j].tensor = model->params[j].tensor; XWorker * worker = (XWorker*)jworkers[i];
gradMap[i][j].tensor = model->params[j].tensor; if (worker->GetWorkerType() == XWORKER_TYPE_JOB) {
XModel * model = ((XWorkerJob*)jworkers[j])->GetModel();
paramMap[i][c].tensor = model->params[i].tensor;
gradMap[i][c].tensor = model->params[i].tensor->grad;
c++;
}
else {
ShowNTErrors("TODO: support a new XWorker type!");
}
} }
} }
} }
...@@ -511,9 +531,6 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -511,9 +531,6 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
XWorkerUpdate * updater = (XWorkerUpdate*)uworkers.GetItem(0); XWorkerUpdate * updater = (XWorkerUpdate*)uworkers.GetItem(0);
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)bworkers.GetItem(0); XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)bworkers.GetItem(0);
/* member models that are active in this run */
XList members(jworkers.count);
/* all member models */ /* all member models */
XList membersAll(jworkers.count); XList membersAll(jworkers.count);
...@@ -523,8 +540,6 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -523,8 +540,6 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
for (int i = 0; i < jworkers.count; i++) { for (int i = 0; i < jworkers.count; i++) {
XWorkerJob* worker = (XWorkerJob*)jworkers[i]; XWorkerJob* worker = (XWorkerJob*)jworkers[i];
membersAll.Add(worker->GetModel()); membersAll.Add(worker->GetModel());
if (active[i] == 1)
members.Add(worker->GetModel());
} }
for (int i = 0; i < pworkers.count; i++) { for (int i = 0; i < pworkers.count; i++) {
...@@ -532,6 +547,8 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -532,6 +547,8 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
jobQueues.Add(worker->GetJobQueue()); jobQueues.Add(worker->GetJobQueue());
} }
CheckNTErrors(jobQueues.count == serverModel.paramNum, "Incompatiable model!");
/* jobs in queue 2 (say jobQueue): collect the (gradient) data and other stuff. /* jobs in queue 2 (say jobQueue): collect the (gradient) data and other stuff.
This is a reduce process. Then we add a job to to update the model. followed This is a reduce process. Then we add a job to to update the model. followed
by a job to broadcast the lastest parameters to workers. NOTE that we by a job to broadcast the lastest parameters to workers. NOTE that we
...@@ -543,23 +560,25 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -543,23 +560,25 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
for (int j = 0; j < serverModel.paramNum; j++) for (int j = 0; j < serverModel.paramNum; j++)
serverModel.params[j].flag = PARAM_STATE_NOT_READY; serverModel.params[j].flag = PARAM_STATE_NOT_READY;
/* check */
for (int i = 0; i < membersAll.count; i++) {
XModel * source = (XModel*)membersAll.GetItem(i);
CheckNTErrors(source->paramNum == serverModel.paramNum, "Incompatiable models!");
}
for (int i = 0; i < members.count; i++) {
XModel * source = (XModel*)members.GetItem(i);
CheckNTErrors(source->paramNum == serverModel.paramNum, "Incompatiable models!");
}
CheckNTErrors(jobQueues.count == serverModel.paramNum, "Incompatiable model!");
/* counts how many member models are collected for each parameter */ /* counts how many member models are collected for each parameter */
int * finishedCount = new int[serverModel.paramNum]; int * finishedCount = new int[serverModel.paramNum];
memset(finishedCount, 0, sizeof(int) * serverModel.paramNum); memset(finishedCount, 0, sizeof(int) * serverModel.paramNum);
/* flag active models */
int modelCount = 0;
int activeModelCount = 0;
int * modelFlag = new int[modelNum];
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i];
for (int j = 0; j < worker->GetModelNum(); j++) {
modelFlag[modelCount++] = active[i];
if (active[i] != 0)
activeModelCount++;
}
}
CheckNTErrors(modelCount == modelNum, "Wrong model number!");
/* This is a simple implementation of the do-and-wait process */ /* This is a simple implementation of the do-and-wait process */
while (1) { while (1) {
for (int j = 0; j < serverModel.paramNum; j++) { for (int j = 0; j < serverModel.paramNum; j++) {
...@@ -572,55 +591,59 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -572,55 +591,59 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
continue; continue;
/* check if all the models (or part of them) are ready */ /* check if all the models (or part of them) are ready */
for (int i = 0; i < members.count; i++) { for (int n = 0, i = 0; n < jworkers.count; n++) {
XModel * source = (XModel*)members.GetItem(i); XWorkerJob * worker = (XWorkerJob*)jworkers[n];
XTensorKeeper &paramSource = source->params[j]; for (int m = 0; m < worker->GetModelNum(); m++, i++) {
XTensorKeeper &paramSource = paramMap[j][i];
/* isGradFinished is true only if the model finishes the computation XTensorKeeper &gradSource = gradMap[j][i];
(in another process) */
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.tensor->isGradFinished) { /* isGradFinished is true only if the model finishes the computation
XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j); (in another process) */
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.tensor->isGradFinished) {
/* data transmit */ XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j);
collecter->AddJobCollectDataP2P(jobQueue, paramSource.tensor->grad, paramServer.tensor->grad);
collecter->AddJobEnqueueFinished(jobQueue); /* data transmit */
collecter->AddJobCollectDataP2P(jobQueue, gradSource.tensor, paramServer.tensor->grad);
/* reset the flag */ collecter->AddJobEnqueueFinished(jobQueue);
paramSource.flag = PARAM_STATE_COLLECTED;
finished++; /* reset the flag */
finishedCount[j]++; paramSource.flag = PARAM_STATE_COLLECTED;
finished++;
/* we call model update (in another thread) and then finishedCount[j]++;
broadcast the new parameters to member models
(in another thread) */ /* we call model update (in another thread) and then
if (finishedCount[j] == members.count) { broadcast the new parameters to member models
paramServer.flag = PARAM_STATE_COLLECTED; (in another thread) */
if (updater != NULL) { if (finishedCount[j] == activeModelCount) {
paramServer.flag = PARAM_STATE_COLLECTED;
/* update the parameters */ if (updater != NULL) {
updater->AddJobUpdate(jobQueue, &serverModel, j, optimizer);
updater->AddJobEnqueueFinished(jobQueue); /* update the parameters */
updater->AddJobUpdate(jobQueue, &serverModel, j, optimizer);
/* broadcast the new parameter to other models*/ updater->AddJobEnqueueFinished(jobQueue);
broadcaster->AddJobBroadcast(jobQueue, &serverModel, &membersAll, j);
broadcaster->AddJobEnqueueFinished(jobQueue); /* broadcast the new parameter to other models*/
broadcaster->AddJobBroadcast(jobQueue, &serverModel, &membersAll, j);
broadcaster->AddJobEnqueueFinished(jobQueue);
}
}
else if (finishedCount[j] > activeModelCount) {
ShowNTErrors("Something is wrong with finishedCount!");
} }
}
else if (finishedCount[j] > members.count) {
ShowNTErrors("Something is wrong with finishedCount!");
} }
} }
} }
} }
/* finishes if all data tensors are processed */ /* finishes if all data tensors are processed */
if (finished == serverModel.paramNum * members.count) if (finished == serverModel.paramNum * activeModelCount)
break; break;
XSleep(SLEEP_TIME_IN_WAITING_JOB_WORKERS); XSleep(SLEEP_TIME_IN_WAITING_JOB_WORKERS);
} }
delete[] finishedCount; delete[] finishedCount;
delete[] modelFlag;
} }
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -34,6 +34,7 @@ namespace nts { ...@@ -34,6 +34,7 @@ namespace nts {
/* constructor */ /* constructor */
XWorker::XWorker() XWorker::XWorker()
{ {
type = XWORKER_TYPE_UNKNOWN;
devID = -1; devID = -1;
id = -1; id = -1;
state = XWORKER_UNSTARTED; state = XWORKER_UNSTARTED;
...@@ -46,6 +47,12 @@ XWorker::~XWorker() ...@@ -46,6 +47,12 @@ XWorker::~XWorker()
Stop(); Stop();
} }
/* get worker type */
XWORKER_TYPE XWorker::GetWorkerType()
{
return type;
}
/* set device id */ /* set device id */
void XWorker::SetDeviceID(int myDevID) void XWorker::SetDeviceID(int myDevID)
{ {
...@@ -110,6 +117,21 @@ int XWorker::GetJobNum() ...@@ -110,6 +117,21 @@ int XWorker::GetJobNum()
return queue.GetJobNum(); return queue.GetJobNum();
} }
/*
get the number of the models for this worker. For a worker that runs
locally (i.e., on the same machine with the leader), the model number
is simply 1. For a remote worker (i.e., the worker must connect to a
remote leader), the model number counts the models that the remote leader
leads.
*/
int XWorker::GetModelNum()
{
/* TODO: return the remote model number if the worker connects
to another leader. */
return 1;
}
/* whether the job queue is empty? */ /* whether the job queue is empty? */
bool XWorker::IsEmpty() bool XWorker::IsEmpty()
{ {
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#ifndef __XWORKER_H__ #ifndef __XWORKER_H__
#define __XWORKER_H__ #define __XWORKER_H__
#include "XModel.h"
#include "../tensor/XQueue.h" #include "../tensor/XQueue.h"
#include "../tensor/XUtility.h" #include "../tensor/XUtility.h"
...@@ -42,12 +43,20 @@ state of a worker ...@@ -42,12 +43,20 @@ state of a worker
2) started 2) started
3) finished 3) finished
*/ */
enum XWORKER_STATE { XWORKER_UNSTARTED, XWORKER_STARTED, XWORKER_FINISHED }; enum XWORKER_STATE { XWORKER_UNSTARTED, XWORKER_STARTED, XWORKER_FINISHED };
/*
worker type
*/
enum XWORKER_TYPE { XWORKER_TYPE_UNKNOWN, XWORKER_TYPE_JOB, XWORKER_TYPE_COLLECT, XWORKER_TYPE_UPDATE, XWORKER_TYPE_BROADCAST };
/* the worker class */ /* the worker class */
class XWorker class XWorker
{ {
protected: protected:
/* type of the worker */
XWORKER_TYPE type;
/* id of the device where we run the worker (we suppose that /* id of the device where we run the worker (we suppose that
the worker is insite. */ the worker is insite. */
int devID; int devID;
...@@ -74,6 +83,9 @@ public: ...@@ -74,6 +83,9 @@ public:
/* de-constructor */ /* de-constructor */
~XWorker(); ~XWorker();
/* get worker type */
XWORKER_TYPE GetWorkerType();
/* set device id */ /* set device id */
void SetDeviceID(int myDevID); void SetDeviceID(int myDevID);
...@@ -101,9 +113,12 @@ public: ...@@ -101,9 +113,12 @@ public:
/* stop the work */ /* stop the work */
void Stop(); void Stop();
/* get the number of remaining jobs */ /* get the number of the remaining jobs */
int GetJobNum(); int GetJobNum();
/* get the number of the models for this worker */
int GetModelNum();
/* whether the job queue is empty? */ /* whether the job queue is empty? */
bool IsEmpty(); bool IsEmpty();
......
...@@ -36,6 +36,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -36,6 +36,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */ /* constructor */
XWorkerBroadcast::XWorkerBroadcast() XWorkerBroadcast::XWorkerBroadcast()
{ {
type = XWORKER_TYPE_BROADCAST;
} }
/* de-constructor */ /* de-constructor */
......
...@@ -34,6 +34,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -34,6 +34,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */ /* constructor */
XWorkerCollect::XWorkerCollect() XWorkerCollect::XWorkerCollect()
{ {
type = XWORKER_TYPE_COLLECT;
collectMode = DATA_COLLECT_P2P; collectMode = DATA_COLLECT_P2P;
} }
......
...@@ -34,6 +34,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -34,6 +34,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */ /* constructor */
XWorkerJob::XWorkerJob() XWorkerJob::XWorkerJob()
{ {
type = XWORKER_TYPE_JOB;
Clear(); Clear();
} }
......
...@@ -32,6 +32,7 @@ namespace nts { // namespace nts (NiuTrans.Tensor) ...@@ -32,6 +32,7 @@ namespace nts { // namespace nts (NiuTrans.Tensor)
/* constructor */ /* constructor */
XWorkerUpdate::XWorkerUpdate() XWorkerUpdate::XWorkerUpdate()
{ {
type = XWORKER_TYPE_UPDATE;
optimizer = NULL; optimizer = NULL;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论