Commit 149b3380 by xiaotong

updates

parent f866d06d
......@@ -182,6 +182,7 @@ void XLeader::MakeAll(XConfig * config, XModel * model)
{
SetServerModel(config, model);
ResetParamGrad();
MakeParamMap();
}
/*
......@@ -435,17 +436,36 @@ void XLeader::MakeParamMap()
{
DestroyParamMap();
modelNum = jworkers.count;
modelNum = 0;
for (int i = 0; i < jworkers.count; i++) {
XWorker * worker = (XWorker*)jworkers[i];
if (worker->GetWorkerType() == XWORKER_TYPE_JOB) {
modelNum += worker->GetModelNum();
CheckNTErrors(worker->GetModelNum() == 1, "Wrong model number!");
}
else {
ShowNTErrors("TODO: support a new XWorker type!");
}
}
paramMap = new XTensorKeeper*[serverModel.paramNum];
gradMap = new XTensorKeeper*[serverModel.paramNum];
for(int i = 0; i < serverModel.paramNum; i++){
paramMap[i] = new XTensorKeeper[modelNum];
gradMap[i] = new XTensorKeeper[modelNum];
for(int j = 0; j < modelNum; j++){
for (int j = 0, c = 0; j < jworkers.count; j++) {
XWorker * worker = (XWorker*)jworkers[i];
if (worker->GetWorkerType() == XWORKER_TYPE_JOB) {
XModel * model = ((XWorkerJob*)jworkers[j])->GetModel();
paramMap[i][j].tensor = model->params[j].tensor;
gradMap[i][j].tensor = model->params[j].tensor;
paramMap[i][c].tensor = model->params[i].tensor;
gradMap[i][c].tensor = model->params[i].tensor->grad;
c++;
}
else {
ShowNTErrors("TODO: support a new XWorker type!");
}
}
}
}
......@@ -511,9 +531,6 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
XWorkerUpdate * updater = (XWorkerUpdate*)uworkers.GetItem(0);
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)bworkers.GetItem(0);
/* member models that are active in this run */
XList members(jworkers.count);
/* all member models */
XList membersAll(jworkers.count);
......@@ -523,8 +540,6 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob* worker = (XWorkerJob*)jworkers[i];
membersAll.Add(worker->GetModel());
if (active[i] == 1)
members.Add(worker->GetModel());
}
for (int i = 0; i < pworkers.count; i++) {
......@@ -532,6 +547,8 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
jobQueues.Add(worker->GetJobQueue());
}
CheckNTErrors(jobQueues.count == serverModel.paramNum, "Incompatiable model!");
/* jobs in queue 2 (say jobQueue): collect the (gradient) data and other stuff.
This is a reduce process. Then we add a job to to update the model. followed
by a job to broadcast the lastest parameters to workers. NOTE that we
......@@ -543,23 +560,25 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
for (int j = 0; j < serverModel.paramNum; j++)
serverModel.params[j].flag = PARAM_STATE_NOT_READY;
/* check */
for (int i = 0; i < membersAll.count; i++) {
XModel * source = (XModel*)membersAll.GetItem(i);
CheckNTErrors(source->paramNum == serverModel.paramNum, "Incompatiable models!");
}
for (int i = 0; i < members.count; i++) {
XModel * source = (XModel*)members.GetItem(i);
CheckNTErrors(source->paramNum == serverModel.paramNum, "Incompatiable models!");
}
CheckNTErrors(jobQueues.count == serverModel.paramNum, "Incompatiable model!");
/* counts how many member models are collected for each parameter */
int * finishedCount = new int[serverModel.paramNum];
memset(finishedCount, 0, sizeof(int) * serverModel.paramNum);
/* flag active models */
int modelCount = 0;
int activeModelCount = 0;
int * modelFlag = new int[modelNum];
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i];
for (int j = 0; j < worker->GetModelNum(); j++) {
modelFlag[modelCount++] = active[i];
if (active[i] != 0)
activeModelCount++;
}
}
CheckNTErrors(modelCount == modelNum, "Wrong model number!");
/* This is a simple implementation of the do-and-wait process */
while (1) {
for (int j = 0; j < serverModel.paramNum; j++) {
......@@ -572,9 +591,11 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
continue;
/* check if all the models (or part of them) are ready */
for (int i = 0; i < members.count; i++) {
XModel * source = (XModel*)members.GetItem(i);
XTensorKeeper &paramSource = source->params[j];
for (int n = 0, i = 0; n < jworkers.count; n++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[n];
for (int m = 0; m < worker->GetModelNum(); m++, i++) {
XTensorKeeper &paramSource = paramMap[j][i];
XTensorKeeper &gradSource = gradMap[j][i];
/* isGradFinished is true only if the model finishes the computation
(in another process) */
......@@ -582,7 +603,7 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j);
/* data transmit */
collecter->AddJobCollectDataP2P(jobQueue, paramSource.tensor->grad, paramServer.tensor->grad);
collecter->AddJobCollectDataP2P(jobQueue, gradSource.tensor, paramServer.tensor->grad);
collecter->AddJobEnqueueFinished(jobQueue);
/* reset the flag */
......@@ -593,7 +614,7 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
/* we call model update (in another thread) and then
broadcast the new parameters to member models
(in another thread) */
if (finishedCount[j] == members.count) {
if (finishedCount[j] == activeModelCount) {
paramServer.flag = PARAM_STATE_COLLECTED;
if (updater != NULL) {
......@@ -606,21 +627,23 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
broadcaster->AddJobEnqueueFinished(jobQueue);
}
}
else if (finishedCount[j] > members.count) {
else if (finishedCount[j] > activeModelCount) {
ShowNTErrors("Something is wrong with finishedCount!");
}
}
}
}
}
/* finishes if all data tensors are processed */
if (finished == serverModel.paramNum * members.count)
if (finished == serverModel.paramNum * activeModelCount)
break;
XSleep(SLEEP_TIME_IN_WAITING_JOB_WORKERS);
}
delete[] finishedCount;
delete[] modelFlag;
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -34,6 +34,7 @@ namespace nts {
/* constructor */
XWorker::XWorker()
{
type = XWORKER_TYPE_UNKNOWN;
devID = -1;
id = -1;
state = XWORKER_UNSTARTED;
......@@ -46,6 +47,12 @@ XWorker::~XWorker()
Stop();
}
/* get worker type */
XWORKER_TYPE XWorker::GetWorkerType()
{
return type;
}
/* set device id */
void XWorker::SetDeviceID(int myDevID)
{
......@@ -110,6 +117,21 @@ int XWorker::GetJobNum()
return queue.GetJobNum();
}
/*
get the number of the models for this worker. For a worker that runs
locally (i.e., on the same machine with the leader), the model number
is simply 1. For a remote worker (i.e., the worker must connect to a
remote leader), the model number counts the models that the remote leader
leads.
*/
int XWorker::GetModelNum()
{
/* TODO: return the remote model number if the worker connects
to another leader. */
return 1;
}
/* whether the job queue is empty? */
bool XWorker::IsEmpty()
{
......
......@@ -31,6 +31,7 @@
#ifndef __XWORKER_H__
#define __XWORKER_H__
#include "XModel.h"
#include "../tensor/XQueue.h"
#include "../tensor/XUtility.h"
......@@ -42,12 +43,20 @@ state of a worker
2) started
3) finished
*/
enum XWORKER_STATE { XWORKER_UNSTARTED, XWORKER_STARTED, XWORKER_FINISHED };
enum XWORKER_STATE { XWORKER_UNSTARTED, XWORKER_STARTED, XWORKER_FINISHED };
/*
worker type
*/
enum XWORKER_TYPE { XWORKER_TYPE_UNKNOWN, XWORKER_TYPE_JOB, XWORKER_TYPE_COLLECT, XWORKER_TYPE_UPDATE, XWORKER_TYPE_BROADCAST };
/* the worker class */
class XWorker
{
protected:
/* type of the worker */
XWORKER_TYPE type;
/* id of the device where we run the worker (we suppose that
the worker is insite. */
int devID;
......@@ -74,6 +83,9 @@ public:
/* de-constructor */
~XWorker();
/* get worker type */
XWORKER_TYPE GetWorkerType();
/* set device id */
void SetDeviceID(int myDevID);
......@@ -101,9 +113,12 @@ public:
/* stop the work */
void Stop();
/* get the number of remaining jobs */
/* get the number of the remaining jobs */
int GetJobNum();
/* get the number of the models for this worker */
int GetModelNum();
/* whether the job queue is empty? */
bool IsEmpty();
......
......@@ -36,6 +36,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */
XWorkerBroadcast::XWorkerBroadcast()
{
type = XWORKER_TYPE_BROADCAST;
}
/* de-constructor */
......
......@@ -34,6 +34,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */
XWorkerCollect::XWorkerCollect()
{
type = XWORKER_TYPE_COLLECT;
collectMode = DATA_COLLECT_P2P;
}
......
......@@ -34,6 +34,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */
XWorkerJob::XWorkerJob()
{
type = XWORKER_TYPE_JOB;
Clear();
}
......
......@@ -32,6 +32,7 @@ namespace nts { // namespace nts (NiuTrans.Tensor)
/* constructor */
XWorkerUpdate::XWorkerUpdate()
{
type = XWORKER_TYPE_UPDATE;
optimizer = NULL;
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论