Commit b2dcbed0 by xiaotong

new classes in trainer

parent b16b52f0
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This class maintains the parameters (and other stuff) for training. It
* could be used to manage the parameter copy and update in training. E.g.,
* one can use this class to keep the parameters on the server side, or
* treat it as an individual model on the worker side.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-02-24
*/
#include "XModel.h"
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
/* constructor */
XModel::XModel()
{
}
/* de-constructor */
XModel::~XModel()
{
}
/* clear the model */
void XModel::Clear()
{
params.Clear();
}
/* reset the flag of parameters (the flag is used in data transfer) */
void XModel::RefreshMe()
{
for (int i = 0; i < params.count; i++) {
XTensor * param = params.GetItem(i);
param->isGradFinished = false;
}
}
/* wrapper of RefreshMe */
void XModel::Refresh(XList * args)
{
CheckNTErrors(args != NULL, "illegal arguments!");
CheckNTErrors(args->count == 1, "The number of arguments must be 1!");
XModel * model = (XModel*)args->GetItem(0);
model->RefreshMe();
}
} /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This class maintains the parameters (and other stuff) for training. It
* could be used to manage the parameter copy and update in training. E.g.,
* one can use this class to keep the parameters on the server side, or
* treat it as an individual model on the worker side.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-02-24
* I created more than one file today, hahaha
*/
#ifndef __XMODEL_H__
#define __XMODEL_H__
#include "../network/XNet.h"
#include "../tensor/XQueue.h"
#include "../tensor/XList.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* a model template for training */
class XModel
{
public:
/* the list of model parameters (pointers to the parameter tensor) */
TensorList params;
public:
/* constructor */
XModel();
/* de-constructor */
~XModel();
/* clear the model */
void Clear();
/* reset the flag of parameters (the flag is used in data transfer) */
void RefreshMe();
/* wrapper of RefreshMe */
static
void Refresh(XList * args);
};
}
#endif // __XMODEL_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2021
* Natural Language Processing Lab, Northeastern University
* and
......
/* NiuTrans.Tensor - an open-source tensor library
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2021
* Natural Language Processing Lab, Northeastern University
* and
......@@ -23,7 +24,7 @@
* Distributed training is supported.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-02-23
* I start coding in 2021 after one year since I typed last C code.
* I start coding in 2021 after one year since I typed last line of C code.
* BUT i was a GOOD tex writter in 2020 :)
*/
......@@ -32,6 +33,7 @@
#include "../network/XNet.h"
#include "../tensor/XQueue.h"
#include "XWorker.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* The base class of worker. It maintains a job queue and offers utilities
* of controlling the working pipeline.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-02-24
*/
#include "XWorker.h"
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
/* constructor */
XWorker::XWorker()
{
}
/* de-constructor */
XWorker::~XWorker()
{
}
/* set device id */
void XWorker::SetDeviceID(int myDevID)
{
devID = myDevID;
}
/* get device id */
int XWorker::GetDeviceID()
{
return devID;
}
/* set worker id */
void XWorker::SetID(int myID)
{
id = myID;
}
/* get worker id */
int XWorker::GetID()
{
return id;
}
/*
enqueue a new job
>> job - the job function
>> jobArgs - the arguments of the function
*/
void XWorker::AddJob(void * job, XList * jobArgs)
{
queue.EnqueueJob(job, jobArgs);
}
/* start the work */
void XWorker::Start()
{
queue.StopJobConsumer();
}
/* stop the work */
void XWorker::Stop()
{
queue.StopJobConsumer();
}
/* get the number of remaining jobs */
int XWorker::GetJobNum()
{
return queue.GetJobNum();
}
/* whether the job queue is empty? */
bool XWorker::IsEmpty()
{
return queue.IsEmpty();
}
} /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* The base class of worker. It maintains a job queue and offers utilities
* of controlling the working pipeline.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-02-24
* People started to go back to the normal life after the Spring Festival.
* Traffic jams again.
*/
#ifndef __XWORKER_H__
#define __XWORKER_H__
#include "../tensor/XQueue.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* the worker class */
class XWorker
{
protected:
/* id of the device where we run the worker (we suppose that
the worker is insite. */
int devID;
/* id of the worker */
int id;
/* the queue */
XQueue queue;
public:
/* constructor */
XWorker();
/* de-constructor */
~XWorker();
/* set device id */
void SetDeviceID(int myDevID);
/* get device id */
int GetDeviceID();
/* set worker id */
void SetID(int myID);
/* get worker id */
int GetID();
/* enqueue a new job */
void AddJob(void * job, XList * jobArgs);
/* start the work */
void Start();
/* stop the work */
void Stop();
/* get the number of remaining jobs */
int GetJobNum();
/* whether the job queue is empty? */
bool IsEmpty();
};
}
#endif
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* The worker of running the neural network.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-02-24
*/
#include "XWorkerJob.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */
XWorkerJob::XWorkerJob()
{
}
/* de-constructor */
XWorkerJob::~XWorkerJob()
{
}
/*
add a new job of model refreshment
>> paramKeeper - keeper of the model parameters
<< return - succeeded or not
*/
bool XWorkerJob::AddJobRefresh(XModel * paramKeeper)
{
CheckNTErrors(paramKeeper != NULL, "no parameter keeper!");
XList args(1);
args.Add(paramKeeper);
queue.EnqueueJob(&XModel::Refresh, &args);
return true;
}
/*
add a new job of input generation
>> inputGenerator - the input generator
>> rawData - the data to be processed (input of the generator)
>> inputs - the generated inputs (output of the generator)
<< return - succeeded or not
*/
bool XWorkerJob::AddJobNewInput(void * inputGenerator, XList * rawData, XList * inputs)
{
CheckNTErrors(inputGenerator != NULL, "no input generator!");
XList args;
args.AddList(rawData);
args.AddList(inputs);
queue.EnqueueJob(inputGenerator, &args);
return true;
}
/*
add a new job of neural network forward and backward computation (with the input)
>> func - the function that runs the neural network
>> inputs - inputs of the neural network
>> outputs - outputs of the neural network
>> net - the neural network
<< return - succeeded or not
*/
bool XWorkerJob::AddJobNeuralNet(void * func, XList * inputs, XList * outputs, void * net)
{
CheckNTErrors(func != NULL, "no input function!");
CheckNTErrors(net != NULL, "no input neural network!");
XList args;
args.AddList(inputs);
args.AddList(outputs);
args.Add(net);
queue.EnqueueJob(func, &args);
return true;
}
} /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* The worker of running the neural network.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-02-24
* My son had new glasses yesterday.
*/
#ifndef __XWORDERJOB_H__
#define __XWORDERJOB_H__
#include "XWorker.h"
#include "XModel.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* a model template for training */
class XWorkerJob : public XWorker
{
private:
public:
/* constructor */
XWorkerJob();
/* de-constructor */
~XWorkerJob();
/* add a new job of model refreshment */
bool AddJobRefresh(XModel * paramKeeper);
/* add a new job of input generation */
bool AddJobNewInput(void * inputGenerator, XList * rawData, XList * inputs);
/* add a new job of neural network forward and backward computation (with the input) */
bool AddJobNeuralNet(void * func, XList * inputs, XList * outputs, void * net);
};
}
#endif
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论