Commit ac855f81 by xiaotong

set grad = 0 after each update of a model

parent 1c9973c9
...@@ -77,12 +77,12 @@ void TestTrain() ...@@ -77,12 +77,12 @@ void TestTrain()
XConfig config; XConfig config;
//config.Add("dev", -1); //config.Add("dev", -1);
config.Add("lrate", 0.001F); config.Add("lrate", 0.1F);
config.Add("nstep", 100000); config.Add("nstep", 100000);
config.Add("nepoch", 5); config.Add("nepoch", 5);
config.Add("jobdev0", 0); config.Add("jobdev0", -1);
config.Add("jobdev1", -1); config.Add("jobdev1", 0);
config.Add("jobdev2", -1); //config.Add("jobdev2", -1);
//config.Add("jobdev3", -1); //config.Add("jobdev3", -1);
//config.Add("jobdev4", -1); //config.Add("jobdev4", -1);
......
...@@ -136,11 +136,34 @@ void XLeader::InitForRun() ...@@ -136,11 +136,34 @@ void XLeader::InitForRun()
} }
} }
/* set grad = 0 */
void XLeader::ResetParamGrad()
{
for (int i = 0; i < serverModel.paramNum; i++) {
XTensor* param = serverModel.params[i].param;
if (param->grad != NULL) {
param->grad->SetZeroAll();
}
}
for (int j = 0; j < jworkers.count; j++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[j];
XModel * model = worker->GetModel();
for (int i = 0; i < model->paramNum; i++) {
XTensor* param = model->params[i].param;
if (param->grad != NULL) {
param->grad->SetZeroAll();
}
}
}
}
/* /*
wait for finished states (i.e., all workers finish their jobs) wait for finished states (i.e., all workers finish their jobs)
>> activeJobWorkers - indicates whether each job worker is active >> activeJobWorkers - indicates whether each job worker is active
>> isToUpdate - indicates whether the model is updated
*/ */
void XLeader::WaitForFinishing(const int* activeJobWorkers) void XLeader::WaitForFinishing(const int* activeJobWorkers, const int isToUpdate)
{ {
int activeCount = 0; int activeCount = 0;
for (int i = 0; i < jworkers.count; i++) { for (int i = 0; i < jworkers.count; i++) {
...@@ -151,7 +174,7 @@ void XLeader::WaitForFinishing(const int* activeJobWorkers) ...@@ -151,7 +174,7 @@ void XLeader::WaitForFinishing(const int* activeJobWorkers)
} }
} }
if (activeCount > 0) { if (activeCount > 0 && isToUpdate) {
for (int i = 0; i < cworkers.count; i++) { for (int i = 0; i < cworkers.count; i++) {
XWorker* worker = (XWorker*)cworkers[i]; XWorker* worker = (XWorker*)cworkers[i];
worker->DequeueFinishedJob(); worker->DequeueFinishedJob();
...@@ -319,6 +342,7 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -319,6 +342,7 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
CheckNTErrors(bworkers.count > 0, "No bworkers!"); CheckNTErrors(bworkers.count > 0, "No bworkers!");
bool isDataOK = true; bool isDataOK = true;
bool isToUpdate = (optimizer != NULL);
int activeJobCount = 0; int activeJobCount = 0;
int* active = new int[jworkers.count]; int* active = new int[jworkers.count];
...@@ -360,7 +384,7 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -360,7 +384,7 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
} }
} }
if (activeJobCount > 0) { if (activeJobCount > 0 && isToUpdate) {
/* workers */ /* workers */
XWorkerCollect * collecter = (XWorkerCollect*)cworkers.GetItem(0); XWorkerCollect * collecter = (XWorkerCollect*)cworkers.GetItem(0);
XWorkerUpdate * updater = (XWorkerUpdate*)uworkers.GetItem(0); XWorkerUpdate * updater = (XWorkerUpdate*)uworkers.GetItem(0);
...@@ -384,28 +408,18 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -384,28 +408,18 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
} }
} }
/* jobs in queue 2: collect the (gradient) data and other stuff. This
is a reduce process. The collector will add a job in queue 3
to update the model. The updater will add a job job in queue 4 to
broadcast the lastest parameters to workers. NOTE that we would update
a worker to the laster model parameters, even if it is not involved
in this run. */
collecter->AddJobUpdateAll(&members, &membersAll, &serverModel, collecter->AddJobUpdateAll(&members, &membersAll, &serverModel,
optimizer, updater, broadcaster); optimizer, updater, broadcaster);
//collecter->AddJobCollectOther(&memberRecords, &serverRecord);
collecter->AddJobEnqueueFinished(); collecter->AddJobEnqueueFinished();
/* jobs in queue 2: collect the (gradient) data and other stuff. This
is a reduce process. */
//collecter->AddJobCollect(&members, &serverModel);
//collecter->AddJobCollectOther(&memberRecords, &serverRecord);
/* job in queue 3: update the model */
//updater->AddJobUpdate(&serverModel, optimizer);
/* job in queue 4: broadcast the lastest parameters to workers. NOTE that
we would update a worker to the laster model parameters, even if it is
not involved in this run. */
//broadcaster->AddJobBroadcast(&serverModel, &membersAll);
//WaitForFinishing();
} }
WaitForFinishing(active); WaitForFinishing(active, isToUpdate);
for (int i = 0; i < jworkers.count; i++) { for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i]; XWorkerJob * worker = (XWorkerJob*)jworkers[i];
......
...@@ -112,8 +112,11 @@ public: ...@@ -112,8 +112,11 @@ public:
/* initialize the models for running them */ /* initialize the models for running them */
void InitForRun(); void InitForRun();
/* set grad = 0 */
void ResetParamGrad();
/* wait for finished states (i.e., all workers finish their jobs) */ /* wait for finished states (i.e., all workers finish their jobs) */
void WaitForFinishing(const int * activeJobWorkers); void WaitForFinishing(const int * activeJobWorkers, const int isToUpdate);
/* get loss */ /* get loss */
float GetLoss(); float GetLoss();
......
...@@ -96,7 +96,11 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -96,7 +96,11 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
int epoch = 0; int epoch = 0;
int step = 0; int step = 0;
int stepAll = 0;
int jobNum = 0; int jobNum = 0;
int accumulation = config->GetInt("accumulation", 1);
CheckNTErrors(accumulation >= 1, "accumulation must be larger than 0!");
int * ids = new int[MAX_DEVICE_NUM_TRAINING]; int * ids = new int[MAX_DEVICE_NUM_TRAINING];
GetDevIDs(config, ids, jobNum, MAX_DEVICE_NUM_TRAINING); GetDevIDs(config, ids, jobNum, MAX_DEVICE_NUM_TRAINING);
...@@ -126,18 +130,26 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -126,18 +130,26 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
dataDistributor->Start(); dataDistributor->Start();
while (ok) { while (ok) {
if (++stepAll % accumulation == 0) {
/* one step of udpate */ /* one step of udpate */
ok = leader.Run(config, dataDistributor, model, optimizer); ok = leader.Run(config, dataDistributor, model, optimizer);
float loss = leader.GetLoss() / leader.GetSampleNum(); float loss = leader.GetLoss() / leader.GetSampleNum();
if ((step + 1) % 100 == 0) if ((step + 1) % 100 == 0)
XPRINT5(1, stderr, "[INFO] elapsed=%.1fs epoch:%d step:%d sample:%d loss:%f\n", XPRINT5(1, stderr, "[INFO] elapsed=%.1fs epoch:%d step:%d sample:%d loss:%f\n",
GetClockSec() - startT, epoch + 1, step + 1, leader.GetSampleNum(), loss); GetClockSec() - startT, epoch + 1, step + 1, leader.GetSampleNum(), loss);
if (++step >= optimizer->nstep) leader.ResetParamGrad();
break;
if (++step >= optimizer->nstep)
break;
}
else {
/* one step with no udpate */
ok = leader.Run(config, dataDistributor, model, NULL);
}
} }
dataDistributor->End(); dataDistributor->End();
...@@ -169,6 +181,8 @@ void XTrainer::ShowSettings(XConfig* config) ...@@ -169,6 +181,8 @@ void XTrainer::ShowSettings(XConfig* config)
} }
} }
XPRINT2(1, stderr, "%25s = %d\n", "accumulation", config->GetInt("accumulation", 1));
delete[] ids; delete[] ids;
} }
......
...@@ -127,7 +127,6 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod ...@@ -127,7 +127,6 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
if (updater != NULL) { if (updater != NULL) {
updater->AddJobUpdateSingle(server, memberAll, j, optimizer, broadcaster); updater->AddJobUpdateSingle(server, memberAll, j, optimizer, broadcaster);
updater->AddJobEnqueueFinished(); updater->AddJobEnqueueFinished();
} }
} }
else if (finishedCount[j] > memberActive->count) { else if (finishedCount[j] > memberActive->count) {
...@@ -195,8 +194,18 @@ void XWorkerCollect::CollectP2P(XTensor * source, XTensor * target) ...@@ -195,8 +194,18 @@ void XWorkerCollect::CollectP2P(XTensor * source, XTensor * target)
CheckNTErrors(IsSameShaped(*source, *target), "The two tensors should be of the same shape!"); CheckNTErrors(IsSameShaped(*source, *target), "The two tensors should be of the same shape!");
/* target += source */ /* target += source */
if(source != target) if (source != target) {
_Sum(source, target, source); XTensor * sourceOnSite = source;
if (source->devID != target->devID) {
sourceOnSite = new XTensor(target);
_CopyValues(source, sourceOnSite);
}
_Sum(target, sourceOnSite, target);
if (sourceOnSite != source)
delete sourceOnSite;
}
} }
/* /*
......
...@@ -142,7 +142,8 @@ void XWorkerUpdate::UpdateSingle(XList * args) ...@@ -142,7 +142,8 @@ void XWorkerUpdate::UpdateSingle(XList * args)
XOptimizer * optimizer = (XOptimizer*)args->GetItem(3 + memNum + 1); XOptimizer * optimizer = (XOptimizer*)args->GetItem(3 + memNum + 1);
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(3 + memNum + 2); XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(3 + memNum + 2);
updater->UpdateParameter(server, &members, pid, optimizer, broadcaster); if(updater != NULL)
updater->UpdateParameter(server, &members, pid, optimizer, broadcaster);
} }
/* /*
...@@ -159,7 +160,8 @@ void XWorkerUpdate::Update(XList * args) ...@@ -159,7 +160,8 @@ void XWorkerUpdate::Update(XList * args)
XModel * model = (XModel*)args->GetItem(1); XModel * model = (XModel*)args->GetItem(1);
XOptimizer * optimizer = (XOptimizer*)args->GetItem(2); XOptimizer * optimizer = (XOptimizer*)args->GetItem(2);
updater->UpdateModel(model, optimizer, SLEEP_TIME_IN_MODEL_UPDATE); if(updater != NULL)
updater->UpdateModel(model, optimizer, SLEEP_TIME_IN_MODEL_UPDATE);
//fprintf(stderr, "update 1\n"); //fprintf(stderr, "update 1\n");
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论