Commit ee577b38 by xiaotong

rename XParamKeeper as XTensorKeeper

parent 69478776
......@@ -96,7 +96,7 @@ void XLeader::SetServerModel(XConfig * config, XModel * model, XList * memberMod
{
serverModel.Clear();
for (int i = 0; i < model->paramNum; i++) {
XTensor * param = model->params[i].param;
XTensor * param = model->params[i].tensor;
serverModel.AddParam(param);
}
......@@ -151,7 +151,7 @@ void XLeader::InitForRun()
void XLeader::ResetParamGrad()
{
for (int i = 0; i < serverModel.paramNum; i++) {
XTensor* param = serverModel.params[i].param;
XTensor* param = serverModel.params[i].tensor;
if (param->grad != NULL) {
param->grad->SetZeroAll();
}
......@@ -161,7 +161,7 @@ void XLeader::ResetParamGrad()
XWorkerJob * worker = (XWorkerJob*)jworkers[j];
XModel * model = worker->GetModel();
for (int i = 0; i < model->paramNum; i++) {
XTensor* param = model->params[i].param;
XTensor* param = model->params[i].tensor;
if (param->grad != NULL) {
param->grad->SetZeroAll();
}
......@@ -517,25 +517,25 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
while (1) {
for (int j = 0; j < serverModel.paramNum; j++) {
XParamKeeper &paramServer = serverModel.params[j];
XTensorKeeper &paramServer = serverModel.params[j];
/* isGradFinished is true only if the model finishes the computation
(in another process) */
if (paramServer.flag != PARAM_STATE_NOT_READY || !paramServer.param->isGradFinished)
if (paramServer.flag != PARAM_STATE_NOT_READY || !paramServer.tensor->isGradFinished)
continue;
/* check if all the models (or part of them) are ready */
for (int i = 0; i < members.count; i++) {
XModel * source = (XModel*)members.GetItem(i);
XParamKeeper &paramSource = source->params[j];
XTensorKeeper &paramSource = source->params[j];
/* isGradFinished is true only if the model finishes the computation
(in another process) */
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.param->isGradFinished) {
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.tensor->isGradFinished) {
XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j);
/* data transmit */
collecter->AddJobCollectDataP2P(jobQueue, paramSource.param->grad, paramServer.param->grad);
collecter->AddJobCollectDataP2P(jobQueue, paramSource.tensor->grad, paramServer.tensor->grad);
collecter->AddJobEnqueueFinished(jobQueue);
/* reset the flag */
......
......@@ -35,9 +35,9 @@ namespace nts {
/* constructor */
XParamKeeper::XParamKeeper()
XTensorKeeper::XTensorKeeper()
{
param = NULL;
tensor = NULL;
flag = PARAM_STATE_NOT_READY;
trainFlag = PARAM_STATE_NOT_READY;
MUTEX_INIT(accessLock);
......@@ -45,7 +45,7 @@ XParamKeeper::XParamKeeper()
}
/* constructor */
XParamKeeper::~XParamKeeper()
XTensorKeeper::~XTensorKeeper()
{
MUTEX_DELE(accessLock);
MUTEX_DELE(trainLock);
......@@ -124,14 +124,14 @@ void XModel::AddParam(XTensor* param)
{
param->SetVarFlag();
XParamKeeper * newParams = new XParamKeeper[paramNum + 1];
XTensorKeeper * newParams = new XTensorKeeper[paramNum + 1];
for (int i = 0; i < paramNum; i++) {
newParams[i].param = params[i].param;
newParams[i].tensor = params[i].tensor;
newParams[i].flag = params[i].flag;
}
newParams[paramNum].param = param;
newParams[paramNum].tensor = param;
newParams[paramNum].flag = PARAM_STATE_NOT_READY;
delete[] params;
......@@ -143,7 +143,7 @@ void XModel::AddParam(XTensor* param)
bool XModel::CheckParam()
{
for (int i = 0; i < paramNum; i++) {
XTensor * param = params[i].param;
XTensor * param = params[i].tensor;
if (!param->isGrad)
return false;
}
......@@ -191,7 +191,7 @@ void XModel::WaitForUnlockedParams()
void XModel::RefreshMe()
{
for (int i = 0; i < paramNum; i++) {
params[i].param->isGradFinished = false;
params[i].tensor->isGradFinished = false;
params[i].flag = PARAM_STATE_NOT_READY;
params[i].trainFlag = PARAM_STATE_NOT_READY;
}
......
......@@ -50,12 +50,12 @@ enum PARAM_STATE { PARAM_STATE_NOT_READY,
PARAM_STATE_COLLECTED,
PARAM_STATE_UPDATED };
/* parameter keeper */
class XParamKeeper
/* tensor keeper */
class XTensorKeeper
{
public:
/* the parameter */
XTensor * param;
XTensor * tensor;
/* the parameter state */
PARAM_STATE flag;
......@@ -73,11 +73,10 @@ public:
public:
/* constructor */
XParamKeeper();
XTensorKeeper();
/* constructor */
~XParamKeeper();
~XTensorKeeper();
};
/* a model template for training */
......@@ -89,7 +88,7 @@ protected:
public:
/* the list of model parameters */
XParamKeeper * params;
XTensorKeeper * params;
/* parameter number */
int paramNum;
......
......@@ -64,7 +64,7 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, int pi
XModel * target = (XModel*)targetList->GetItem(i);
/* data transmit */
BroadcastP2P(source->params[pid].param, target->params[pid].param);
BroadcastP2P(source->params[pid].tensor, target->params[pid].tensor);
/* update the flag */
target->params[pid].flag = PARAM_STATE_UPDATED;
......
......@@ -64,7 +64,7 @@ void XWorkerUpdate::UpdateParameter(XModel * server, int pid,
CheckNTErrors(server->params[pid].flag == PARAM_STATE_COLLECTED, "The state of the parameter is wrong!");
XTensor * param = server->params[pid].param;
XTensor * param = server->params[pid].tensor;
XTensor * grad = param->grad;
CheckNTErrors(grad != NULL, "No gradient!");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论