Commit ee577b38 by xiaotong

rename XParamKeeper as XTensorKeeper

parent 69478776
...@@ -96,7 +96,7 @@ void XLeader::SetServerModel(XConfig * config, XModel * model, XList * memberMod ...@@ -96,7 +96,7 @@ void XLeader::SetServerModel(XConfig * config, XModel * model, XList * memberMod
{ {
serverModel.Clear(); serverModel.Clear();
for (int i = 0; i < model->paramNum; i++) { for (int i = 0; i < model->paramNum; i++) {
XTensor * param = model->params[i].param; XTensor * param = model->params[i].tensor;
serverModel.AddParam(param); serverModel.AddParam(param);
} }
...@@ -151,7 +151,7 @@ void XLeader::InitForRun() ...@@ -151,7 +151,7 @@ void XLeader::InitForRun()
void XLeader::ResetParamGrad() void XLeader::ResetParamGrad()
{ {
for (int i = 0; i < serverModel.paramNum; i++) { for (int i = 0; i < serverModel.paramNum; i++) {
XTensor* param = serverModel.params[i].param; XTensor* param = serverModel.params[i].tensor;
if (param->grad != NULL) { if (param->grad != NULL) {
param->grad->SetZeroAll(); param->grad->SetZeroAll();
} }
...@@ -161,7 +161,7 @@ void XLeader::ResetParamGrad() ...@@ -161,7 +161,7 @@ void XLeader::ResetParamGrad()
XWorkerJob * worker = (XWorkerJob*)jworkers[j]; XWorkerJob * worker = (XWorkerJob*)jworkers[j];
XModel * model = worker->GetModel(); XModel * model = worker->GetModel();
for (int i = 0; i < model->paramNum; i++) { for (int i = 0; i < model->paramNum; i++) {
XTensor* param = model->params[i].param; XTensor* param = model->params[i].tensor;
if (param->grad != NULL) { if (param->grad != NULL) {
param->grad->SetZeroAll(); param->grad->SetZeroAll();
} }
...@@ -517,25 +517,25 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac ...@@ -517,25 +517,25 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
while (1) { while (1) {
for (int j = 0; j < serverModel.paramNum; j++) { for (int j = 0; j < serverModel.paramNum; j++) {
XParamKeeper &paramServer = serverModel.params[j]; XTensorKeeper &paramServer = serverModel.params[j];
/* isGradFinished is true only if the model finishes the computation /* isGradFinished is true only if the model finishes the computation
(in another process) */ (in another process) */
if (paramServer.flag != PARAM_STATE_NOT_READY || !paramServer.param->isGradFinished) if (paramServer.flag != PARAM_STATE_NOT_READY || !paramServer.tensor->isGradFinished)
continue; continue;
/* check if all the models (or part of them) are ready */ /* check if all the models (or part of them) are ready */
for (int i = 0; i < members.count; i++) { for (int i = 0; i < members.count; i++) {
XModel * source = (XModel*)members.GetItem(i); XModel * source = (XModel*)members.GetItem(i);
XParamKeeper &paramSource = source->params[j]; XTensorKeeper &paramSource = source->params[j];
/* isGradFinished is true only if the model finishes the computation /* isGradFinished is true only if the model finishes the computation
(in another process) */ (in another process) */
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.param->isGradFinished) { if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.tensor->isGradFinished) {
XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j); XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j);
/* data transmit */ /* data transmit */
collecter->AddJobCollectDataP2P(jobQueue, paramSource.param->grad, paramServer.param->grad); collecter->AddJobCollectDataP2P(jobQueue, paramSource.tensor->grad, paramServer.tensor->grad);
collecter->AddJobEnqueueFinished(jobQueue); collecter->AddJobEnqueueFinished(jobQueue);
/* reset the flag */ /* reset the flag */
......
...@@ -35,9 +35,9 @@ namespace nts { ...@@ -35,9 +35,9 @@ namespace nts {
/* constructor */ /* constructor */
XParamKeeper::XParamKeeper() XTensorKeeper::XTensorKeeper()
{ {
param = NULL; tensor = NULL;
flag = PARAM_STATE_NOT_READY; flag = PARAM_STATE_NOT_READY;
trainFlag = PARAM_STATE_NOT_READY; trainFlag = PARAM_STATE_NOT_READY;
MUTEX_INIT(accessLock); MUTEX_INIT(accessLock);
...@@ -45,7 +45,7 @@ XParamKeeper::XParamKeeper() ...@@ -45,7 +45,7 @@ XParamKeeper::XParamKeeper()
} }
/* constructor */ /* constructor */
XParamKeeper::~XParamKeeper() XTensorKeeper::~XTensorKeeper()
{ {
MUTEX_DELE(accessLock); MUTEX_DELE(accessLock);
MUTEX_DELE(trainLock); MUTEX_DELE(trainLock);
...@@ -124,14 +124,14 @@ void XModel::AddParam(XTensor* param) ...@@ -124,14 +124,14 @@ void XModel::AddParam(XTensor* param)
{ {
param->SetVarFlag(); param->SetVarFlag();
XParamKeeper * newParams = new XParamKeeper[paramNum + 1]; XTensorKeeper * newParams = new XTensorKeeper[paramNum + 1];
for (int i = 0; i < paramNum; i++) { for (int i = 0; i < paramNum; i++) {
newParams[i].param = params[i].param; newParams[i].tensor = params[i].tensor;
newParams[i].flag = params[i].flag; newParams[i].flag = params[i].flag;
} }
newParams[paramNum].param = param; newParams[paramNum].tensor = param;
newParams[paramNum].flag = PARAM_STATE_NOT_READY; newParams[paramNum].flag = PARAM_STATE_NOT_READY;
delete[] params; delete[] params;
...@@ -143,7 +143,7 @@ void XModel::AddParam(XTensor* param) ...@@ -143,7 +143,7 @@ void XModel::AddParam(XTensor* param)
bool XModel::CheckParam() bool XModel::CheckParam()
{ {
for (int i = 0; i < paramNum; i++) { for (int i = 0; i < paramNum; i++) {
XTensor * param = params[i].param; XTensor * param = params[i].tensor;
if (!param->isGrad) if (!param->isGrad)
return false; return false;
} }
...@@ -191,7 +191,7 @@ void XModel::WaitForUnlockedParams() ...@@ -191,7 +191,7 @@ void XModel::WaitForUnlockedParams()
void XModel::RefreshMe() void XModel::RefreshMe()
{ {
for (int i = 0; i < paramNum; i++) { for (int i = 0; i < paramNum; i++) {
params[i].param->isGradFinished = false; params[i].tensor->isGradFinished = false;
params[i].flag = PARAM_STATE_NOT_READY; params[i].flag = PARAM_STATE_NOT_READY;
params[i].trainFlag = PARAM_STATE_NOT_READY; params[i].trainFlag = PARAM_STATE_NOT_READY;
} }
......
...@@ -50,12 +50,12 @@ enum PARAM_STATE { PARAM_STATE_NOT_READY, ...@@ -50,12 +50,12 @@ enum PARAM_STATE { PARAM_STATE_NOT_READY,
PARAM_STATE_COLLECTED, PARAM_STATE_COLLECTED,
PARAM_STATE_UPDATED }; PARAM_STATE_UPDATED };
/* parameter keeper */ /* tensor keeper */
class XParamKeeper class XTensorKeeper
{ {
public: public:
/* the parameter */ /* the parameter */
XTensor * param; XTensor * tensor;
/* the parameter state */ /* the parameter state */
PARAM_STATE flag; PARAM_STATE flag;
...@@ -73,11 +73,10 @@ public: ...@@ -73,11 +73,10 @@ public:
public: public:
/* constructor */ /* constructor */
XParamKeeper(); XTensorKeeper();
/* constructor */ /* constructor */
~XParamKeeper(); ~XTensorKeeper();
}; };
/* a model template for training */ /* a model template for training */
...@@ -89,7 +88,7 @@ protected: ...@@ -89,7 +88,7 @@ protected:
public: public:
/* the list of model parameters */ /* the list of model parameters */
XParamKeeper * params; XTensorKeeper * params;
/* parameter number */ /* parameter number */
int paramNum; int paramNum;
......
...@@ -64,7 +64,7 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, int pi ...@@ -64,7 +64,7 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, int pi
XModel * target = (XModel*)targetList->GetItem(i); XModel * target = (XModel*)targetList->GetItem(i);
/* data transmit */ /* data transmit */
BroadcastP2P(source->params[pid].param, target->params[pid].param); BroadcastP2P(source->params[pid].tensor, target->params[pid].tensor);
/* update the flag */ /* update the flag */
target->params[pid].flag = PARAM_STATE_UPDATED; target->params[pid].flag = PARAM_STATE_UPDATED;
......
...@@ -64,7 +64,7 @@ void XWorkerUpdate::UpdateParameter(XModel * server, int pid, ...@@ -64,7 +64,7 @@ void XWorkerUpdate::UpdateParameter(XModel * server, int pid,
CheckNTErrors(server->params[pid].flag == PARAM_STATE_COLLECTED, "The state of the parameter is wrong!"); CheckNTErrors(server->params[pid].flag == PARAM_STATE_COLLECTED, "The state of the parameter is wrong!");
XTensor * param = server->params[pid].param; XTensor * param = server->params[pid].tensor;
XTensor * grad = param->grad; XTensor * grad = param->grad;
CheckNTErrors(grad != NULL, "No gradient!"); CheckNTErrors(grad != NULL, "No gradient!");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论