Commit 2736ad16 by xiaotong

bug fixes and a new class XConfig

parent 327fb820
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* this class keeps a batch of paramters.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-02-28
*/
#include "XConfig.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* this class defines a parameter keeper.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-02-28
* A new semester begins today.
*/
#ifndef __XCONFIG_H__
#define __XCONFIG_H__
#include "XGlobal.h"
#include "XUtility.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
} // namespace nts(NiuTrans.Tensor)
#endif
\ No newline at end of file
...@@ -40,7 +40,7 @@ namespace nts { ...@@ -40,7 +40,7 @@ namespace nts {
XLeader::XLeader() XLeader::XLeader()
{ {
id = -1; id = -1;
dataLoader = NULL; dataDistributor = NULL;
} }
/* de-constructor */ /* de-constructor */
...@@ -60,11 +60,26 @@ int XLeader::GetID() ...@@ -60,11 +60,26 @@ int XLeader::GetID()
return id; return id;
} }
/*
set the communication mode
>> myMode - the mode
*/
void XLeader::SetMode(XLEADER_MODE myMode)
{
mode = myMode;
}
/* set the data loader */ /* set the data loader */
void XLeader::SetDataLoader(DataLoaderBase* myDataLoader) void XLeader::SetDataDistributor(DataDistributeBase* myDataDistributor)
{
CheckNTErrors(myDataDistributor != NULL,
"The input of XLeader::SetDataLoader should not be NULL!");
dataDistributor = myDataDistributor;
}
/* run the leader (this is the core process) */
void XLeader::Run()
{ {
CheckNTErrors(myDataLoader != NULL, "The input of XLeader::SetDataLoader should not be NULL!");
dataLoader = myDataLoader;
} }
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
...@@ -55,8 +55,11 @@ protected: ...@@ -55,8 +55,11 @@ protected:
/* id of the leader */ /* id of the leader */
int id; int id;
/* communication mode */
XLEADER_MODE mode;
/* data loader */ /* data loader */
DataLoaderBase* dataLoader; DataDistributeBase* dataDistributor;
public: public:
...@@ -72,8 +75,15 @@ public: ...@@ -72,8 +75,15 @@ public:
/* get id */ /* get id */
int GetID(); int GetID();
/* set the communication mode */
void SetMode(XLEADER_MODE myMode);
/* set the data loader */ /* set the data loader */
void SetDataLoader(DataLoaderBase * myDataLoader); void SetDataDistributor(DataDistributeBase * myDataLoader);
/* run the leader (this is the core process) */
virtual
void Run();
}; };
......
...@@ -35,45 +35,45 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -35,45 +35,45 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
*******************************/ *******************************/
/* constructor */ /* constructor */
DataLoaderBase::DataLoaderBase() DataDistributeBase::DataDistributeBase()
{ {
MUTEX_INIT(loadMutex); MUTEX_INIT(loadMutex);
} }
/* de-constructor */ /* de-constructor */
DataLoaderBase::~DataLoaderBase() DataDistributeBase::~DataDistributeBase()
{ {
MUTEX_DELE(loadMutex); MUTEX_DELE(loadMutex);
} }
/* open data file */ /* * start the job (e.g., open the file) */
bool DataLoaderBase::Open(XList * args) bool DataDistributeBase::Start(XList * args)
{ {
ShowNTErrors("DataLoaderBase::Open must be overloaded!"); ShowNTErrors("DataDistributeBase::Start must be overloaded!");
return true; return true;
} }
/* close data file */ /* end the job (e.g., close the file) */
bool DataLoaderBase::Close(XList * args) bool DataDistributeBase::End(XList * args)
{ {
ShowNTErrors("DataLoaderBase::Close must be overloaded!"); ShowNTErrors("DataDistributeBase::End must be overloaded!");
return true; return true;
} }
/* load a batch of samples */ /* get a batch of samples */
bool DataLoaderBase::LoadBatch(XList * args) bool DataDistributeBase::GetBatch(XList * args)
{ {
ShowNTErrors("DataLoaderBase::LoadBatch must be overloaded!"); ShowNTErrors("DataDistributeBase::GetBatch must be overloaded!");
return true; return true;
} }
/* load a batch of samples (for multi-threading) */ /* get a batch of samples (for multi-threading) */
bool DataLoaderBase::LoadBatchSafe(XList * args) bool DataDistributeBase::GetBatchSafe(XList * args)
{ {
bool r; bool r;
MUTEX_LOCK(loadMutex); MUTEX_LOCK(loadMutex);
r = LoadBatch(args); r = GetBatch(args);
MUTEX_UNLOCK(loadMutex); MUTEX_UNLOCK(loadMutex);
return r; return r;
......
...@@ -35,8 +35,16 @@ ...@@ -35,8 +35,16 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* data loader template */ /*
class DataLoaderBase data distributor template. It distribute batches of data to workers.
The use of data distributor follows:
Start() -> GetBatch() -> ... -> GetBatch() -> End()
In addition, GetBatch() should be thread-safe, and thus could be
called by different threads simultaneously.
*/
class DataDistributeBase
{ {
protected: protected:
/* mutex of batch loading */ /* mutex of batch loading */
...@@ -44,26 +52,26 @@ protected: ...@@ -44,26 +52,26 @@ protected:
public: public:
/* constructor */ /* constructor */
DataLoaderBase(); DataDistributeBase();
/* de-constructor */ /* de-constructor */
~DataLoaderBase(); ~DataDistributeBase();
/* open data file */ /* start the job (e.g., open the file) */
virtual virtual
bool Open(XList * args); bool Start(XList * args);
/* close data file */ /* end the job (e.g., close the file) */
virtual virtual
bool Close(XList * args); bool End(XList * args);
/* load a batch of samples */ /* get a batch of samples */
virtual virtual
bool LoadBatch(XList * args); bool GetBatch(XList * args);
protected: protected:
/* load a batch of samples (for multi-threading) */ /* get a batch of samples (for multi-threading) */
bool LoadBatchSafe(XList * args); bool GetBatchSafe(XList * args);
}; };
/* neural network template */ /* neural network template */
......
...@@ -57,34 +57,14 @@ bool XWorkerJob::AddJobRefresh(XModel * paramKeeper) ...@@ -57,34 +57,14 @@ bool XWorkerJob::AddJobRefresh(XModel * paramKeeper)
} }
/* /*
add a new job of input generation
>> inputGenerator - the input generator
>> rawData - the data to be processed (input of the generator)
>> inputs - the generated inputs (output of the generator)
<< return - succeeded or not
*/
bool XWorkerJob::AddJobNewInput(void * inputGenerator, XList * rawData, XList * inputs)
{
CheckNTErrors(inputGenerator != NULL, "no input generator!");
XList args;
args.AddList(rawData);
args.AddList(inputs);
queue.EnqueueJob(inputGenerator, &args);
return true;
}
/*
add a new job of neural network forward and backward computation (with the input) add a new job of neural network forward and backward computation (with the input)
>> func - the function that runs the neural network >> func - the function that calls the run of the neural network
>> net - the neural network
>> inputs - inputs of the neural network >> inputs - inputs of the neural network
>> outputs - outputs of the neural network >> outputs - outputs of the neural network
>> net - the neural network
<< return - succeeded or not << return - succeeded or not
*/ */
bool XWorkerJob::AddJobNeuralNet(void * func, XList * inputs, XList * outputs, void * net) bool XWorkerJob::AddJobNeuralNet(void * func, void * net, XList * inputs, XList * outputs)
{ {
CheckNTErrors(func != NULL, "no input function!"); CheckNTErrors(func != NULL, "no input function!");
CheckNTErrors(net != NULL, "no input neural network!"); CheckNTErrors(net != NULL, "no input neural network!");
......
...@@ -50,11 +50,8 @@ public: ...@@ -50,11 +50,8 @@ public:
/* add a new job of model refreshment */ /* add a new job of model refreshment */
bool AddJobRefresh(XModel * paramKeeper); bool AddJobRefresh(XModel * paramKeeper);
/* add a new job of input generation */
bool AddJobNewInput(void * inputGenerator, XList * rawData, XList * inputs);
/* add a new job of neural network forward and backward computation (with the input) */ /* add a new job of neural network forward and backward computation (with the input) */
bool AddJobNeuralNet(void * func, XList * inputs, XList * outputs, void * net); bool AddJobNeuralNet(void * func, void * net, XList * inputs, XList * outputs);
}; };
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论