Commit 9b3209d8 by xiaotong

updates of XWorker and XWorkerJob

parent e3455593
...@@ -80,10 +80,20 @@ void XLeader::SetMode(XLEADER_MODE myMode) ...@@ -80,10 +80,20 @@ void XLeader::SetMode(XLEADER_MODE myMode)
/* /*
add a number of job workers (given their device ids) add a number of job workers (given their device ids)
>> model - the neural network >> model - the neural network
>> n - number of the models
>> ids - the array of device ids >> ids - the array of device ids
*/ */
void XLeader::AddJobWorker(XModel * model, int * ids) void XLeader::AddJobWorker(XModel * model, int n, int * ids)
{ {
/* we keep the input model */
if (n >= 1) {
jworkers.Add(model);
}
/* we clone the input model */
for (int i = 0; i < n - 1; i++) {
jworkers.Add(model->Clone(ids[i]));
}
} }
/* /*
...@@ -96,6 +106,25 @@ run the model (for one time) ...@@ -96,6 +106,25 @@ run the model (for one time)
void XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, void XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
XModel * model, XOptimizer * optimizer) XModel * model, XOptimizer * optimizer)
{ {
/* feed the input to each worker and geneate the output */
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i];
XModel * model = worker->GetModel();
/* get a batch of samples */
dataDistributor->GetBatch(worker->GetInput());
/* job in the queue: refresh the model */
worker->AddJobRefresh(model);
/* job in the queue: run the model */
worker->AddJobNeuralNet(model, worker->GetInput(), worker->GetOutput());
/* clear it */
worker->Clear();
}
/* collect the (gradient) data and update the model */
} }
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -88,7 +88,7 @@ public: ...@@ -88,7 +88,7 @@ public:
void SetMode(XLEADER_MODE myMode); void SetMode(XLEADER_MODE myMode);
/* add a number of job workers (given their device ids) */ /* add a number of job workers (given their device ids) */
void AddJobWorker(XModel * model, int * ids); void AddJobWorker(XModel * model, int n, int * ids);
/* run the model (for one time) */ /* run the model (for one time) */
void Run(XConfig * config, DataDistributeBase * dataDistributor, void Run(XConfig * config, DataDistributeBase * dataDistributor,
......
...@@ -67,13 +67,13 @@ XModel * XModel::Clone(int devID) ...@@ -67,13 +67,13 @@ XModel * XModel::Clone(int devID)
run the neural network run the neural network
>> args - the arguments >> args - the arguments
*/ */
bool XModel::Run(XList * args) bool XModel::RunMe(XList * args)
{ {
ShowNTErrors("NetBase::Run must be overloaded!"); ShowNTErrors("NetBase::Run must be overloaded!");
return true; return true;
} }
/* reset the flag of parameters (the flag is used in data transfer) */ /* refresh the model */
void XModel::RefreshMe() void XModel::RefreshMe()
{ {
for (int i = 0; i < params.count; i++) { for (int i = 0; i < params.count; i++) {
...@@ -85,23 +85,18 @@ void XModel::RefreshMe() ...@@ -85,23 +85,18 @@ void XModel::RefreshMe()
/* wrapper of RefreshMe */ /* wrapper of RefreshMe */
void XModel::Refresh(XList * args) void XModel::Refresh(XList * args)
{ {
CheckNTErrors(args != NULL, "illegal arguments!"); CheckNTErrors(args != NULL || args->count == 0, "no arguments for XModel::Refresh");
CheckNTErrors(args->count == 1, "The number of arguments must be 1!");
XModel * model = (XModel*)args->GetItem(0); XModel * model = (XModel*)args->GetItem(0);
model->RefreshMe(); model->RefreshMe();
} }
/* run the neural network (for multi-threading */ /* wrapper of Run() */
bool XModel::RunSafe(XList * args) bool XModel::Run(XList * args)
{ {
bool r; CheckNTErrors(args != NULL || args->count == 0, "no arguments for XModel::Refresh");
XModel * model = (XModel*)args->GetItem(0);
MUTEX_LOCK(netMutex);
r = Run(args);
MUTEX_UNLOCK(netMutex);
return r; return model->Run(args);
} }
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -67,21 +67,20 @@ public: ...@@ -67,21 +67,20 @@ public:
/* run the neural network (would be overloaded) */ /* run the neural network (would be overloaded) */
virtual virtual
bool Run(XList * args); bool RunMe(XList * args);
public: public:
/* reset the flag of parameters (the flag is used in data transfer) */ /* refresh the model */
void RefreshMe(); void RefreshMe();
/* wrapper of RefreshMe */ /* wrapper of RefreshMe() */
static static
void Refresh(XList * args); void Refresh(XList * args);
protected: /* wrapper of Run() */
static
/* run the neural network (for multi-threading) */ bool Run(XList * args);
bool RunSafe(XList * args);
}; };
} }
......
...@@ -53,6 +53,30 @@ XModel * XWorkerJob::GetModel() ...@@ -53,6 +53,30 @@ XModel * XWorkerJob::GetModel()
return model; return model;
} }
/* clear the worker */
void XWorkerJob::Clear()
{
for (int i = 0; i < inputs.count; i++)
delete (XTensor*)inputs[i];
inputs.Clear();
for (int i = 0; i < outputs.count; i++)
delete (XTensor*)outputs[i];
outputs.Clear();
}
/* get the input list */
XList * XWorkerJob::GetInput()
{
return &inputs;
}
/* get the output list */
XList * XWorkerJob::GetOutput()
{
return &outputs;
}
/* /*
add a new job of model refreshment add a new job of model refreshment
>> myModel - the model >> myModel - the model
...@@ -65,30 +89,30 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel) ...@@ -65,30 +89,30 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel)
XList args(1); XList args(1);
args.Add(myModel); args.Add(myModel);
queue.EnqueueJob((void*)&myModel->Refresh, &args); queue.EnqueueJob(XModel::Refresh, &args);
return true; return true;
} }
/* /*
add a new job of neural network forward and backward computation (with the input) add a new job of neural network forward and backward computation (with the input)
>> func - the function that calls the run of the neural network
>> myModel - the model >> myModel - the model
>> inputs - inputs of the neural network >> inputs - inputs of the neural network
>> outputs - outputs of the neural network >> outputs - outputs of the neural network
<< return - succeeded or not << return - succeeded or not
*/ */
bool XWorkerJob::AddJobNeuralNet(void * func, XModel * myModel, XList * inputs, XList * outputs) bool XWorkerJob::AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs)
{ {
CheckNTErrors(func != NULL, "no input function!");
CheckNTErrors(myModel != NULL, "no input neural network!"); CheckNTErrors(myModel != NULL, "no input neural network!");
CheckNTErrors(inputs != NULL, "no inputs of the model!");
CheckNTErrors(outputs != NULL, "no outputs of the model!");
XList args; XList args;
args.Add(myModel);
args.AddList(inputs); args.AddList(inputs);
args.AddList(outputs); args.AddList(outputs);
args.Add(myModel);
queue.EnqueueJob(func, &args); queue.EnqueueJob(XModel::Run, &args);
return true; return true;
} }
......
...@@ -42,6 +42,12 @@ class XWorkerJob : public XWorker ...@@ -42,6 +42,12 @@ class XWorkerJob : public XWorker
protected: protected:
/* the model */ /* the model */
XModel * model; XModel * model;
/* the input tensors of the model */
XList inputs;
/* the output tensors of the model */
XList outputs;
public: public:
...@@ -57,11 +63,20 @@ public: ...@@ -57,11 +63,20 @@ public:
/* get the parameter keeper */ /* get the parameter keeper */
XModel * GetModel(); XModel * GetModel();
/* clear the worker */
void Clear();
/* get the input list */
XList * GetInput();
/* get the output list */
XList * GetOutput();
/* add a new job of model refreshment */ /* add a new job of model refreshment */
bool AddJobRefresh(XModel * myModel); bool AddJobRefresh(XModel * myModel);
/* add a new job of neural network forward and backward computation (with the input) */ /* add a new job of neural network forward and backward computation (with the input) */
bool AddJobNeuralNet(void * func, XModel * myModel, XList * inputs, XList * outputs); bool AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs);
}; };
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论