Commit d69372b3 by xiaotong

add a new class XParamKeeper

parent 90599e05
......@@ -90,7 +90,7 @@ void TestTrain()
TTModel model;
model.Init(config, -1);
tmpTT = (XTensor*)model.params[0];
tmpTT = model.params[0].param;
XOptimizer optimizer;
optimizer.Init(config);
......
......@@ -91,8 +91,8 @@ Set the server model. It distributes the server-side parameters on different dev
void XLeader::SetServerModel(XConfig * config, XModel * model, XList * memberModels)
{
serverModel.Clear();
for (int i = 0; i < model->params.count; i++) {
XTensor * param = (XTensor*)model->params[i];
for (int i = 0; i < model->paramNum; i++) {
XTensor * param = model->params[i].param;
serverModel.AddParam(param);
}
......
......@@ -33,17 +33,34 @@
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
/* constructor */
XParamKeeper::XParamKeeper()
{
param = NULL;
flag = PARAM_STATE_NOT_READY;
MUTEX_INIT(accessLock);
MUTEX_INIT(trainLock);
}
/* constructor */
XParamKeeper::~XParamKeeper()
{
MUTEX_DELE(accessLock);
MUTEX_DELE(trainLock);
}
/* constructor */
XModel::XModel()
{
flags = NULL;
params = NULL;
paramNum = 0;
MUTEX_INIT(modelMutex);
}
/* de-constructor */
XModel::~XModel()
{
delete[] flags;
Clear();
MUTEX_DELE(modelMutex);
}
......@@ -51,7 +68,8 @@ XModel::~XModel()
/* clear the model */
void XModel::Clear()
{
params.Clear();
delete[] params;
paramNum = 0;
}
/*
......@@ -104,22 +122,27 @@ add a parameter tensor
void XModel::AddParam(XTensor* param)
{
param->SetVarFlag();
params.Add(param);
PARAM_STATE * newFlags = new PARAM_STATE[params.count];
memcpy(newFlags, flags, sizeof(PARAM_STATE) * (params.count - 1));
newFlags[params.count - 1] = PARAM_STATE_NOT_READY;
XParamKeeper * newParams = new XParamKeeper[paramNum + 1];
for (int i = 0; i < paramNum; i++) {
newParams[i].param = params[i].param;
newParams[i].flag = params[i].flag;
}
delete[] flags;
flags = newFlags;
newParams[paramNum].param = param;
newParams[paramNum].flag = PARAM_STATE_NOT_READY;
delete[] params;
params = newParams;
paramNum++;
}
/* check if the parameters are well-defined for training */
bool XModel::CheckParam()
{
for (int i = 0; i < params.count; i++) {
XTensor * param = (XTensor*)params[i];
for (int i = 0; i < paramNum; i++) {
XTensor * param = params[i].param;
if (!param->isGrad)
return false;
}
......@@ -130,25 +153,18 @@ bool XModel::CheckParam()
/* initial model for running the it */
void XModel::InitForRun()
{
for (int i = 0; i < params.count; i++) {
XTensor * param = (XTensor*)params[i];
param->isGradFinished = false;
flags[i] = PARAM_STATE_NOT_READY;
for (int i = 0; i < paramNum; i++) {
params[i].param->isGradFinished = false;
params[i].flag = PARAM_STATE_NOT_READY;
}
}
/* refresh the model */
void XModel::RefreshMe()
{
for (int i = 0; i < params.count; i++) {
XTensor * param = params.GetItem(i);
param->isGradFinished = false;
}
delete[] flags;
flags = new PARAM_STATE[params.count];
for (int i = 0; i < params.count; i++) {
flags[i] = PARAM_STATE_NOT_READY;
for (int i = 0; i < paramNum; i++) {
params[i].param->isGradFinished = false;
params[i].flag = PARAM_STATE_NOT_READY;
}
}
......
......@@ -50,6 +50,31 @@ enum PARAM_STATE { PARAM_STATE_NOT_READY,
PARAM_STATE_COLLECTED,
PARAM_STATE_UPDATED };
/* parameter keeper */
class XParamKeeper
{
public:
/* the parameter */
XTensor * param;
/* the parameter state */
PARAM_STATE flag;
/* a mutex for locking and unlocking the parameter */
MUTEX_HANDLE accessLock;
/* a mutex of the overall training */
MUTEX_HANDLE trainLock;
public:
/* constructor */
XParamKeeper();
/* constructor */
~XParamKeeper();
};
/* a model template for training */
class XModel
{
......@@ -58,11 +83,11 @@ protected:
MUTEX_HANDLE modelMutex;
public:
/* the list of model parameters (pointers to the parameter tensor) */
TensorList params;
/* the list of model parameters */
XParamKeeper * params;
/* flags of the parameters */
PARAM_STATE * flags;
/* parameter number */
int paramNum;
public:
......
......@@ -57,20 +57,17 @@ broadcast data for a parameter
*/
void XWorkerBroadcast::BroadcastDataSingle(XModel * source, XList * targetList, int pid)
{
CheckNTErrors(source->flags[pid] == PARAM_STATE_UPDATED,
CheckNTErrors(source->params[pid].flag == PARAM_STATE_UPDATED,
"The parameter is not ready for broadcasting");
TensorList & sp = source->params;
for (int i = 0; i < targetList->count; i++) {
XModel * target = (XModel*)targetList->GetItem(i);
TensorList & tp = target->params;
/* data transmit */
BroadcastP2P(sp.GetItem(pid), tp.GetItem(pid));
BroadcastP2P(source->params[pid].param, target->params[pid].param);
/* update the flag */
target->flags[pid] = PARAM_STATE_UPDATED;
target->params[pid].flag = PARAM_STATE_UPDATED;
}
}
......@@ -82,21 +79,20 @@ broadcast data for a model
*/
void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long sleepTime)
{
TensorList & sp = source->params;
int finished = 0;
int * finishedFlag = new int[sp.count];
memset(finishedFlag, 0, sizeof(int) * sp.count);
int * finishedFlag = new int[source->paramNum];
memset(finishedFlag, 0, sizeof(int) * source->paramNum);
/* check */
for (int i = 0; i < targetList->count; i++) {
TensorList & tp = ((XModel*)targetList->GetItem(i))->params;
CheckNTErrors(sp.count == tp.count, "Incompatiable models!");
XModel * target = (XModel*)targetList->GetItem(i);
CheckNTErrors(source->paramNum == target->paramNum, "Incompatiable models!");
}
/* the major body of broadcasting */
while (1) {
for (int i = 0; i < sp.count; i++) {
if (source->flags[i] == PARAM_STATE_UPDATED && finishedFlag[i] == 0) {
for (int i = 0; i < source->paramNum; i++) {
if (source->params[i].flag == PARAM_STATE_UPDATED && finishedFlag[i] == 0) {
/* broadcasting */
BroadcastDataSingle(source, targetList, i);
......@@ -107,7 +103,7 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s
}
}
if (finished == sp.count * targetList->count)
if (finished == source->paramNum * targetList->count)
break;
XSleep(sleepTime);
......@@ -186,7 +182,7 @@ bool XWorkerBroadcast::AddJobBroadcastSingle(XModel * source, XList * targetList
{
CheckNTErrors(source != NULL, "no input source tensor!");
CheckNTErrors(targetList != NULL, "no input target tensor list!");
CheckNTErrors(pid >= 0 && pid < source->params.count, "illegal parameter index!");
CheckNTErrors(pid >= 0 && pid < source->paramNum, "illegal parameter index!");
XList args;
args.Add(this);
......
......@@ -67,26 +67,25 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
XOptimizer * optimizer, XWorkerUpdate * updater,
XWorkerBroadcast * broadcaster, long sleepTime)
{
TensorList & tp = server->params;
int finished = 0;
for (int j = 0; j < tp.count; j++)
server->flags[j] = PARAM_STATE_NOT_READY;
for (int j = 0; j < server->paramNum; j++)
server->params[j].flag = PARAM_STATE_NOT_READY;
/* check */
for (int i = 0; i < memberAll->count; i++) {
TensorList & sp = ((XModel*)memberAll->GetItem(i))->params;
CheckNTErrors(sp.count == tp.count, "Incompatiable models!");
XModel * source = (XModel*)memberAll->GetItem(i);
CheckNTErrors(source->paramNum == server->paramNum, "Incompatiable models!");
}
for (int i = 0; i < memberActive->count; i++) {
TensorList & sp = ((XModel*)memberActive->GetItem(i))->params;
CheckNTErrors(sp.count == tp.count, "Incompatiable models!");
XModel * source = (XModel*)memberActive->GetItem(i);
CheckNTErrors(source->paramNum == server->paramNum, "Incompatiable models!");
}
/* counts how many member models are collect for each parameters */
int * finishedCount = new int[tp.count];
memset(finishedCount, 0, sizeof(int) * tp.count);
int * finishedCount = new int[server->paramNum];
memset(finishedCount, 0, sizeof(int) * server->paramNum);
/* This is a simple implementation of the wait-and-collect process. But
there is a risk that some models are not available, that is, the
......@@ -94,26 +93,29 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
to break after waiting for a short time. */
while (1) {
if (collectMode == DATA_COLLECT_P2P) {
for (int j = 0; j < tp.count; j++) {
for (int j = 0; j < server->paramNum; j++) {
XParamKeeper &paramServer = server->params[j];
/* tp[j]->isGradFinished is true only if the model finishes the computation
(in another process) */
if (server->flags[j] != PARAM_STATE_NOT_READY || !tp[j]->isGradFinished)
if (paramServer.flag != PARAM_STATE_NOT_READY || !paramServer.param->isGradFinished)
continue;
/* check if all the models (or part of them) are ready */
for (int i = 0; i < memberActive->count; i++) {
XModel * source = (XModel*)memberActive->GetItem(i);
TensorList & sp = source->params;
XParamKeeper &paramSource = source->params[j];
/* sp[j]->isGradFinished is true only if the model finishes the computation
(in another process) */
if (source->flags[j] == PARAM_STATE_NOT_READY && sp[j]->isGradFinished) {
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.param->isGradFinished) {
/* data transmit */
CollectP2P(sp[j]->grad, tp[j]->grad);
CollectP2P(paramSource.param->grad, paramServer.param->grad);
/* reset the flag */
source->flags[j] = PARAM_STATE_COLLECTED;
paramSource.flag = PARAM_STATE_COLLECTED;
finished++;
finishedCount[j]++;
......@@ -121,7 +123,7 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
broadcast the new parameters to member models
(in another thread) */
if (finishedCount[j] == memberActive->count) {
server->flags[j] = PARAM_STATE_COLLECTED;
paramServer.flag = PARAM_STATE_COLLECTED;
if(updater != NULL)
updater->AddJobUpdateSingle(server, memberAll, j, optimizer, broadcaster);
}
......@@ -133,31 +135,32 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
}
}
else if (collectMode == DATA_COLLECT_REDUCESUM) {
for (int j = 0; j < tp.count; j++) {
for (int j = 0; j < server->paramNum; j++) {
bool ready = true;
XParamKeeper &paramServer = server->params[j];
/* tp[j]->isGradFinished is true only if the model finishes the computation
(in another process) */
if (server->flags[j] != PARAM_STATE_NOT_READY || !tp[j]->isGradFinished)
if (paramServer.flag != PARAM_STATE_NOT_READY || !paramServer.param->isGradFinished)
continue;
/* check if all the models (or part of them) are ready */
for (int i = 0; i < memberActive->count; i++) {
XModel * source = (XModel*)memberActive->GetItem(i);
TensorList & sp = source->params;
XParamKeeper &paramSource = source->params[j];
/* sp[j]->isGradFinished is true only if the model finishes the computation
(in another process) */
if (source->flags[j] == PARAM_STATE_COLLECTED ||
source->flags[j] == PARAM_STATE_UPDATED ||
!sp[j]->isGradFinished)
if (paramSource.flag == PARAM_STATE_COLLECTED ||
paramSource.flag == PARAM_STATE_UPDATED ||
!paramSource.param->isGradFinished)
{
ready = false;
break;
}
else if (source->flags[j] == PARAM_STATE_NOT_READY) {
source->flags[j] = PARAM_STATE_READY;
else if (paramSource.flag == PARAM_STATE_NOT_READY) {
paramSource.flag = PARAM_STATE_READY;
}
}
......@@ -166,20 +169,19 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
for (int i = 0; i < memberActive->count; i++) {
XModel * source = (XModel*)memberActive->GetItem(i);
TensorList & sp = source->params;
tensorList.Add(sp.GetItem(j)->grad);
tensorList.Add(source->params[j].param->grad);
}
/* data transmit */
CollectReduceSum(&tensorList, tp.GetItem(j)->grad);
CollectReduceSum(&tensorList, server->params[j].param->grad);
/* reset the flags */
for (int i = 0; i < memberActive->count; i++) {
XModel * source = (XModel*)memberActive->GetItem(i);
source->flags[j] = PARAM_STATE_COLLECTED;
source->params[j].flag = PARAM_STATE_COLLECTED;
}
server->flags[j] = PARAM_STATE_COLLECTED;
server->params[j].flag = PARAM_STATE_COLLECTED;
finished += memberActive->count;
/* we call model update (in another thread) and then
......@@ -194,7 +196,7 @@ void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XMod
}
/* the collection finishes if all data tensors are processed */
if (finished == tp.count * memberActive->count)
if (finished == server->paramNum * memberActive->count)
break;
XSleep(sleepTime);
......@@ -333,17 +335,6 @@ bool XWorkerCollect::AddJobCollect(XList * sourceList, XModel * target)
CheckNTErrors(sourceList != NULL, "no input source model list!");
CheckNTErrors(target != NULL, "no input target model!");
/*XList args;
args.Add(this);
args.AddInt(sourceList->count);
args.AddList(sourceList);
args.Add(target);
if (isInstantRun)
XWorkerCollect::Collect(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerCollect::Collect, &args);*/
XList args;
args.Add(this);
args.AddInt(sourceList->count);
......
......@@ -63,12 +63,10 @@ update a parameter of a model
void XWorkerUpdate::UpdateParameter(XModel * server, XList * members, int pid,
XOptimizer * optimizer, XWorkerBroadcast * broadcaster)
{
TensorList & params = server->params;
PARAM_STATE * flags = server->flags;
CheckNTErrors(flags[pid] == PARAM_STATE_COLLECTED, "The state of the parameter is wrong!");
CheckNTErrors(server->params[pid].flag == PARAM_STATE_COLLECTED, "The state of the parameter is wrong!");
XTensor * param = params.GetItem(pid);
XTensor * param = server->params[pid].param;
XTensor * grad = param->grad;
CheckNTErrors(grad != NULL, "No gradient!");
......@@ -77,7 +75,7 @@ void XWorkerUpdate::UpdateParameter(XModel * server, XList * members, int pid,
optimizer->UpdateParam(param, grad, pid);
/* set the flag */
flags[pid] = PARAM_STATE_UPDATED;
server->params[pid].flag = PARAM_STATE_UPDATED;
/* broadcast the new parameter to other models (in anotehr worker/thread) */
broadcaster->AddJobBroadcastSingle(server, members, pid);
......@@ -92,15 +90,13 @@ update the model
void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, long sleepTime)
{
int finished = 0;
TensorList & params = model->params;
PARAM_STATE * flags = model->flags;
optimizer->Prepare(model);
while (1) {
for (int i = 0; i < params.count; i++) {
if (flags[i] == PARAM_STATE_COLLECTED) {
XTensor * param = params.GetItem(i);
for (int i = 0; i < model->paramNum; i++) {
if (model->params[i].flag == PARAM_STATE_COLLECTED) {
XTensor * param = model->params[i].param;
XTensor * grad = param->grad;
CheckNTErrors(grad != NULL, "No gradient!");
......@@ -109,12 +105,12 @@ void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, long sle
optimizer->UpdateParam(param, grad, i);
/* set the flag */
flags[i] = PARAM_STATE_UPDATED;
model->params[i].flag = PARAM_STATE_UPDATED;
finished++;
}
}
if (finished == params.count)
if (finished == model->paramNum)
break;
XSleep(sleepTime);
......@@ -182,7 +178,7 @@ bool XWorkerUpdate::AddJobUpdateSingle(XModel * model, XList * members, int pid,
CheckNTErrors(members != NULL, "No member model list!");
CheckNTErrors(optimizer != NULL, "No optimizer!");
CheckNTErrors(broadcaster != NULL, "No broadcaster!");
CheckNTErrors(pid >= 0 && pid < model->params.count, "Illegal parameter index!");
CheckNTErrors(pid >= 0 && pid < model->paramNum, "Illegal parameter index!");
XList args;
args.Add(this);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论