Commit 6ec2d28c by xiaotong

new class: XWorkerBroadcast

parent 40f5462d
......@@ -61,6 +61,10 @@ void XLeader::Init()
for (int i = 0; i < uworkers.count; i++)
delete (XWorkerUpdate*)uworkers.GetItem(i);
uworkers.Clear();
for (int i = 0; i < bworkers.count; i++)
delete (XWorkerBroadcast*)bworkers.GetItem(i);
bworkers.Clear();
}
/* set id */
......@@ -76,6 +80,36 @@ int XLeader::GetID()
}
/*
Set the server model. It distributes the server-side parameters on different devices.
>> config - the configuration
>> model - the base model
>> memberModels - the models that run on different devices. We can place
the server-side parameters on different member models.
*/
void XLeader::SetServerModel(XConfig * config, XModel * model, XList * memberModels)
{
serverModel.params.Clear();
serverModel.params.AddList(&model->params);
/* TODO: we can place parameters on different devices */
}
/*
set the server model. It distributes the server-side parameters on different devices.
>> config - the configuration
>> model - the base model*/
void XLeader::SetServerModel(XConfig * config, XModel * model)
{
XList members;
for (int i = 0; i < jworkers.count; i++) {
XModel * member = (XModel*)jworkers.GetItem(i);
members.Add(member);
}
SetServerModel(config, model, &members);
}
/*
set the communication mode
>> myMode - the mode
*/
......@@ -100,6 +134,11 @@ void XLeader::Start()
XWorkerJob * worker = (XWorkerJob*)uworkers.GetItem(i);
worker->Start();
}
for (int i = 0; i < bworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)bworkers.GetItem(i);
worker->Start();
}
}
/*
......@@ -148,6 +187,13 @@ void XLeader::AddJobUpdateWorker(XModel * model, XOptimizer * optimizer)
uworkers.Add(worker);
}
/* add a data-broadcasting worker */
void XLeader::AddJobBroadcastWorker()
{
XWorkerBroadcast * worker = new XWorkerBroadcast();
bworkers.Add(worker);
}
/*
run the model (for one time)
>> config - the configuration
......@@ -167,16 +213,16 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
*/
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i];
XModel * model = worker->GetModel();
XModel * jmodel = worker->GetModel();
/* get a batch of samples */
bool fetched = dataDistributor->GetBatch(worker->GetInput());
/* job in queue 1: refresh the model */
worker->AddJobRefresh(model);
worker->AddJobRefresh(jmodel);
/* job in queue 1: run the model */
worker->AddJobNeuralNet(model, worker->GetInput(), worker->GetOutput());
worker->AddJobNeuralNet(jmodel, worker->GetInput(), worker->GetOutput());
/* clear it */
worker->Clear();
......@@ -188,14 +234,13 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
XList members(jworkers.count);
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i];
if (worker->GetModel() != model)
members.Add(worker->GetModel());
members.Add(worker->GetModel());
}
/* job in queue 2: collect the (gradient) data */
if (cworkers.count > 0) {
XWorkerCollect * collecter = (XWorkerCollect*)cworkers.GetItem(0);
collecter->AddJobCollect(&members, model);
collecter->AddJobCollect(&members, &serverModel);
}
else {
ShowNTErrors("No data-collecting workers!");
......@@ -204,12 +249,21 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
/* job in queue 3: update the model */
if (uworkers.count > 0) {
XWorkerUpdate * updater = (XWorkerUpdate*)uworkers.GetItem(0);
updater->AddJobUpdate(model, optimizer);
updater->AddJobUpdate(&serverModel, optimizer);
}
else {
ShowNTErrors("No model-update workers!");
}
/* job in queue 4: broadcast the lastest parameters to workers */
if (bworkers.count > 0) {
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)bworkers.GetItem(0);
broadcaster->AddJobBroadcast(&serverModel, &members);
}
else {
ShowNTErrors("No data-broadcasting workers!");
}
return isDataOK;
}
......
......@@ -41,6 +41,7 @@
#include "XWorkerJob.h"
#include "XWorkerCollect.h"
#include "XWorkerUpdate.h"
#include "XWorkerBroadcast.h"
#include "../tensor/XConfig.h"
#include "../tensor/XList.h"
......@@ -64,6 +65,9 @@ protected:
/* id of the leader */
int id;
/* a model that keeps the parameters (as a server) */
XModel serverModel;
/* communication mode */
XLEADER_MODE mode;
......@@ -76,6 +80,9 @@ protected:
/* model-update workers */
XList uworkers;
/* data-broadcasting workers */
XList bworkers;
public:
/* constructor */
XLeader();
......@@ -92,6 +99,12 @@ public:
/* get id */
int GetID();
/* set the server model */
void SetServerModel(XConfig * config, XModel * model, XList * memberModels);
/* set the server model */
void SetServerModel(XConfig * config, XModel * model);
/* start the workers */
void Start();
......@@ -107,6 +120,9 @@ public:
/* add a model-update worker */
void AddJobUpdateWorker(XModel * model, XOptimizer * optimizer);
/* add a data-broadcasting worker */
void AddJobBroadcastWorker();
/* run the model (for one time) */
bool Run(XConfig * config, DataDistributeBase * dataDistributor,
XModel * model, XOptimizer * optimizer);
......
......@@ -109,6 +109,8 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
leader.AddJobWorker(model, jobNum, ids);
leader.AddJobCollectWorker();
leader.AddJobUpdateWorker(model, optimizer);
leader.AddJobBroadcastWorker();
leader.SetServerModel(config, model);
leader.Start();
/* train the model */
......
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* The worker that boradcast the lastest parameters from the server to
* the workers.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-03
*/
#include "XWorkerBroadcast.h"
#include "../tensor/core/CHeader.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */
XWorkerBroadcast::XWorkerBroadcast()
{
}
/* de-constructor */
XWorkerBroadcast::~XWorkerBroadcast()
{
}
/* set the broadcasting type */
void XWorkerBroadcast::SetBroadcastMode(DATA_BROADCAST_TYPE myMode)
{
broadcastMode = myMode;
}
/*
broadcast data
>> source - the data that we want to broadcast
>> targetList - the target places that we recieve the data
>> sleepTime - the waiting time in broadcasting
*/
void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long sleepTime)
{
TensorList & sp = source->params;
int finished = 0;
/* check */
for (int i = 0; i < targetList->count; i++) {
TensorList & tp = ((XModel*)targetList->GetItem(i))->params;
CheckNTErrors(sp.count == tp.count, "Incompatiable models!");
}
/* the major body of broadcasting */
while (1) {
for (int i = 0; i < sp.count; i++) {
if (source->flags[i] == PARAM_STATE_UPDATED) {
for (int j = 0; j < targetList->count; j++) {
XModel * target = (XModel*)targetList->GetItem(j);
TensorList & tp = target->params;
/* data transmit */
BroadcastP2P(sp.GetItem(i), tp.GetItem(i));
/* update the flag */
target->flags[i] = PARAM_STATE_UPDATED;
finished++;
}
}
}
if (finished == sp.count * targetList->count)
break;
}
}
/*
wrapper of BroadcastData
>> args - the list of arguments
*/
void XWorkerBroadcast::Broadcast(XList * args)
{
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(0);
XModel * source = (XModel*)args->GetItem(1);
/* target models */
int targetNum = args->GetItemInt(2);
XList target(targetNum);
for (int i = 0; i < targetNum; i++) {
XModel * model = (XModel*)args->GetItem(3 + i);
target.Add(model);
}
broadcaster->BroadcastData(source, &target, SLEEP_TIME_IN_BROADCASTING);
}
/*
P2P data broadcasting
>> source - the source data
>> target - the target data
*/
void XWorkerBroadcast::BroadcastP2P(XTensor * source, XTensor * target)
{
CheckNTErrors(source != NULL, "The source tensor should not be NULL!");
CheckNTErrors(target != NULL, "The target tensor should not be NULL!");
CheckNTErrors(IsSameShaped(source, target), "The two tensors should be of the same shape!");
CopyValues(*source, *target);
}
/*
add a new job of broadcasting data
>> source - the data that we want to broadcast
>> targetList - the target places that we recieve the data
*/
bool XWorkerBroadcast::AddJobBroadcast(XModel * source, XList * targetList)
{
CheckNTErrors(source != NULL, "no input source tensor!");
CheckNTErrors(targetList != NULL, "no input target tensor list!");
XList args;
args.Add(this);
args.Add(source);
args.AddInt(targetList->count);
args.AddList(targetList);
queue.EnqueueJob((void*)(char*)XWorkerBroadcast::Broadcast, &args);
return true;
}
}
\ No newline at end of file
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* The worker that boradcast the lastest parameters from the server to
* the workers.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-03
* Several visiters will come today, so i have less time for coding.
*/
#ifndef __XWORKERBROADCAST_H__
#define __XWORKERBROADCAST_H__
#include "XWorker.h"
#include "XModel.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#define SLEEP_TIME_IN_BROADCASTING 10
/*
data broadcasting method
1) point-to-point
*/
enum DATA_BROADCAST_TYPE { DATA_BROADCAST_P2P };
/* This class defines a broadcaster that transmits parameters from
a server to workers. */
class XWorkerBroadcast : public XWorker
{
protected:
DATA_BROADCAST_TYPE broadcastMode;
public:
/* constructor */
XWorkerBroadcast();
/* de-constructor */
~XWorkerBroadcast();
/* set the broadcasting type */
void SetBroadcastMode(DATA_BROADCAST_TYPE myMode);
/* broadcast data */
void BroadcastData(XModel * source, XList * targetList, long sleepTime);
/* wrapper of BroadcastData */
static
void Broadcast(XList * args);
/* P2P data broadcasting */
void BroadcastP2P(XTensor * source, XTensor * target);
/* add a new job of broadcasting data */
bool AddJobBroadcast(XModel * source, XList * targetList);
};
}
#endif
\ No newline at end of file
......@@ -99,7 +99,6 @@ void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleep
}
}
}
}
else if (collectMode == DATA_COLLECT_REDUCESUM) {
for (int j = 0; j < tp.count; j++) {
......
......@@ -44,7 +44,8 @@ data collection method
*/
enum DATA_COLLECT_TYPE { DATA_COLLECT_P2P, DATA_COLLECT_REDUCESUM};
/* The class defines the collecting-data worker */
/* The class defines the collecting-data worker. It collect (gradient) data
from workers for the leader (server). */
class XWorkerCollect : public XWorker
{
protected:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论