Commit ac855f81 by xiaotong

set grad = 0 after each update of a model

parent 1c9973c9
......@@ -77,12 +77,12 @@ void TestTrain()
XConfig config;
//config.Add("dev", -1);
config.Add("lrate", 0.001F);
config.Add("lrate", 0.1F);
config.Add("nstep", 100000);
config.Add("nepoch", 5);
config.Add("jobdev0", 0);
config.Add("jobdev1", -1);
config.Add("jobdev2", -1);
config.Add("jobdev0", -1);
config.Add("jobdev1", 0);
//config.Add("jobdev2", -1);
//config.Add("jobdev3", -1);
//config.Add("jobdev4", -1);
......
......@@ -136,11 +136,34 @@ void XLeader::InitForRun()
}
}
/* set grad = 0 */
void XLeader::ResetParamGrad()
{
for (int i = 0; i < serverModel.paramNum; i++) {
XTensor* param = serverModel.params[i].param;
if (param->grad != NULL) {
param->grad->SetZeroAll();
}
}
for (int j = 0; j < jworkers.count; j++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[j];
XModel * model = worker->GetModel();
for (int i = 0; i < model->paramNum; i++) {
XTensor* param = model->params[i].param;
if (param->grad != NULL) {
param->grad->SetZeroAll();
}
}
}
}
/*
wait for finished states (i.e., all workers finish their jobs)
>> activeJobWorkers - indicates whether each job worker is active
>> isToUpdate - indicates whether the model is updated
*/
void XLeader::WaitForFinishing(const int* activeJobWorkers)
void XLeader::WaitForFinishing(const int* activeJobWorkers, const int isToUpdate)
{
int activeCount = 0;
for (int i = 0; i < jworkers.count; i++) {
......@@ -151,7 +174,7 @@ void XLeader::WaitForFinishing(const int* activeJobWorkers)
}
}
if (activeCount > 0) {
if (activeCount > 0 && isToUpdate) {
for (int i = 0; i < cworkers.count; i++) {
XWorker* worker = (XWorker*)cworkers[i];
worker->DequeueFinishedJob();
......@@ -319,6 +342,7 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
CheckNTErrors(bworkers.count > 0, "No bworkers!");
bool isDataOK = true;
bool isToUpdate = (optimizer != NULL);
int activeJobCount = 0;
int* active = new int[jworkers.count];
......@@ -360,7 +384,7 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
}
}
if (activeJobCount > 0) {
if (activeJobCount > 0 && isToUpdate) {
/* workers */
XWorkerCollect * collecter = (XWorkerCollect*)cworkers.GetItem(0);
XWorkerUpdate * updater = (XWorkerUpdate*)uworkers.GetItem(0);
......@@ -384,28 +408,18 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
}
}
/* jobs in queue 2: collect the (gradient) data and other stuff. This
is a reduce process. The collector will add a job in queue 3
to update the model. The updater will add a job job in queue 4 to
broadcast the lastest parameters to workers. NOTE that we would update
a worker to the laster model parameters, even if it is not involved
in this run. */
collecter->AddJobUpdateAll(&members, &membersAll, &serverModel,
optimizer, updater, broadcaster);
//collecter->AddJobCollectOther(&memberRecords, &serverRecord);
optimizer, updater, broadcaster);
collecter->AddJobEnqueueFinished();
/* jobs in queue 2: collect the (gradient) data and other stuff. This
is a reduce process. */
//collecter->AddJobCollect(&members, &serverModel);
//collecter->AddJobCollectOther(&memberRecords, &serverRecord);
/* job in queue 3: update the model */
//updater->AddJobUpdate(&serverModel, optimizer);
/* job in queue 4: broadcast the lastest parameters to workers. NOTE that
we would update a worker to the laster model parameters, even if it is
not involved in this run. */
//broadcaster->AddJobBroadcast(&serverModel, &membersAll);
//WaitForFinishing();
}
WaitForFinishing(active);
WaitForFinishing(active, isToUpdate);
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i];
......
......@@ -112,8 +112,11 @@ public:
/* initialize the models for running them */
void InitForRun();
/* set grad = 0 */
void ResetParamGrad();
/* wait for finished states (i.e., all workers finish their jobs) */
void WaitForFinishing(const int * activeJobWorkers);
void WaitForFinishing(const int * activeJobWorkers, const int isToUpdate);
/* get loss */
float GetLoss();
......
......@@ -96,7 +96,11 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
int epoch = 0;
int step = 0;
int stepAll = 0;
int jobNum = 0;
int accumulation = config->GetInt("accumulation", 1);
CheckNTErrors(accumulation >= 1, "accumulation must be larger than 0!");
int * ids = new int[MAX_DEVICE_NUM_TRAINING];
GetDevIDs(config, ids, jobNum, MAX_DEVICE_NUM_TRAINING);
......@@ -126,18 +130,26 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
dataDistributor->Start();
while (ok) {
if (++stepAll % accumulation == 0) {
/* one step of udpate */
ok = leader.Run(config, dataDistributor, model, optimizer);
/* one step of udpate */
ok = leader.Run(config, dataDistributor, model, optimizer);
float loss = leader.GetLoss() / leader.GetSampleNum();
float loss = leader.GetLoss() / leader.GetSampleNum();
if ((step + 1) % 100 == 0)
XPRINT5(1, stderr, "[INFO] elapsed=%.1fs epoch:%d step:%d sample:%d loss:%f\n",
if ((step + 1) % 100 == 0)
XPRINT5(1, stderr, "[INFO] elapsed=%.1fs epoch:%d step:%d sample:%d loss:%f\n",
GetClockSec() - startT, epoch + 1, step + 1, leader.GetSampleNum(), loss);
if (++step >= optimizer->nstep)
break;
leader.ResetParamGrad();
if (++step >= optimizer->nstep)
break;
}
else {
/* one step with no udpate */
ok = leader.Run(config, dataDistributor, model, NULL);
}
}
dataDistributor->End();
......@@ -169,6 +181,8 @@ void XTrainer::ShowSettings(XConfig* config)
}
}
XPRINT2(1, stderr, "%25s = %d\n", "accumulation", config->GetInt("accumulation", 1));
delete[] ids;
}
......
......@@ -127,7 +127,6 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
if (updater != NULL) {
updater->AddJobUpdateSingle(server, memberAll, j, optimizer, broadcaster);
updater->AddJobEnqueueFinished();
}
}
else if (finishedCount[j] > memberActive->count) {
......@@ -195,8 +194,18 @@ void XWorkerCollect::CollectP2P(XTensor * source, XTensor * target)
CheckNTErrors(IsSameShaped(*source, *target), "The two tensors should be of the same shape!");
/* target += source */
if(source != target)
_Sum(source, target, source);
if (source != target) {
XTensor * sourceOnSite = source;
if (source->devID != target->devID) {
sourceOnSite = new XTensor(target);
_CopyValues(source, sourceOnSite);
}
_Sum(target, sourceOnSite, target);
if (sourceOnSite != source)
delete sourceOnSite;
}
}
/*
......
......@@ -142,7 +142,8 @@ void XWorkerUpdate::UpdateSingle(XList * args)
XOptimizer * optimizer = (XOptimizer*)args->GetItem(3 + memNum + 1);
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(3 + memNum + 2);
updater->UpdateParameter(server, &members, pid, optimizer, broadcaster);
if(updater != NULL)
updater->UpdateParameter(server, &members, pid, optimizer, broadcaster);
}
/*
......@@ -159,7 +160,8 @@ void XWorkerUpdate::Update(XList * args)
XModel * model = (XModel*)args->GetItem(1);
XOptimizer * optimizer = (XOptimizer*)args->GetItem(2);
updater->UpdateModel(model, optimizer, SLEEP_TIME_IN_MODEL_UPDATE);
if(updater != NULL)
updater->UpdateModel(model, optimizer, SLEEP_TIME_IN_MODEL_UPDATE);
//fprintf(stderr, "update 1\n");
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论