Commit d39baab3 by xiaotong

new classes: XWorkerCollect

parent 97da3265
......@@ -106,7 +106,10 @@ run the model (for one time)
void XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
XModel * model, XOptimizer * optimizer)
{
/* feed the input to each worker and geneate the output */
/* Feed the input to each worker and geneate the output.
For each worker, we define a job queue and enqueue jobs
into it.
*/
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i];
XModel * model = worker->GetModel();
......
......@@ -36,12 +36,14 @@ namespace nts {
/* constructor */
XModel::XModel()
{
flags = NULL;
MUTEX_INIT(modelMutex);
}
/* de-constructor */
XModel::~XModel()
{
delete[] flags;
Clear();
MUTEX_DELE(modelMutex);
}
......@@ -80,6 +82,12 @@ void XModel::RefreshMe()
XTensor * param = params.GetItem(i);
param->isGradFinished = false;
}
delete[] flags;
flags = new PARAM_STATE[params.count];
for (int i = 0; i < params.count; i++) {
flags[i] = PARAM_STATE_NOT_READY;
}
}
/* wrapper of RefreshMe */
......
......@@ -38,6 +38,14 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
parameter state
1) not ready
2) ready
3) the parameter has been collected from other models
*/
enum PARAM_STATE { PARAM_STATE_NOT_READY, PARAM_STATE_READY, PARAM_STATE_COLLECTED };
/* a model template for training */
class XModel
{
......@@ -49,6 +57,9 @@ public:
/* the list of model parameters (pointers to the parameter tensor) */
TensorList params;
/* flags of the parameters */
PARAM_STATE * flags;
public:
/* constructor */
......
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* The worker that collects data from workers.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-01
*/
#include "XWorkerCollect.h"
#include "../tensor/core/CHeader.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */
XWorkerCollect::XWorkerCollect()
{
collectMode = DATA_COLLECT_P2P;
}
/* de-constructor */
XWorkerCollect::~XWorkerCollect()
{
}
/* set the collection type */
void XWorkerCollect::SetCollectMode(DATA_COLLECT_TYPE myMode)
{
collectMode = myMode;
}
/*
collect data
>> sourceList - the list of data tensors we collect data from
>> target - the target tensor we place the result, that is
target += \sum_i source_i
*/
void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleepTime)
{
TensorList & tp = target->params;
int finished = 0;
/* check */
for (int i = 0; i < sourceList->count; i++) {
TensorList & sp = ((XModel*)sourceList->GetItem(i))->params;
CheckNTErrors(sp.count == tp.count, "Incompatiable models!");
}
/* This is a simple implementation of the wait-and-collect process. But
there is a risk that some models are not available, that is, the
loop would never stop. A solution might be that we force the loop
to break after waiting for a short time. */
while (1) {
if (collectMode == DATA_COLLECT_P2P) {
for (int j = 0; j < tp.count; j++) {
/* target->flags[j] is ready only if model finishes the computation
(in another process) */
if (target->flags[j] != PARAM_STATE_READY)
continue;
/* check if all the models (or part of them) are ready */
for (int i = 0; i < sourceList->count; i++) {
XModel * source = (XModel*)sourceList->GetItem(i);
TensorList & sp = source->params;
/* source->flags[j] is ready only if model finishes the computation
(in another process) */
if (source->flags[j] == PARAM_STATE_READY) {
/* data transmit */
CollectP2P(sp.GetItem(j), tp.GetItem(j));
/* reset the flag */
source->flags[j] = PARAM_STATE_COLLECTED;
finished++;
}
}
}
}
else if (collectMode == DATA_COLLECT_REDUCESUM) {
for (int j = 0; j < tp.count; j++) {
bool ready = true;
/* target->flags[j] is ready only if model finishes the computation
(in another process) */
if (target->flags[j] != PARAM_STATE_READY)
continue;
/* check if all the models (or part of them) are ready */
for (int i = 0; i < sourceList->count; i++) {
XModel * source = (XModel*)sourceList->GetItem(i);
TensorList & sp = source->params;
/* source->flags[j] is ready only if model finishes the computation
(in another process) */
if (source->flags[j] != PARAM_STATE_READY) {
ready = false;
break;
}
}
if (ready) {
XList tensorList(sourceList->count);
for (int i = 0; i < sourceList->count; i++) {
XModel * source = (XModel*)sourceList->GetItem(i);
TensorList & sp = source->params;
tensorList.Add(sp.GetItem(j));
}
/* data transmit */
CollectReduceSum(&tensorList, tp.GetItem(j));
/* reset the flags */
for (int i = 0; i < sourceList->count; i++) {
XModel * source = (XModel*)sourceList->GetItem(i);
source->flags[j] = PARAM_STATE_COLLECTED;
}
finished += sourceList->count;
}
}
}
else {
ShowNTErrors("Unsupported data collection mode!");
}
/* the collection finishes if all data tensors are processed */
if (finished == tp.count * sourceList->count)
break;
/* reset the flags */
for (int i = 0; i < tp.count; i++) {
target->flags[i] = PARAM_STATE_COLLECTED;
}
#ifdef _WIN32
Sleep((DWORD)sleepTime);
#else
sleep(sleepTime / 1000);
#endif
}
}
/* wrapper of CollectData */
void XWorkerCollect::Collect(XList * args)
{
XWorkerCollect * collecter = (XWorkerCollect*)args->GetItem(0);
int sourceNum = args->GetItemInt(1);
/* the source models */
XList source(sourceNum);
for (int i = 0; i < sourceNum; i++) {
XModel * model = (XModel*)args->GetItem(2 + i);
source.Add(model);
}
/* the target model */
XModel * target = (XModel*)args->GetItem(2 + sourceNum);
collecter->CollectData(&source, target, SLEEP_TIME_IN_COLLECTING);
}
/*
P2P data collection
target += source
>> source - the source tensor
>> target - the target tensor
*/
void XWorkerCollect::CollectP2P(XTensor * source, XTensor * target)
{
CheckNTErrors(source != NULL, "The source tensor should not be NULL!");
CheckNTErrors(target != NULL, "The target tensor should not be NULL!");
CheckNTErrors(IsSameShaped(source, target), "The two tensors should be of the same shape!");
/* target += source */
Sum(*source, *target, *source);
}
/*
sum-reduce for given tensors
target += source_0
target += source_1
...
target += source_n
>> source - the source tensor
>> target - the target tensor
*/
void XWorkerCollect::CollectReduceSum(XList * source, XTensor * target)
{
for (int i = 0; i < source->count; i++) {
XTensor * s = (XTensor*)source->GetItem(i);
CollectP2P(s, target);
}
}
/*
all-reduce: the well-known all-reduce method
every tensor is involved in every data transmition. The final outcome
is that all input tensors share the same value (i.e., the sum of them).
>> all - the tensors for sum
*/
void XWorkerCollect::CollectAllReduce(XList * all)
{
ShowNTErrors("TODO!");
}
}
\ No newline at end of file
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* The worker that collects data from workers.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-01
*/
#ifndef __XWORKERCOLLECT_H__
#define __XWORKERCOLLECT_H__
#include "XWorker.h"
#include "XModel.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#define SLEEP_TIME_IN_COLLECTING 10
/*
data collection method
1) point-to-point
2) reduce sum
3) all-reduce
*/
enum DATA_COLLECT_TYPE { DATA_COLLECT_P2P, DATA_COLLECT_REDUCESUM};
/* The class defines the collecting-data worker */
class XWorkerCollect : public XWorker
{
protected:
DATA_COLLECT_TYPE collectMode;
public:
/* constructor */
XWorkerCollect();
/* de-constructor */
~XWorkerCollect();
/* set the collection type */
void SetCollectMode(DATA_COLLECT_TYPE myMode);
/* collect data */
void CollectData(XList * sourceList, XModel * target, long sleepTime);
/* wrapper of CollectData */
static
void Collect(XList * args);
/* P2P data collection */
void CollectP2P(XTensor * source, XTensor * target);
/* sum-reduce for given tensors */
void CollectReduceSum(XList * source, XTensor * target);
/* all-reduce */
void CollectAllReduce(XList * all);
};
}
#endif
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* The worker that updates the model.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-01
*/
#include "XWorkerUpdate.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
}
\ No newline at end of file
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* The worker that updates the model.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-01
*/
#ifndef __XWORKERUPDATE_H__
#define __XWORKERUPDATE_H__
#include "XWorker.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* The class defines the model-update worker */
class XWorkerUpdate : public XWorker
{
public:
};
}
#endif
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论