Commit 9b3209d8 by xiaotong

updates of XWorker and XWorkerJob

parent e3455593
......@@ -80,10 +80,20 @@ void XLeader::SetMode(XLEADER_MODE myMode)
/*
add a number of job workers (given their device ids)
>> model - the neural network
>> n - number of the models
>> ids - the array of device ids
*/
void XLeader::AddJobWorker(XModel * model, int * ids)
void XLeader::AddJobWorker(XModel * model, int n, int * ids)
{
/* we keep the input model */
if (n >= 1) {
jworkers.Add(model);
}
/* we clone the input model */
for (int i = 0; i < n - 1; i++) {
jworkers.Add(model->Clone(ids[i]));
}
}
/*
......@@ -96,6 +106,25 @@ run the model (for one time)
void XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
XModel * model, XOptimizer * optimizer)
{
/* feed the input to each worker and geneate the output */
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i];
XModel * model = worker->GetModel();
/* get a batch of samples */
dataDistributor->GetBatch(worker->GetInput());
/* job in the queue: refresh the model */
worker->AddJobRefresh(model);
/* job in the queue: run the model */
worker->AddJobNeuralNet(model, worker->GetInput(), worker->GetOutput());
/* clear it */
worker->Clear();
}
/* collect the (gradient) data and update the model */
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -88,7 +88,7 @@ public:
void SetMode(XLEADER_MODE myMode);
/* add a number of job workers (given their device ids) */
void AddJobWorker(XModel * model, int * ids);
void AddJobWorker(XModel * model, int n, int * ids);
/* run the model (for one time) */
void Run(XConfig * config, DataDistributeBase * dataDistributor,
......
......@@ -67,13 +67,13 @@ XModel * XModel::Clone(int devID)
run the neural network
>> args - the arguments
*/
bool XModel::Run(XList * args)
bool XModel::RunMe(XList * args)
{
ShowNTErrors("NetBase::Run must be overloaded!");
return true;
}
/* reset the flag of parameters (the flag is used in data transfer) */
/* refresh the model */
void XModel::RefreshMe()
{
for (int i = 0; i < params.count; i++) {
......@@ -85,23 +85,18 @@ void XModel::RefreshMe()
/* wrapper of RefreshMe */
void XModel::Refresh(XList * args)
{
CheckNTErrors(args != NULL, "illegal arguments!");
CheckNTErrors(args->count == 1, "The number of arguments must be 1!");
CheckNTErrors(args != NULL || args->count == 0, "no arguments for XModel::Refresh");
XModel * model = (XModel*)args->GetItem(0);
model->RefreshMe();
}
/* run the neural network (for multi-threading */
bool XModel::RunSafe(XList * args)
/* wrapper of Run() */
bool XModel::Run(XList * args)
{
bool r;
MUTEX_LOCK(netMutex);
r = Run(args);
MUTEX_UNLOCK(netMutex);
CheckNTErrors(args != NULL || args->count == 0, "no arguments for XModel::Refresh");
XModel * model = (XModel*)args->GetItem(0);
return r;
return model->Run(args);
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -67,21 +67,20 @@ public:
/* run the neural network (would be overloaded) */
virtual
bool Run(XList * args);
bool RunMe(XList * args);
public:
/* reset the flag of parameters (the flag is used in data transfer) */
/* refresh the model */
void RefreshMe();
/* wrapper of RefreshMe */
/* wrapper of RefreshMe() */
static
void Refresh(XList * args);
protected:
/* run the neural network (for multi-threading) */
bool RunSafe(XList * args);
/* wrapper of Run() */
static
bool Run(XList * args);
};
}
......
......@@ -53,6 +53,30 @@ XModel * XWorkerJob::GetModel()
return model;
}
/* clear the worker */
void XWorkerJob::Clear()
{
for (int i = 0; i < inputs.count; i++)
delete (XTensor*)inputs[i];
inputs.Clear();
for (int i = 0; i < outputs.count; i++)
delete (XTensor*)outputs[i];
outputs.Clear();
}
/* get the input list */
XList * XWorkerJob::GetInput()
{
return &inputs;
}
/* get the output list */
XList * XWorkerJob::GetOutput()
{
return &outputs;
}
/*
add a new job of model refreshment
>> myModel - the model
......@@ -65,30 +89,30 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel)
XList args(1);
args.Add(myModel);
queue.EnqueueJob((void*)&myModel->Refresh, &args);
queue.EnqueueJob(XModel::Refresh, &args);
return true;
}
/*
add a new job of neural network forward and backward computation (with the input)
>> func - the function that calls the run of the neural network
>> myModel - the model
>> inputs - inputs of the neural network
>> outputs - outputs of the neural network
<< return - succeeded or not
*/
bool XWorkerJob::AddJobNeuralNet(void * func, XModel * myModel, XList * inputs, XList * outputs)
bool XWorkerJob::AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs)
{
CheckNTErrors(func != NULL, "no input function!");
CheckNTErrors(myModel != NULL, "no input neural network!");
CheckNTErrors(inputs != NULL, "no inputs of the model!");
CheckNTErrors(outputs != NULL, "no outputs of the model!");
XList args;
args.Add(myModel);
args.AddList(inputs);
args.AddList(outputs);
args.Add(myModel);
queue.EnqueueJob(func, &args);
queue.EnqueueJob(XModel::Run, &args);
return true;
}
......
......@@ -42,6 +42,12 @@ class XWorkerJob : public XWorker
protected:
/* the model */
XModel * model;
/* the input tensors of the model */
XList inputs;
/* the output tensors of the model */
XList outputs;
public:
......@@ -57,11 +63,20 @@ public:
/* get the parameter keeper */
XModel * GetModel();
/* clear the worker */
void Clear();
/* get the input list */
XList * GetInput();
/* get the output list */
XList * GetOutput();
/* add a new job of model refreshment */
bool AddJobRefresh(XModel * myModel);
/* add a new job of neural network forward and backward computation (with the input) */
bool AddJobNeuralNet(void * func, XModel * myModel, XList * inputs, XList * outputs);
bool AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs);
};
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论