Commit a09b92e6 by xiaotong

remove gradMap and redefine XWorkerUpdate

parent 81483f00
......@@ -41,7 +41,6 @@ XLeader::XLeader()
{
id = -1;
paramMap = NULL;
gradMap = NULL;
modelNum = 0;
}
......@@ -385,11 +384,8 @@ void XLeader::DestroyParamMap()
for(int i = 0; i < serverModel.paramNum; i++){
if(paramMap != NULL)
delete[] paramMap[i];
if(gradMap != NULL)
delete[] gradMap[i];
}
delete[] paramMap;
delete[] gradMap;
modelNum = 0;
}
......@@ -411,13 +407,11 @@ void XLeader::MakeParamMap()
if(modelCount != modelNum){
DestroyParamMap();
paramMap = new XTensorKeeper*[serverModel.paramNum];
gradMap = new XTensorKeeper*[serverModel.paramNum];
}
for(int i = 0; i < serverModel.paramNum; i++){
if(modelCount != modelNum){
paramMap[i] = new XTensorKeeper[modelCount];
gradMap[i] = new XTensorKeeper[modelCount];
}
for (int j = 0, c = 0; j < jworkers.count; j++) {
......@@ -425,11 +419,9 @@ void XLeader::MakeParamMap()
if (worker->GetWorkerType() == XWORKER_TYPE_JOB) {
XModel * model = ((XWorkerJob*)jworkers[j])->GetModel();
paramMap[i][c].tensor = model->params[i].tensor;
paramMap[i][c].grad = model->params[i].tensor->grad;
paramMap[i][c].flag = PARAM_STATE_NOT_READY;
paramMap[i][c].trainFlag = PARAM_STATE_NOT_READY;
gradMap[i][c].tensor = model->params[i].tensor->grad;
gradMap[i][c].flag = PARAM_STATE_NOT_READY;
gradMap[i][c].trainFlag = PARAM_STATE_NOT_READY;
paramMap[i][c].trainFlag = PARAM_STATE_NOT_READY;;
c++;
}
else {
......@@ -600,48 +592,51 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
XTensorKeeper &paramServer = serverModel.params[j];
/* isGradFinished is true only if the model finishes the computation
(in another process) */
(in another thread) */
if (paramServer.flag != PARAM_STATE_NOT_READY || !paramServer.tensor->isGradFinished)
continue;
/* set the gradient tensor */
if (paramServer.grad != paramServer.tensor->grad)
paramServer.grad = paramServer.tensor->grad;
/* check if all the models (or part of them) are ready */
for (int n = 0, i = 0; n < jworkers.count; n++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[n];
for (int m = 0; m < worker->GetModelNum(); m++, i++) {
XTensorKeeper &paramSource = paramMap[j][i];
XTensorKeeper &gradSource = gradMap[j][i];
XTensorKeeper &paramWorker = paramMap[j][i];
/* isGradFinished is true only if the model finishes the computation
(in another process) */
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.tensor->isGradFinished) {
(in another thread) */
if (paramWorker.flag == PARAM_STATE_NOT_READY && paramWorker.tensor->isGradFinished) {
/* get the gradient */
gradSource.tensor = paramSource.tensor->grad;
paramWorker.grad = paramWorker.tensor->grad;
/* the job queue of updating parameter j */
XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j);
/* data transmit */
collecter->AddJobCollectDataP2P(jobQueue, gradSource.tensor, paramServer.tensor->grad);
collecter->AddJobCollectDataP2P(jobQueue, paramWorker.grad, paramServer.grad);
collecter->AddJobEnqueueFinished(jobQueue);
/* reset the flag */
paramSource.flag = PARAM_STATE_COLLECTED;
paramWorker.flag = PARAM_STATE_COLLECTED;
finished++;
finishedCount[j]++;
/* we call model update (in another thread) and then
broadcast the new parameters to member models
(in another thread) */
broadcast the new parameters to member models
(in another thread) */
if (finishedCount[j] == activeModelCount) {
paramServer.flag = PARAM_STATE_COLLECTED;
if (updater != NULL) {
/* update the parameters */
updater->AddJobUpdate(jobQueue, &serverModel, j, optimizer);
updater->AddJobUpdate(jobQueue, &paramServer, optimizer);
updater->AddJobEnqueueFinished(jobQueue);
/* broadcast the new parameter to other models*/
/* broadcast the new parameter to other models */
broadcaster->AddJobBroadcast(jobQueue, &serverModel, &membersAll, j);
broadcaster->AddJobEnqueueFinished(jobQueue);
}
......
......@@ -101,10 +101,6 @@ protected:
gradient of the loss with respect to the parameters. */
XTensorKeeper ** paramMap;
/* map of parameter gradients. (x,y) indexes the
gradient of parameter x of worker y. */
XTensorKeeper ** gradMap;
/* number of model copies for paramMap and gradMap */
int modelNum;
......
......@@ -87,9 +87,8 @@ void XOptimizer::Note(XModel * model)
update a parameter matrix
>> param - the parameter matrix
>> gard - the gradient
>> pid - the id of the parameter matrix
*/
void XOptimizer::UpdateParam(XTensor * param, XTensor * grad, int pid)
void XOptimizer::UpdateParam(XTensor * param, XTensor * grad)
{
/* the delta rule
\theta_new = \theta_old - \grad * \lrate */
......
......@@ -77,7 +77,7 @@ public:
/* update a parameter matrix */
virtual
void UpdateParam(XTensor * param, XTensor * grad, int pid);
void UpdateParam(XTensor * param, XTensor * grad);
/* get learning rate */
float GetLearningRate();
......
......@@ -34,6 +34,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
XTensorKeeper::XTensorKeeper()
{
tensor = NULL;
grad = NULL;
flag = PARAM_STATE_NOT_READY;
trainFlag = PARAM_STATE_NOT_READY;
MUTEX_INIT(accessLock);
......
......@@ -55,6 +55,9 @@ public:
/* the parameter */
XTensor * tensor;
/* the gradient */
XTensor * grad;
/* the parameter state */
PARAM_STATE flag;
......
......@@ -55,26 +55,24 @@ XOptimizer * XWorkerUpdate::GetOptimizer()
/*
update a parameter of a model
>> model - the model that we want to update (on the server side)
>> pid - the parameter index
>> paramKeeper - the parameter keeper
>> optimizer - the optimizer
*/
void XWorkerUpdate::UpdateParameter(XModel * server, int pid,
XOptimizer * optimizer)
void XWorkerUpdate::UpdateParameter(XTensorKeeper * paramKeeper, XOptimizer * optimizer)
{
CheckNTErrors(server->params[pid].flag == PARAM_STATE_COLLECTED, "The state of the parameter is wrong!");
CheckNTErrors(paramKeeper->flag == PARAM_STATE_COLLECTED, "The state of the parameter is wrong!");
XTensor * param = server->params[pid].tensor;
XTensor * grad = param->grad;
XTensor * param = paramKeeper->tensor;
XTensor * grad = paramKeeper->grad;
CheckNTErrors(grad != NULL, "No gradient!");
/* update the parameter */
optimizer->UpdateParam(param, grad, pid);
optimizer->UpdateParam(param, grad);
/* set the flag */
server->params[pid].flag = PARAM_STATE_UPDATED;
paramKeeper->flag = PARAM_STATE_UPDATED;
}
/*
......@@ -85,37 +83,32 @@ void XWorkerUpdate::Update(XList * args)
{
int paramCount = 0;
CheckNTErrors(args != NULL && args->count >= 4, "Illegal argument list!");
CheckNTErrors(args != NULL && args->count == 3, "Illegal argument list!");
XWorkerUpdate * updater = (XWorkerUpdate*)args->GetItem(paramCount++);
XModel * server = (XModel*)args->GetItem(paramCount++);
int pid = args->GetInt(paramCount++);
XTensorKeeper * paramKeeper = (XTensorKeeper*)args->GetItem(paramCount++);
XOptimizer * optimizer = (XOptimizer*)args->GetItem(paramCount++);
if(updater != NULL)
updater->UpdateParameter(server, pid, optimizer);
updater->UpdateParameter(paramKeeper, optimizer);
}
/*
add a new job of model update (for a parameter)
>> jobQueue - the queue for sub-jobs executed in the job
>> model - the model that we want to update (on the server side)
>> pid - the parameter index
>> paramKeeper - the parameter keeper
>> optimizer - the optimizer
*/
bool XWorkerUpdate::AddJobUpdate(XQueue * jobQueue,
XModel * model, int pid,
XTensorKeeper * paramKeeper,
XOptimizer * optimizer)
{
CheckNTErrors(model != NULL, "No input model!");
CheckNTErrors(paramKeeper != NULL, "No input parameter keeper!");
CheckNTErrors(optimizer != NULL, "No optimizer!");
CheckNTErrors(pid >= 0 && pid < model->paramNum, "Illegal parameter index!");
XList args;
args.Add(this);
args.Add(model);
args.AddInt(pid);
args.Add(paramKeeper);
args.Add(optimizer);
XQueue& queueRun = jobQueue != NULL ? *jobQueue : queue;
......
......@@ -57,16 +57,14 @@ public:
XOptimizer * GetOptimizer();
/* update the parameter */
void UpdateParameter(XModel * server, int pid, XOptimizer * optimizer);
void UpdateParameter(XTensorKeeper * paramKeeper, XOptimizer * optimizer);
/* wrapper of UpdateParameter */
static
void Update(XList * args);
/* add a new job of model update (for a parameter) */
bool AddJobUpdate(XQueue * jobQueue, XModel * model, int pid, XOptimizer * optimizer);
bool AddJobUpdate(XQueue * jobQueue, XTensorKeeper * paramKeeper, XOptimizer * optimizer);
};
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论