Commit 7950bd3c by xiaotong

clean the code

parent 052a62b5
......@@ -485,106 +485,98 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
jobQueues.Add(worker->GetJobQueue());
}
if(1){
int finished = 0;
/* jobs in queue 2 (say jobQueue): collect the (gradient) data and other stuff.
This is a reduce process. Then we add a job to to update the model. followed
by a job to broadcast the lastest parameters to workers. NOTE that we
would update a worker to the latest model parameters, even if it is not
involved in this run. */
int finished = 0;
for (int j = 0; j < serverModel.paramNum; j++)
serverModel.params[j].flag = PARAM_STATE_NOT_READY;
for (int j = 0; j < serverModel.paramNum; j++)
serverModel.params[j].flag = PARAM_STATE_NOT_READY;
/* check */
for (int i = 0; i < membersAll.count; i++) {
XModel * source = (XModel*)membersAll.GetItem(i);
CheckNTErrors(source->paramNum == serverModel.paramNum, "Incompatiable models!");
}
/* check */
for (int i = 0; i < membersAll.count; i++) {
XModel * source = (XModel*)membersAll.GetItem(i);
CheckNTErrors(source->paramNum == serverModel.paramNum, "Incompatiable models!");
}
for (int i = 0; i < members.count; i++) {
XModel * source = (XModel*)members.GetItem(i);
CheckNTErrors(source->paramNum == serverModel.paramNum, "Incompatiable models!");
}
for (int i = 0; i < members.count; i++) {
XModel * source = (XModel*)members.GetItem(i);
CheckNTErrors(source->paramNum == serverModel.paramNum, "Incompatiable models!");
}
CheckNTErrors(jobQueues.count == serverModel.paramNum, "Incompatiable model!");
CheckNTErrors(jobQueues.count == serverModel.paramNum, "Incompatiable model!");
/* counts how many member models are collect for each parameters */
int * finishedCount = new int[serverModel.paramNum];
memset(finishedCount, 0, sizeof(int) * serverModel.paramNum);
/* counts how many member models are collect for each parameters */
int * finishedCount = new int[serverModel.paramNum];
memset(finishedCount, 0, sizeof(int) * serverModel.paramNum);
/* This is a simple implementation of the wait-and-collect process. But
there is a risk that some models are not available, that is, the
loop would never stop. A solution might be that we force the loop
to break after waiting for a short time. */
while (1) {
for (int j = 0; j < serverModel.paramNum; j++) {
/* This is a simple implementation of the wait-and-collect process. But
there is a risk that some models are not available, that is, the
loop would never stop. A solution might be that we force the loop
to break after waiting for a short time. */
while (1) {
for (int j = 0; j < serverModel.paramNum; j++) {
XParamKeeper &paramServer = serverModel.params[j];
XParamKeeper &paramServer = serverModel.params[j];
/* isGradFinished is true only if the model finishes the computation
(in another process) */
if (paramServer.flag != PARAM_STATE_NOT_READY || !paramServer.param->isGradFinished)
continue;
/* tp[j]->isGradFinished is true only if the model finishes the computation
/* check if all the models (or part of them) are ready */
for (int i = 0; i < members.count; i++) {
XModel * source = (XModel*)members.GetItem(i);
XParamKeeper &paramSource = source->params[j];
/* isGradFinished is true only if the model finishes the computation
(in another process) */
if (paramServer.flag != PARAM_STATE_NOT_READY || !paramServer.param->isGradFinished)
continue;
/* check if all the models (or part of them) are ready */
for (int i = 0; i < members.count; i++) {
XModel * source = (XModel*)members.GetItem(i);
XParamKeeper &paramSource = source->params[j];
/* sp[j]->isGradFinished is true only if the model finishes the computation
(in another process) */
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.param->isGradFinished) {
XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j);
/* data transmit */
collecter->AddJobCollectDataP2P(jobQueue, paramSource.param->grad, paramServer.param->grad);
collecter->AddJobEnqueueFinished();
/* reset the flag */
paramSource.flag = PARAM_STATE_COLLECTED;
finished++;
finishedCount[j]++;
/* we call model update (in another thread) and then
broadcast the new parameters to member models
(in another thread) */
if (finishedCount[j] == members.count) {
paramServer.flag = PARAM_STATE_COLLECTED;
if (updater != NULL) {
/* update the parameters */
updater->AddJobUpdate(jobQueue, &serverModel, j, optimizer);
updater->AddJobEnqueueFinished(jobQueue);
/* broadcast the new parameter to other models*/
broadcaster->AddJobBroadcastSingle(jobQueue, &serverModel, &membersAll, j);
broadcaster->AddJobEnqueueFinished(jobQueue);
}
}
else if (finishedCount[j] > members.count) {
ShowNTErrors("Something is wrong with finishedCount!");
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.param->isGradFinished) {
XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j);
/* data transmit */
collecter->AddJobCollectDataP2P(jobQueue, paramSource.param->grad, paramServer.param->grad);
collecter->AddJobEnqueueFinished(jobQueue);
/* reset the flag */
paramSource.flag = PARAM_STATE_COLLECTED;
finished++;
finishedCount[j]++;
/* we call model update (in another thread) and then
broadcast the new parameters to member models
(in another thread) */
if (finishedCount[j] == members.count) {
paramServer.flag = PARAM_STATE_COLLECTED;
if (updater != NULL) {
/* update the parameters */
updater->AddJobUpdate(jobQueue, &serverModel, j, optimizer);
updater->AddJobEnqueueFinished(jobQueue);
/* broadcast the new parameter to other models*/
broadcaster->AddJobBroadcastSingle(jobQueue, &serverModel, &membersAll, j);
broadcaster->AddJobEnqueueFinished(jobQueue);
}
}
else if (finishedCount[j] > members.count) {
ShowNTErrors("Something is wrong with finishedCount!");
}
}
}
/* the collection finishes if all data tensors are processed */
if (finished == serverModel.paramNum * members.count)
break;
XSleep(SLEEP_TIME_IN_WAITING_JOB_WORKERS);
}
delete[] finishedCount;
/* the collection finishes if all data tensors are processed */
if (finished == serverModel.paramNum * members.count)
break;
XSleep(SLEEP_TIME_IN_WAITING_JOB_WORKERS);
}
/* jobs in queue 2: collect the (gradient) data and other stuff. This
is a reduce process. The collector will add a job in queue 3
to update the model. The updater will add a job job in queue 4 to
broadcast the lastest parameters to workers. NOTE that we would update
a worker to the laster model parameters, even if it is not involved
in this run. */
//collecter->AddJobUpdateAll(&jobQueues,
// &members, &membersAll, &serverModel,
// optimizer, updater, broadcaster);
//collecter->AddJobEnqueueFinished();
delete[] finishedCount;
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -49,206 +49,6 @@ void XWorkerCollect::SetCollectMode(DATA_COLLECT_TYPE myMode)
}
/*
collect the gradient data, update the parameters, and broadcast the
new parameters to all models. NOTE that this method just collect graident
from member models. Then it calls an XWorkerUpdate to update the parameters.
The XWorkerUpdate also calls an XWorkerBroadcast to broadcast the new parameter
to member models back.
>> jobQueues - queues that we process the jobs
>> memberActive - member models that are active, i.e., have generated gradients
>> memberAll - all member models
>> server - the server model
>> optimizer - the optimizer
>> updater - the worker that updates the parameters
>> broadcaster - the worker that broadcasts the new parameters to all member
models
>> sleepTime - waiting time in collecting
*/
void XWorkerCollect::UpdateDataAll(XList * jobQueues, XList * memberActive, XList * memberAll,
XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater,
XWorkerBroadcast * broadcaster, int sleepTime)
{
int finished = 0;
for (int j = 0; j < server->paramNum; j++)
server->params[j].flag = PARAM_STATE_NOT_READY;
/* check */
for (int i = 0; i < memberAll->count; i++) {
XModel * source = (XModel*)memberAll->GetItem(i);
CheckNTErrors(source->paramNum == server->paramNum, "Incompatiable models!");
}
for (int i = 0; i < memberActive->count; i++) {
XModel * source = (XModel*)memberActive->GetItem(i);
CheckNTErrors(source->paramNum == server->paramNum, "Incompatiable models!");
}
CheckNTErrors(jobQueues->count == server->paramNum, "Incompatiable model!");
/* counts how many member models are collect for each parameters */
int * finishedCount = new int[server->paramNum];
memset(finishedCount, 0, sizeof(int) * server->paramNum);
/* This is a simple implementation of the wait-and-collect process. But
there is a risk that some models are not available, that is, the
loop would never stop. A solution might be that we force the loop
to break after waiting for a short time. */
while (1) {
if (collectMode == DATA_COLLECT_P2P) {
for (int j = 0; j < server->paramNum; j++) {
XParamKeeper &paramServer = server->params[j];
/* tp[j]->isGradFinished is true only if the model finishes the computation
(in another process) */
if (paramServer.flag != PARAM_STATE_NOT_READY || !paramServer.param->isGradFinished)
continue;
/* check if all the models (or part of them) are ready */
for (int i = 0; i < memberActive->count; i++) {
XModel * source = (XModel*)memberActive->GetItem(i);
XParamKeeper &paramSource = source->params[j];
/* sp[j]->isGradFinished is true only if the model finishes the computation
(in another process) */
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.param->isGradFinished) {
/* data transmit */
CollectP2P(paramSource.param->grad, paramServer.param->grad);
/* reset the flag */
paramSource.flag = PARAM_STATE_COLLECTED;
finished++;
finishedCount[j]++;
/* we call model update (in another thread) and then
broadcast the new parameters to member models
(in another thread) */
if (finishedCount[j] == memberActive->count) {
paramServer.flag = PARAM_STATE_COLLECTED;
if (updater != NULL) {
XQueue* jobQueue = (XQueue*)jobQueues->GetItem(j);
/* update the parameters */
updater->AddJobUpdate(jobQueue, server, j, optimizer);
updater->AddJobEnqueueFinished(jobQueue);
/* broadcast the new parameter to other models*/
broadcaster->AddJobBroadcastSingle(jobQueue, server, memberAll, j);
broadcaster->AddJobEnqueueFinished(jobQueue);
}
}
else if (finishedCount[j] > memberActive->count) {
ShowNTErrors("Something is wrong with finishedCount!");
}
}
}
}
}
else {
ShowNTErrors("Unsupported data collection mode!");
}
/* the collection finishes if all data tensors are processed */
if (finished == server->paramNum * memberActive->count)
break;
XSleep(sleepTime);
}
delete[] finishedCount;
}
/* wrapper of UpdateDataAll */
void XWorkerCollect::UpdateAll(XList * args)
{
int paramCount = 0;
XWorkerCollect * collecter = (XWorkerCollect*)args->GetItem(paramCount++);
int queueNum = args->GetInt(paramCount++);
XList jobQueues;
for (int i = 0; i < queueNum; i++) {
XQueue * queue = (XQueue*)args->GetItem(paramCount++);
jobQueues.Add(queue);
}
int activeNum = args->GetInt(paramCount++);
XList memberActive;
for (int i = 0; i < activeNum; i++) {
XModel * member = (XModel*)args->GetItem(paramCount++);
memberActive.Add(member);
}
int allNum = args->GetInt(paramCount++);
XList memberAll;
for (int i = 0; i < allNum; i++) {
XModel * member = (XModel*)args->GetItem(paramCount++);
memberAll.Add(member);
}
XModel * server = (XModel*)args->GetItem(paramCount++);
XOptimizer * optimizer = (XOptimizer*)args->GetItem(paramCount++);
XWorkerUpdate * updater = (XWorkerUpdate*)args->GetItem(paramCount++);
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(paramCount++);
collecter->UpdateDataAll(&jobQueues, &memberActive, &memberAll, server,
optimizer, updater, broadcaster,
SLEEP_TIME_IN_COLLECTING);
}
/*
add a new job of collecting data, update the parameter and
broadcast the new parameter
>> jobQueues - the queues that we would use in following jobs
>> memberActive - member models that are active, i.e., have generated gradients
>> memberAll - all member models
>> server - the server model
>> optimizer - the optimizer
>> updater - the worker that updates the parameters
>> broadcaster - the worker that broadcasts the new parameters to all member
models
<< return - successful or not
*/
bool XWorkerCollect::AddJobUpdateAll(XList * jobQueues,
XList * memberActive, XList * memberAll,
XModel * server,
XOptimizer * optimizer,
XWorkerUpdate * updater, XWorkerBroadcast * broadcaster)
{
CheckNTErrors(memberActive != NULL, "No input (active) member list!");
CheckNTErrors(memberAll != NULL, "No input (all) member list!");
CheckNTErrors(server != NULL, "No input server model!");
CheckNTErrors(optimizer != NULL, "No input optimizer!");
CheckNTErrors(updater != NULL, "No input updater!");
CheckNTErrors(broadcaster != NULL, "No input broadcaster!");
XList args;
args.Add(this);
args.AddInt(jobQueues->count);
args.AddList(jobQueues);
args.AddInt(memberActive->count);
args.AddList(memberActive);
args.AddInt(memberAll->count);
args.AddList(memberAll);
args.Add(server);
args.Add(optimizer);
args.Add(updater);
args.Add(broadcaster);
if (isInstantRun)
XWorkerCollect::UpdateAll(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerCollect::UpdateAll, &args);
return true;
}
/*
P2P data collection
target += source
......
......@@ -65,23 +65,6 @@ public:
/* set the collection type */
void SetCollectMode(DATA_COLLECT_TYPE myMode);
/* collect the gradient data, update the parameters, and broadcast the
new parameters to all models. NOTE that this method just collects graidents
from member models. Then it calls an XWorkerUpdate to update the parameters.
The XWorkerUpdate also calls an XWorkerBroadcast to broadcast the new parameter
to member models back. */
void UpdateDataAll(XList * jobQueues, XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster,
int sleepTime);
/* wrapper of UpdateDataAll */
static
void UpdateAll(XList * args);
/* add a new job of collecting data, update the parameter and broadcast the new parameter */
bool AddJobUpdateAll(XList * jobQueues, XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster);
/* P2P data collection */
void CollectP2P(XTensor * source, XTensor * target);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论