Commit 81483f00 by xiaotong

bug fixes

parent 149b3380
......@@ -378,6 +378,68 @@ void XLeader::AddJobParamterWorker(int n)
pworkers.Add(worker);
}
}
/* destroy the parameter map (and gradient map) */
void XLeader::DestroyParamMap()
{
for(int i = 0; i < serverModel.paramNum; i++){
if(paramMap != NULL)
delete[] paramMap[i];
if(gradMap != NULL)
delete[] gradMap[i];
}
delete[] paramMap;
delete[] gradMap;
modelNum = 0;
}
/* generate the map of parameters */
void XLeader::MakeParamMap()
{
int modelCount = 0;
for (int i = 0; i < jworkers.count; i++) {
XWorker * worker = (XWorker*)jworkers[i];
if (worker->GetWorkerType() == XWORKER_TYPE_JOB) {
modelCount += worker->GetModelNum();
CheckNTErrors(worker->GetModelNum() == 1, "Wrong model number!");
}
else {
ShowNTErrors("TODO: support a new XWorker type!");
}
}
if(modelCount != modelNum){
DestroyParamMap();
paramMap = new XTensorKeeper*[serverModel.paramNum];
gradMap = new XTensorKeeper*[serverModel.paramNum];
}
for(int i = 0; i < serverModel.paramNum; i++){
if(modelCount != modelNum){
paramMap[i] = new XTensorKeeper[modelCount];
gradMap[i] = new XTensorKeeper[modelCount];
}
for (int j = 0, c = 0; j < jworkers.count; j++) {
XWorker * worker = (XWorker*)jworkers[i];
if (worker->GetWorkerType() == XWORKER_TYPE_JOB) {
XModel * model = ((XWorkerJob*)jworkers[j])->GetModel();
paramMap[i][c].tensor = model->params[i].tensor;
paramMap[i][c].flag = PARAM_STATE_NOT_READY;
paramMap[i][c].trainFlag = PARAM_STATE_NOT_READY;
gradMap[i][c].tensor = model->params[i].tensor->grad;
gradMap[i][c].flag = PARAM_STATE_NOT_READY;
gradMap[i][c].trainFlag = PARAM_STATE_NOT_READY;
c++;
}
else {
ShowNTErrors("TODO: support a new XWorker type!");
}
}
}
modelNum = modelCount;
}
/*
run the model (for one time). Basically this is a map-reduce process.
......@@ -418,57 +480,6 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, XOptim
return activeJobCount > 0;
}
/* destroy the parameter map (and gradient map) */
void XLeader::DestroyParamMap()
{
for(int i = 0; i < modelNum; i++){
delete[] paramMap[i];
delete[] gradMap[i];
}
delete[] paramMap;
delete[] gradMap;
modelNum = 0;
}
/* generate the map of parameters */
void XLeader::MakeParamMap()
{
DestroyParamMap();
modelNum = 0;
for (int i = 0; i < jworkers.count; i++) {
XWorker * worker = (XWorker*)jworkers[i];
if (worker->GetWorkerType() == XWORKER_TYPE_JOB) {
modelNum += worker->GetModelNum();
CheckNTErrors(worker->GetModelNum() == 1, "Wrong model number!");
}
else {
ShowNTErrors("TODO: support a new XWorker type!");
}
}
paramMap = new XTensorKeeper*[serverModel.paramNum];
gradMap = new XTensorKeeper*[serverModel.paramNum];
for(int i = 0; i < serverModel.paramNum; i++){
paramMap[i] = new XTensorKeeper[modelNum];
gradMap[i] = new XTensorKeeper[modelNum];
for (int j = 0, c = 0; j < jworkers.count; j++) {
XWorker * worker = (XWorker*)jworkers[i];
if (worker->GetWorkerType() == XWORKER_TYPE_JOB) {
XModel * model = ((XWorkerJob*)jworkers[j])->GetModel();
paramMap[i][c].tensor = model->params[i].tensor;
gradMap[i][c].tensor = model->params[i].tensor->grad;
c++;
}
else {
ShowNTErrors("TODO: support a new XWorker type!");
}
}
}
}
/*
run the model
......@@ -530,6 +541,9 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
XWorkerCollect * collecter = (XWorkerCollect*)cworkers.GetItem(0);
XWorkerUpdate * updater = (XWorkerUpdate*)uworkers.GetItem(0);
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)bworkers.GetItem(0);
/* parameter map */
MakeParamMap();
/* all member models */
XList membersAll(jworkers.count);
......@@ -600,6 +614,11 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
/* isGradFinished is true only if the model finishes the computation
(in another process) */
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.tensor->isGradFinished) {
/* get the gradient */
gradSource.tensor = paramSource.tensor->grad;
/* the job queue of updating parameter j */
XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j);
/* data transmit */
......
......@@ -35,6 +35,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
XWorkerJob::XWorkerJob()
{
type = XWORKER_TYPE_JOB;
model = NULL;
Clear();
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论