Commit 2736ad16 by xiaotong

bug fixes and a new class XConfig

parent 327fb820
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* this class keeps a batch of paramters.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-02-28
*/
#include "XConfig.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* this class defines a parameter keeper.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-02-28
* A new semester begins today.
*/
#ifndef __XCONFIG_H__
#define __XCONFIG_H__
#include "XGlobal.h"
#include "XUtility.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
} // namespace nts(NiuTrans.Tensor)
#endif
\ No newline at end of file
......@@ -40,7 +40,7 @@ namespace nts {
XLeader::XLeader()
{
id = -1;
dataLoader = NULL;
dataDistributor = NULL;
}
/* de-constructor */
......@@ -60,11 +60,26 @@ int XLeader::GetID()
return id;
}
/*
set the communication mode
>> myMode - the mode
*/
void XLeader::SetMode(XLEADER_MODE myMode)
{
mode = myMode;
}
/* set the data loader */
void XLeader::SetDataLoader(DataLoaderBase* myDataLoader)
void XLeader::SetDataDistributor(DataDistributeBase* myDataDistributor)
{
CheckNTErrors(myDataDistributor != NULL,
"The input of XLeader::SetDataLoader should not be NULL!");
dataDistributor = myDataDistributor;
}
/* run the leader (this is the core process) */
void XLeader::Run()
{
CheckNTErrors(myDataLoader != NULL, "The input of XLeader::SetDataLoader should not be NULL!");
dataLoader = myDataLoader;
}
} /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
......@@ -55,8 +55,11 @@ protected:
/* id of the leader */
int id;
/* communication mode */
XLEADER_MODE mode;
/* data loader */
DataLoaderBase* dataLoader;
DataDistributeBase* dataDistributor;
public:
......@@ -72,8 +75,15 @@ public:
/* get id */
int GetID();
/* set the communication mode */
void SetMode(XLEADER_MODE myMode);
/* set the data loader */
void SetDataLoader(DataLoaderBase * myDataLoader);
void SetDataDistributor(DataDistributeBase * myDataLoader);
/* run the leader (this is the core process) */
virtual
void Run();
};
......
......@@ -35,45 +35,45 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
*******************************/
/* constructor */
DataLoaderBase::DataLoaderBase()
DataDistributeBase::DataDistributeBase()
{
MUTEX_INIT(loadMutex);
}
/* de-constructor */
DataLoaderBase::~DataLoaderBase()
DataDistributeBase::~DataDistributeBase()
{
MUTEX_DELE(loadMutex);
}
/* open data file */
bool DataLoaderBase::Open(XList * args)
/* * start the job (e.g., open the file) */
bool DataDistributeBase::Start(XList * args)
{
ShowNTErrors("DataLoaderBase::Open must be overloaded!");
ShowNTErrors("DataDistributeBase::Start must be overloaded!");
return true;
}
/* close data file */
bool DataLoaderBase::Close(XList * args)
/* end the job (e.g., close the file) */
bool DataDistributeBase::End(XList * args)
{
ShowNTErrors("DataLoaderBase::Close must be overloaded!");
ShowNTErrors("DataDistributeBase::End must be overloaded!");
return true;
}
/* load a batch of samples */
bool DataLoaderBase::LoadBatch(XList * args)
/* get a batch of samples */
bool DataDistributeBase::GetBatch(XList * args)
{
ShowNTErrors("DataLoaderBase::LoadBatch must be overloaded!");
ShowNTErrors("DataDistributeBase::GetBatch must be overloaded!");
return true;
}
/* load a batch of samples (for multi-threading) */
bool DataLoaderBase::LoadBatchSafe(XList * args)
/* get a batch of samples (for multi-threading) */
bool DataDistributeBase::GetBatchSafe(XList * args)
{
bool r;
MUTEX_LOCK(loadMutex);
r = LoadBatch(args);
r = GetBatch(args);
MUTEX_UNLOCK(loadMutex);
return r;
......
......@@ -35,8 +35,16 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* data loader template */
class DataLoaderBase
/*
data distributor template. It distribute batches of data to workers.
The use of data distributor follows:
Start() -> GetBatch() -> ... -> GetBatch() -> End()
In addition, GetBatch() should be thread-safe, and thus could be
called by different threads simultaneously.
*/
class DataDistributeBase
{
protected:
/* mutex of batch loading */
......@@ -44,26 +52,26 @@ protected:
public:
/* constructor */
DataLoaderBase();
DataDistributeBase();
/* de-constructor */
~DataLoaderBase();
~DataDistributeBase();
/* open data file */
/* start the job (e.g., open the file) */
virtual
bool Open(XList * args);
bool Start(XList * args);
/* close data file */
/* end the job (e.g., close the file) */
virtual
bool Close(XList * args);
bool End(XList * args);
/* load a batch of samples */
/* get a batch of samples */
virtual
bool LoadBatch(XList * args);
bool GetBatch(XList * args);
protected:
/* load a batch of samples (for multi-threading) */
bool LoadBatchSafe(XList * args);
/* get a batch of samples (for multi-threading) */
bool GetBatchSafe(XList * args);
};
/* neural network template */
......
......@@ -57,34 +57,14 @@ bool XWorkerJob::AddJobRefresh(XModel * paramKeeper)
}
/*
add a new job of input generation
>> inputGenerator - the input generator
>> rawData - the data to be processed (input of the generator)
>> inputs - the generated inputs (output of the generator)
<< return - succeeded or not
*/
bool XWorkerJob::AddJobNewInput(void * inputGenerator, XList * rawData, XList * inputs)
{
CheckNTErrors(inputGenerator != NULL, "no input generator!");
XList args;
args.AddList(rawData);
args.AddList(inputs);
queue.EnqueueJob(inputGenerator, &args);
return true;
}
/*
add a new job of neural network forward and backward computation (with the input)
>> func - the function that runs the neural network
>> func - the function that calls the run of the neural network
>> net - the neural network
>> inputs - inputs of the neural network
>> outputs - outputs of the neural network
>> net - the neural network
<< return - succeeded or not
*/
bool XWorkerJob::AddJobNeuralNet(void * func, XList * inputs, XList * outputs, void * net)
bool XWorkerJob::AddJobNeuralNet(void * func, void * net, XList * inputs, XList * outputs)
{
CheckNTErrors(func != NULL, "no input function!");
CheckNTErrors(net != NULL, "no input neural network!");
......
......@@ -50,11 +50,8 @@ public:
/* add a new job of model refreshment */
bool AddJobRefresh(XModel * paramKeeper);
/* add a new job of input generation */
bool AddJobNewInput(void * inputGenerator, XList * rawData, XList * inputs);
/* add a new job of neural network forward and backward computation (with the input) */
bool AddJobNeuralNet(void * func, XList * inputs, XList * outputs, void * net);
bool AddJobNeuralNet(void * func, void * net, XList * inputs, XList * outputs);
};
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论