Commit 81483f00 by xiaotong

bug fixes

parent 149b3380
...@@ -379,51 +379,13 @@ void XLeader::AddJobParamterWorker(int n) ...@@ -379,51 +379,13 @@ void XLeader::AddJobParamterWorker(int n)
} }
} }
/*
run the model (for one time). Basically this is a map-reduce process.
>> config - the configuration
>> dataDistributor - data distributor
>> optimizer - the optimization method
<< return - if we can fetch the new data
*/
bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, XOptimizer * optimizer)
{
CheckNTErrors(jworkers.count > 0, "No jworkers!");
CheckNTErrors(cworkers.count > 0, "No cworkers!");
CheckNTErrors(uworkers.count > 0, "No uworkers!");
CheckNTErrors(bworkers.count > 0, "No bworkers!");
CheckNTErrors(pworkers.count > 0, "No pworkers!");
bool isToUpdate = (optimizer != NULL);
int activeJobCount = 0;
int* active = new int[jworkers.count];
InitForRun();
/* run models on job workers */
activeJobCount = RunModel(config, dataDistributor, active);
/* update the model on the server side */
if (activeJobCount > 0 && isToUpdate)
RunUpdate(config, optimizer, active);
WaitForFinishing(active, isToUpdate);
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i];
worker->Clear();
}
delete[] active;
return activeJobCount > 0;
}
/* destroy the parameter map (and gradient map) */ /* destroy the parameter map (and gradient map) */
void XLeader::DestroyParamMap() void XLeader::DestroyParamMap()
{ {
for(int i = 0; i < modelNum; i++){ for(int i = 0; i < serverModel.paramNum; i++){
if(paramMap != NULL)
delete[] paramMap[i]; delete[] paramMap[i];
if(gradMap != NULL)
delete[] gradMap[i]; delete[] gradMap[i];
} }
delete[] paramMap; delete[] paramMap;
...@@ -434,13 +396,11 @@ void XLeader::DestroyParamMap() ...@@ -434,13 +396,11 @@ void XLeader::DestroyParamMap()
/* generate the map of parameters */ /* generate the map of parameters */
void XLeader::MakeParamMap() void XLeader::MakeParamMap()
{ {
DestroyParamMap(); int modelCount = 0;
modelNum = 0;
for (int i = 0; i < jworkers.count; i++) { for (int i = 0; i < jworkers.count; i++) {
XWorker * worker = (XWorker*)jworkers[i]; XWorker * worker = (XWorker*)jworkers[i];
if (worker->GetWorkerType() == XWORKER_TYPE_JOB) { if (worker->GetWorkerType() == XWORKER_TYPE_JOB) {
modelNum += worker->GetModelNum(); modelCount += worker->GetModelNum();
CheckNTErrors(worker->GetModelNum() == 1, "Wrong model number!"); CheckNTErrors(worker->GetModelNum() == 1, "Wrong model number!");
} }
else { else {
...@@ -448,19 +408,28 @@ void XLeader::MakeParamMap() ...@@ -448,19 +408,28 @@ void XLeader::MakeParamMap()
} }
} }
if(modelCount != modelNum){
DestroyParamMap();
paramMap = new XTensorKeeper*[serverModel.paramNum]; paramMap = new XTensorKeeper*[serverModel.paramNum];
gradMap = new XTensorKeeper*[serverModel.paramNum]; gradMap = new XTensorKeeper*[serverModel.paramNum];
}
for(int i = 0; i < serverModel.paramNum; i++){ for(int i = 0; i < serverModel.paramNum; i++){
paramMap[i] = new XTensorKeeper[modelNum]; if(modelCount != modelNum){
gradMap[i] = new XTensorKeeper[modelNum]; paramMap[i] = new XTensorKeeper[modelCount];
gradMap[i] = new XTensorKeeper[modelCount];
}
for (int j = 0, c = 0; j < jworkers.count; j++) { for (int j = 0, c = 0; j < jworkers.count; j++) {
XWorker * worker = (XWorker*)jworkers[i]; XWorker * worker = (XWorker*)jworkers[i];
if (worker->GetWorkerType() == XWORKER_TYPE_JOB) { if (worker->GetWorkerType() == XWORKER_TYPE_JOB) {
XModel * model = ((XWorkerJob*)jworkers[j])->GetModel(); XModel * model = ((XWorkerJob*)jworkers[j])->GetModel();
paramMap[i][c].tensor = model->params[i].tensor; paramMap[i][c].tensor = model->params[i].tensor;
paramMap[i][c].flag = PARAM_STATE_NOT_READY;
paramMap[i][c].trainFlag = PARAM_STATE_NOT_READY;
gradMap[i][c].tensor = model->params[i].tensor->grad; gradMap[i][c].tensor = model->params[i].tensor->grad;
gradMap[i][c].flag = PARAM_STATE_NOT_READY;
gradMap[i][c].trainFlag = PARAM_STATE_NOT_READY;
c++; c++;
} }
else { else {
...@@ -468,6 +437,48 @@ void XLeader::MakeParamMap() ...@@ -468,6 +437,48 @@ void XLeader::MakeParamMap()
} }
} }
} }
modelNum = modelCount;
}
/*
run the model (for one time). Basically this is a map-reduce process.
>> config - the configuration
>> dataDistributor - data distributor
>> optimizer - the optimization method
<< return - if we can fetch the new data
*/
bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, XOptimizer * optimizer)
{
CheckNTErrors(jworkers.count > 0, "No jworkers!");
CheckNTErrors(cworkers.count > 0, "No cworkers!");
CheckNTErrors(uworkers.count > 0, "No uworkers!");
CheckNTErrors(bworkers.count > 0, "No bworkers!");
CheckNTErrors(pworkers.count > 0, "No pworkers!");
bool isToUpdate = (optimizer != NULL);
int activeJobCount = 0;
int* active = new int[jworkers.count];
InitForRun();
/* run models on job workers */
activeJobCount = RunModel(config, dataDistributor, active);
/* update the model on the server side */
if (activeJobCount > 0 && isToUpdate)
RunUpdate(config, optimizer, active);
WaitForFinishing(active, isToUpdate);
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i];
worker->Clear();
}
delete[] active;
return activeJobCount > 0;
} }
/* /*
...@@ -531,6 +542,9 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -531,6 +542,9 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
XWorkerUpdate * updater = (XWorkerUpdate*)uworkers.GetItem(0); XWorkerUpdate * updater = (XWorkerUpdate*)uworkers.GetItem(0);
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)bworkers.GetItem(0); XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)bworkers.GetItem(0);
/* parameter map */
MakeParamMap();
/* all member models */ /* all member models */
XList membersAll(jworkers.count); XList membersAll(jworkers.count);
...@@ -600,6 +614,11 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -600,6 +614,11 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
/* isGradFinished is true only if the model finishes the computation /* isGradFinished is true only if the model finishes the computation
(in another process) */ (in another process) */
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.tensor->isGradFinished) { if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.tensor->isGradFinished) {
/* get the gradient */
gradSource.tensor = paramSource.tensor->grad;
/* the job queue of updating parameter j */
XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j); XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j);
/* data transmit */ /* data transmit */
......
...@@ -35,6 +35,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -35,6 +35,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
XWorkerJob::XWorkerJob() XWorkerJob::XWorkerJob()
{ {
type = XWORKER_TYPE_JOB; type = XWORKER_TYPE_JOB;
model = NULL;
Clear(); Clear();
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论