Commit f866d06d by xiaotong

updates of XTensorKeeper

parent f41939be
......@@ -174,6 +174,17 @@ void XLeader::ResetParamGrad()
}
/*
prepare for running
>> config - the configuration
>> model - the model that we run
*/
void XLeader::MakeAll(XConfig * config, XModel * model)
{
SetServerModel(config, model);
ResetParamGrad();
}
/*
wait for finished states (i.e., all workers finish their jobs)
>> activeJobWorkers - indicates whether each job worker is active
>> isToUpdate - indicates whether the model is updated
......
......@@ -139,6 +139,9 @@ public:
/* set grad = 0 */
void ResetParamGrad();
/* prepare for running */
void MakeAll(XConfig * config, XModel * model);
/* wait for finished states (i.e., all workers finish their jobs) */
void WaitForFinishing(const int * activeJobWorkers, const int isToUpdate);
......
......@@ -33,24 +33,6 @@
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
/* constructor */
XTensorKeeper::XTensorKeeper()
{
tensor = NULL;
flag = PARAM_STATE_NOT_READY;
trainFlag = PARAM_STATE_NOT_READY;
MUTEX_INIT(accessLock);
MUTEX_INIT(trainLock);
}
/* constructor */
XTensorKeeper::~XTensorKeeper()
{
MUTEX_DELE(accessLock);
MUTEX_DELE(trainLock);
}
/* constructor */
XModel::XModel()
{
......
......@@ -32,53 +32,13 @@
#ifndef __XMODEL_H__
#define __XMODEL_H__
#include "XTensorKeeper.h"
#include "../network/XNet.h"
#include "../tensor/XQueue.h"
#include "../tensor/XList.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
parameter state
1) not ready
2) ready
3) the parameter has been collected from other models
4) the updated parameter
*/
enum PARAM_STATE { PARAM_STATE_NOT_READY,
PARAM_STATE_READY,
PARAM_STATE_COLLECTED,
PARAM_STATE_UPDATED };
/* tensor keeper */
class XTensorKeeper
{
public:
/* the parameter */
XTensor * tensor;
/* the parameter state */
PARAM_STATE flag;
/* the state of the entire training process
(choosing from PARAM_STATE_NOT_READY and
PARAM_STATE_UPDATED */
PARAM_STATE trainFlag;
/* a mutex for locking and unlocking the parameter */
MUTEX_HANDLE accessLock;
/* a mutex of the overall training */
MUTEX_HANDLE trainLock;
public:
/* constructor */
XTensorKeeper();
/* constructor */
~XTensorKeeper();
};
/* a model template for training */
class XModel
{
......
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* We define a class that keeps a tensor (could be either a parameter or
* gradient).
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-25
*/
#include "XTensorKeeper.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */
XTensorKeeper::XTensorKeeper()
{
tensor = NULL;
flag = PARAM_STATE_NOT_READY;
trainFlag = PARAM_STATE_NOT_READY;
MUTEX_INIT(accessLock);
MUTEX_INIT(trainLock);
}
/* constructor */
XTensorKeeper::~XTensorKeeper()
{
MUTEX_DELE(accessLock);
MUTEX_DELE(trainLock);
}
}
\ No newline at end of file
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* We define a class that keeps a tensor (could be either a parameter or
* gradient).
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-25
* I will take the first shot of the COVID-19 vaccine this afternoon.
*/
#ifndef __XTENSORKEEPER_H__
#define __XTENSORKEEPER_H__
#include "../network/XNet.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
parameter state
1) not ready
2) ready
3) the parameter has been collected from other models
4) the updated parameter
*/
enum PARAM_STATE {
PARAM_STATE_NOT_READY,
PARAM_STATE_READY,
PARAM_STATE_COLLECTED,
PARAM_STATE_UPDATED
};
/* tensor keeper */
class XTensorKeeper
{
public:
/* the parameter */
XTensor * tensor;
/* the parameter state */
PARAM_STATE flag;
/* the state of the entire training process
(choosing from PARAM_STATE_NOT_READY and
PARAM_STATE_UPDATED */
PARAM_STATE trainFlag;
/* a mutex for locking and unlocking the parameter */
MUTEX_HANDLE accessLock;
/* a mutex of the overall training */
MUTEX_HANDLE trainLock;
public:
/* constructor */
XTensorKeeper();
/* constructor */
~XTensorKeeper();
};
}
#endif
\ No newline at end of file
......@@ -120,7 +120,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
leader.AddJobBroadcastWorker();
leader.AddJobParamterWorker(model->paramNum);
//leader.SetInstantRun();
leader.SetServerModel(config, model);
leader.MakeAll(config, model);
leader.Start();
/* learning rate scheduler */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论