Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
052a62b5
Commit
052a62b5
authored
Mar 22, 2021
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bug fixes in XQueue and XThread
parent
78307f09
隐藏空白字符变更
内嵌
并排
正在显示
11 个修改的文件
包含
178 行增加
和
64 行删除
+178
-64
source/tensor/XQueue.cpp
+23
-2
source/tensor/XQueue.h
+9
-3
source/tensor/XThread.cpp
+20
-0
source/tensor/XThread.h
+4
-0
source/train/XLeader.cpp
+22
-17
source/train/XLeader.h
+1
-0
source/train/XModel.cpp
+1
-1
source/train/XWorker.cpp
+6
-0
source/train/XWorker.h
+3
-0
source/train/XWorkerCollect.cpp
+77
-38
source/train/XWorkerCollect.h
+12
-3
没有找到文件。
source/tensor/XQueue.cpp
查看文件 @
052a62b5
...
@@ -176,8 +176,9 @@ void XQueue::RunJobConsumer(int jobDevID)
...
@@ -176,8 +176,9 @@ void XQueue::RunJobConsumer(int jobDevID)
jobDequeuer
.
SetFunc
((
TFunction
)
DequeueJobs
,
jobDequeuerArgs
);
jobDequeuer
.
SetFunc
((
TFunction
)
DequeueJobs
,
jobDequeuerArgs
);
jobDequeuer
.
Start
();
//jobDequeuer.Start();
jobDequeuer
.
LetItGo
();
//jobDequeuer.LetItGo();
jobDequeuer
.
StartNow
();
}
}
/* stop the job consumer */
/* stop the job consumer */
...
@@ -256,5 +257,25 @@ int XQueue::GetJobNum()
...
@@ -256,5 +257,25 @@ int XQueue::GetJobNum()
return
c
;
return
c
;
}
}
/*
get the number of items in the queue. Note that
this function is not the same as GetJobNum() because
"items" are the real elements we put into the queue.
"jobs" only make sense when the queue is running as a
job queue.
*/
int
XQueue
::
GetItemNum
()
{
MUTEX_LOCK
(
enqueueMutex
);
MUTEX_LOCK
(
dequeueMutex
);
int
c
=
itemCount
;
MUTEX_UNLOCK
(
dequeueMutex
);
MUTEX_UNLOCK
(
enqueueMutex
);
return
c
;
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
}
/* end of the nts (NiuTrans.Tensor) namespace */
source/tensor/XQueue.h
查看文件 @
052a62b5
...
@@ -144,10 +144,17 @@ public:
...
@@ -144,10 +144,17 @@ public:
/* get the break flag */
/* get the break flag */
bool
GetJobBreak
();
bool
GetJobBreak
();
/* get the number of jobs */
/* get the number of
running
jobs */
int
GetJobNum
();
int
GetJobNum
();
/* get the number of items in the queue. Note that
this function is not the same as GetJobNum() because
"items" are the real elements we put into the queue.
"jobs" only make sense when the queue is running as a
job queue. */
int
GetItemNum
();
};
};
}
/* end of the nts (NiuTrans.Tensor) namespace */
}
/* end of the nts (NiuTrans.Tensor) namespace */
#endif
#endif
\ No newline at end of file
source/tensor/XThread.cpp
查看文件 @
052a62b5
...
@@ -224,6 +224,26 @@ void XThread::LetItGo()
...
@@ -224,6 +224,26 @@ void XThread::LetItGo()
#endif
#endif
#endif
#endif
}
}
/*
create the thread and run it immediately (a combination of
Start() and LetItGo() */
bool
XThread
::
StartNow
()
{
CheckNTErrors
(
jobCount
==
0
,
"Cannot start a thread again when it is running!"
);
jobCount
++
;
Start
();
#ifdef _WIN32
MUTEX_LOCK
(
workingMutex
);
COND_RESET
(
jobCond
);
MUTEX_UNLOCK
(
workingMutex
);
COND_SIGNAL
(
jobCond
);
#endif
return
true
;
}
/* waith for a singal */
/* waith for a singal */
void
XThread
::
Wait
(
COND_HANDLE
*
c
,
MUTEX_HANDLE
*
m
)
void
XThread
::
Wait
(
COND_HANDLE
*
c
,
MUTEX_HANDLE
*
m
)
...
...
source/tensor/XThread.h
查看文件 @
052a62b5
...
@@ -142,6 +142,10 @@ public:
...
@@ -142,6 +142,10 @@ public:
/* let the thread process a job */
/* let the thread process a job */
void
LetItGo
();
void
LetItGo
();
/* create the thread and run it immediately (a combination of
Start() and LetItGo() */
bool
StartNow
();
/* waith for a singal */
/* waith for a singal */
static
static
...
...
source/train/XLeader.cpp
查看文件 @
052a62b5
...
@@ -182,25 +182,30 @@ void XLeader::WaitForFinishing(const int* activeJobWorkers, const int isToUpdate
...
@@ -182,25 +182,30 @@ void XLeader::WaitForFinishing(const int* activeJobWorkers, const int isToUpdate
XWorker
*
worker
=
(
XWorker
*
)
jworkers
[
i
];
XWorker
*
worker
=
(
XWorker
*
)
jworkers
[
i
];
worker
->
DequeueFinishedJob
();
worker
->
DequeueFinishedJob
();
activeCount
++
;
activeCount
++
;
CheckNTErrors
(
worker
->
GetFinishedNumInQueue
()
==
0
,
"Incorrect job number!"
);
}
}
}
}
if
(
activeCount
>
0
&&
isToUpdate
)
{
if
(
activeCount
>
0
&&
isToUpdate
)
{
for
(
int
i
=
0
;
i
<
cworkers
.
count
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cworkers
.
count
;
i
++
)
{
XWorker
*
worker
=
(
XWorker
*
)
cworkers
[
i
];
XWorker
*
worker
=
(
XWorker
*
)
cworkers
[
i
];
worker
->
DequeueFinishedJob
();
for
(
int
j
=
0
;
j
<
serverModel
.
paramNum
*
activeCount
;
j
++
)
worker
->
DequeueFinishedJob
();
CheckNTErrors
(
worker
->
GetFinishedNumInQueue
()
==
0
,
"Incorrect job number!"
);
}
}
for
(
int
i
=
0
;
i
<
uworkers
.
count
;
i
++
)
{
for
(
int
i
=
0
;
i
<
uworkers
.
count
;
i
++
)
{
XWorker
*
worker
=
(
XWorker
*
)
uworkers
[
i
];
XWorker
*
worker
=
(
XWorker
*
)
uworkers
[
i
];
for
(
int
j
=
0
;
j
<
serverModel
.
paramNum
;
j
++
)
for
(
int
j
=
0
;
j
<
serverModel
.
paramNum
;
j
++
)
worker
->
DequeueFinishedJob
();
worker
->
DequeueFinishedJob
();
CheckNTErrors
(
worker
->
GetFinishedNumInQueue
()
==
0
,
"Incorrect job number!"
);
}
}
for
(
int
i
=
0
;
i
<
bworkers
.
count
;
i
++
)
{
for
(
int
i
=
0
;
i
<
bworkers
.
count
;
i
++
)
{
XWorker
*
worker
=
(
XWorker
*
)
bworkers
[
i
];
XWorker
*
worker
=
(
XWorker
*
)
bworkers
[
i
];
for
(
int
j
=
0
;
j
<
serverModel
.
paramNum
;
j
++
)
for
(
int
j
=
0
;
j
<
serverModel
.
paramNum
;
j
++
)
worker
->
DequeueFinishedJob
();
worker
->
DequeueFinishedJob
();
CheckNTErrors
(
worker
->
GetFinishedNumInQueue
()
==
0
,
"Incorrect job number!"
);
}
}
}
}
}
}
...
@@ -373,7 +378,6 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, XOptim
...
@@ -373,7 +378,6 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, XOptim
CheckNTErrors
(
bworkers
.
count
>
0
,
"No bworkers!"
);
CheckNTErrors
(
bworkers
.
count
>
0
,
"No bworkers!"
);
CheckNTErrors
(
pworkers
.
count
>
0
,
"No pworkers!"
);
CheckNTErrors
(
pworkers
.
count
>
0
,
"No pworkers!"
);
bool
isDataOK
=
true
;
bool
isToUpdate
=
(
optimizer
!=
NULL
);
bool
isToUpdate
=
(
optimizer
!=
NULL
);
int
activeJobCount
=
0
;
int
activeJobCount
=
0
;
int
*
active
=
new
int
[
jworkers
.
count
];
int
*
active
=
new
int
[
jworkers
.
count
];
...
@@ -430,8 +434,8 @@ int XLeader::RunModel(XConfig * config, DataDistributeBase * dataDistributor, in
...
@@ -430,8 +434,8 @@ int XLeader::RunModel(XConfig * config, DataDistributeBase * dataDistributor, in
/* job in queue 1: run the model */
/* job in queue 1: run the model */
worker
->
AddJobNeuralNet
(
jmodel
,
worker
->
AddJobNeuralNet
(
jmodel
,
worker
->
GetInput
(),
worker
->
GetOutput
(),
worker
->
GetInput
(),
worker
->
GetOutput
(),
worker
->
GetGold
(),
worker
->
GetLoss
());
worker
->
GetGold
(),
worker
->
GetLoss
());
/* job in queue 1: make a record of the run */
/* job in queue 1: make a record of the run */
worker
->
AddJobRecord
(
&
serverRecord
);
worker
->
AddJobRecord
(
&
serverRecord
);
...
@@ -526,9 +530,11 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
...
@@ -526,9 +530,11 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
/* sp[j]->isGradFinished is true only if the model finishes the computation
/* sp[j]->isGradFinished is true only if the model finishes the computation
(in another process) */
(in another process) */
if
(
paramSource
.
flag
==
PARAM_STATE_NOT_READY
&&
paramSource
.
param
->
isGradFinished
)
{
if
(
paramSource
.
flag
==
PARAM_STATE_NOT_READY
&&
paramSource
.
param
->
isGradFinished
)
{
XQueue
*
jobQueue
=
(
XQueue
*
)
jobQueues
.
GetItem
(
j
);
/* data transmit */
/* data transmit */
CollectP2P
(
paramSource
.
param
->
grad
,
paramServer
.
param
->
grad
);
collecter
->
AddJobCollectDataP2P
(
jobQueue
,
paramSource
.
param
->
grad
,
paramServer
.
param
->
grad
);
collecter
->
AddJobEnqueueFinished
();
/* reset the flag */
/* reset the flag */
paramSource
.
flag
=
PARAM_STATE_COLLECTED
;
paramSource
.
flag
=
PARAM_STATE_COLLECTED
;
...
@@ -538,21 +544,20 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
...
@@ -538,21 +544,20 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
/* we call model update (in another thread) and then
/* we call model update (in another thread) and then
broadcast the new parameters to member models
broadcast the new parameters to member models
(in another thread) */
(in another thread) */
if
(
finishedCount
[
j
]
==
member
Active
->
count
)
{
if
(
finishedCount
[
j
]
==
member
s
.
count
)
{
paramServer
.
flag
=
PARAM_STATE_COLLECTED
;
paramServer
.
flag
=
PARAM_STATE_COLLECTED
;
if
(
updater
!=
NULL
)
{
if
(
updater
!=
NULL
)
{
XQueue
*
jobQueue
=
(
XQueue
*
)
jobQueues
->
GetItem
(
j
);
/* update the parameters */
/* update the parameters */
updater
->
AddJobUpdate
(
jobQueue
,
server
,
j
,
optimizer
);
updater
->
AddJobUpdate
(
jobQueue
,
&
serverModel
,
j
,
optimizer
);
updater
->
AddJobEnqueueFinished
(
jobQueue
);
updater
->
AddJobEnqueueFinished
(
jobQueue
);
/* broadcast the new parameter to other models*/
/* broadcast the new parameter to other models*/
broadcaster
->
AddJobBroadcastSingle
(
jobQueue
,
server
,
member
All
,
j
);
broadcaster
->
AddJobBroadcastSingle
(
jobQueue
,
&
serverModel
,
&
members
All
,
j
);
broadcaster
->
AddJobEnqueueFinished
(
jobQueue
);
broadcaster
->
AddJobEnqueueFinished
(
jobQueue
);
}
}
}
}
else
if
(
finishedCount
[
j
]
>
member
Active
->
count
)
{
else
if
(
finishedCount
[
j
]
>
member
s
.
count
)
{
ShowNTErrors
(
"Something is wrong with finishedCount!"
);
ShowNTErrors
(
"Something is wrong with finishedCount!"
);
}
}
}
}
...
@@ -560,10 +565,10 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
...
@@ -560,10 +565,10 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
}
}
/* the collection finishes if all data tensors are processed */
/* the collection finishes if all data tensors are processed */
if
(
finished
==
server
->
paramNum
*
memberActive
->
count
)
if
(
finished
==
server
Model
.
paramNum
*
members
.
count
)
break
;
break
;
XSleep
(
sleepTime
);
XSleep
(
SLEEP_TIME_IN_WAITING_JOB_WORKERS
);
}
}
delete
[]
finishedCount
;
delete
[]
finishedCount
;
...
@@ -576,10 +581,10 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
...
@@ -576,10 +581,10 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
broadcast the lastest parameters to workers. NOTE that we would update
broadcast the lastest parameters to workers. NOTE that we would update
a worker to the laster model parameters, even if it is not involved
a worker to the laster model parameters, even if it is not involved
in this run. */
in this run. */
collecter
->
AddJobUpdateAll
(
&
jobQueues
,
//
collecter->AddJobUpdateAll(&jobQueues,
&
members
,
&
membersAll
,
&
serverModel
,
//
&members, &membersAll, &serverModel,
optimizer
,
updater
,
broadcaster
);
//
optimizer, updater, broadcaster);
collecter
->
AddJobEnqueueFinished
();
//
collecter->AddJobEnqueueFinished();
}
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
}
/* end of the nts (NiuTrans.Tensor) namespace */
source/train/XLeader.h
查看文件 @
052a62b5
...
@@ -50,6 +50,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
...
@@ -50,6 +50,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MAX_NUM_OF_WORKERS 1024
#define MAX_NUM_OF_WORKERS 1024
#define SLEEP_TIME_IN_WAITING_FOR_JOBS 20
#define SLEEP_TIME_IN_WAITING_FOR_JOBS 20
#define SLEEP_TIME_IN_WAITING_JOB_WORKERS 5
/*
/*
conmmunication mode of a leader. This offers a way of organizing a hierachy of the work
conmmunication mode of a leader. This offers a way of organizing a hierachy of the work
...
...
source/train/XModel.cpp
查看文件 @
052a62b5
...
@@ -112,7 +112,7 @@ bool XModel::RunMe(XList * args)
...
@@ -112,7 +112,7 @@ bool XModel::RunMe(XList * args)
if
(
RunSimple
(
inputs
,
outputs
,
golds
,
losses
))
if
(
RunSimple
(
inputs
,
outputs
,
golds
,
losses
))
return
true
;
return
true
;
ShowNTErrors
(
"You must
be
overload one of these: XModel::RunSimple ... !"
);
ShowNTErrors
(
"You must overload one of these: XModel::RunSimple ... !"
);
return
false
;
return
false
;
}
}
...
...
source/train/XWorker.cpp
查看文件 @
052a62b5
...
@@ -176,5 +176,11 @@ void XWorker::AddJobDequeueFinished(XQueue* jobQueue)
...
@@ -176,5 +176,11 @@ void XWorker::AddJobDequeueFinished(XQueue* jobQueue)
queueRun
.
EnqueueJob
((
void
*
)(
char
*
)
XWorker
::
DequeueFinished
,
&
args
);
queueRun
.
EnqueueJob
((
void
*
)(
char
*
)
XWorker
::
DequeueFinished
,
&
args
);
}
}
/* get number of unflaged finished job */
int
XWorker
::
GetFinishedNumInQueue
()
{
return
finishedQueue
.
GetItemNum
();
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
}
/* end of the nts (NiuTrans.Tensor) namespace */
source/train/XWorker.h
查看文件 @
052a62b5
...
@@ -126,6 +126,9 @@ public:
...
@@ -126,6 +126,9 @@ public:
/* add a job of dequeuing a counting a finished job */
/* add a job of dequeuing a counting a finished job */
void
AddJobDequeueFinished
(
XQueue
*
jobQueue
=
NULL
);
void
AddJobDequeueFinished
(
XQueue
*
jobQueue
=
NULL
);
/* get number of unflaged finished job */
int
GetFinishedNumInQueue
();
};
};
}
}
...
...
source/train/XWorkerCollect.cpp
查看文件 @
052a62b5
...
@@ -200,6 +200,53 @@ void XWorkerCollect::UpdateAll(XList * args)
...
@@ -200,6 +200,53 @@ void XWorkerCollect::UpdateAll(XList * args)
optimizer
,
updater
,
broadcaster
,
optimizer
,
updater
,
broadcaster
,
SLEEP_TIME_IN_COLLECTING
);
SLEEP_TIME_IN_COLLECTING
);
}
}
/*
add a new job of collecting data, update the parameter and
broadcast the new parameter
>> jobQueues - the queues that we would use in following jobs
>> memberActive - member models that are active, i.e., have generated gradients
>> memberAll - all member models
>> server - the server model
>> optimizer - the optimizer
>> updater - the worker that updates the parameters
>> broadcaster - the worker that broadcasts the new parameters to all member
models
<< return - successful or not
*/
bool
XWorkerCollect
::
AddJobUpdateAll
(
XList
*
jobQueues
,
XList
*
memberActive
,
XList
*
memberAll
,
XModel
*
server
,
XOptimizer
*
optimizer
,
XWorkerUpdate
*
updater
,
XWorkerBroadcast
*
broadcaster
)
{
CheckNTErrors
(
memberActive
!=
NULL
,
"No input (active) member list!"
);
CheckNTErrors
(
memberAll
!=
NULL
,
"No input (all) member list!"
);
CheckNTErrors
(
server
!=
NULL
,
"No input server model!"
);
CheckNTErrors
(
optimizer
!=
NULL
,
"No input optimizer!"
);
CheckNTErrors
(
updater
!=
NULL
,
"No input updater!"
);
CheckNTErrors
(
broadcaster
!=
NULL
,
"No input broadcaster!"
);
XList
args
;
args
.
Add
(
this
);
args
.
AddInt
(
jobQueues
->
count
);
args
.
AddList
(
jobQueues
);
args
.
AddInt
(
memberActive
->
count
);
args
.
AddList
(
memberActive
);
args
.
AddInt
(
memberAll
->
count
);
args
.
AddList
(
memberAll
);
args
.
Add
(
server
);
args
.
Add
(
optimizer
);
args
.
Add
(
updater
);
args
.
Add
(
broadcaster
);
if
(
isInstantRun
)
XWorkerCollect
::
UpdateAll
(
&
args
);
else
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerCollect
::
UpdateAll
,
&
args
);
return
true
;
}
/*
/*
P2P data collection
P2P data collection
...
@@ -258,51 +305,43 @@ void XWorkerCollect::CollectAllReduce(XList * all)
...
@@ -258,51 +305,43 @@ void XWorkerCollect::CollectAllReduce(XList * all)
{
{
ShowNTErrors
(
"TODO!"
);
ShowNTErrors
(
"TODO!"
);
}
}
/* wrapper of Collect */
void
XWorkerCollect
::
CollectDataP2P
(
XList
*
args
)
{
int
paramCount
=
0
;
XWorkerCollect
*
collecter
=
(
XWorkerCollect
*
)
args
->
GetItem
(
paramCount
++
);
XTensor
*
source
=
(
XTensor
*
)
args
->
GetItem
(
paramCount
++
);
XTensor
*
target
=
(
XTensor
*
)
args
->
GetItem
(
paramCount
++
);
if
(
collecter
!=
NULL
)
collecter
->
CollectP2P
(
source
,
target
);
}
/*
/*
add a new job of collecting data, update the parameter and
add a new job of collecting data
broadcast the new parameter
>> jobQueue - the queue where we run the job
>> jobQueues - the queues that we would use in following jobs
>> source - where we collect the data from
>> memberActive - member models that are active, i.e., have generated gradients
>> target - where we place the data (on the server end)
>> memberAll - all member models
>> server - the server model
>> optimizer - the optimizer
>> updater - the worker that updates the parameters
>> broadcaster - the worker that broadcasts the new parameters to all member
models
<< return - successful or not
*/
*/
bool
XWorkerCollect
::
AddJobUpdateAll
(
XList
*
jobQueues
,
bool
XWorkerCollect
::
AddJobCollectDataP2P
(
XQueue
*
jobQueue
,
XTensor
*
source
,
XTensor
*
target
)
XList
*
memberActive
,
XList
*
memberAll
,
XModel
*
server
,
XOptimizer
*
optimizer
,
XWorkerUpdate
*
updater
,
XWorkerBroadcast
*
broadcaster
)
{
{
CheckNTErrors
(
memberActive
!=
NULL
,
"No input (active) member list!"
);
CheckNTErrors
(
source
!=
NULL
,
"No input soure tensor!"
);
CheckNTErrors
(
memberAll
!=
NULL
,
"No input (all) member list!"
);
CheckNTErrors
(
target
!=
NULL
,
"No input target tensor!"
);
CheckNTErrors
(
server
!=
NULL
,
"No input server model!"
);
CheckNTErrors
(
optimizer
!=
NULL
,
"No input optimizer!"
);
CheckNTErrors
(
updater
!=
NULL
,
"No input updater!"
);
CheckNTErrors
(
broadcaster
!=
NULL
,
"No input broadcaster!"
);
XList
args
;
XList
args
;
args
.
Add
(
this
);
args
.
Add
(
this
);
args
.
AddInt
(
jobQueues
->
count
);
args
.
Add
(
source
);
args
.
AddList
(
jobQueues
);
args
.
Add
(
target
);
args
.
AddInt
(
memberActive
->
count
);
args
.
AddList
(
memberActive
);
XQueue
&
queueRun
=
jobQueue
!=
NULL
?
*
jobQueue
:
queue
;
args
.
AddInt
(
memberAll
->
count
);
args
.
AddList
(
memberAll
);
args
.
Add
(
server
);
args
.
Add
(
optimizer
);
args
.
Add
(
updater
);
args
.
Add
(
broadcaster
);
if
(
isInstantRun
)
if
(
isInstantRun
)
XWorkerCollect
::
UpdateAll
(
&
args
);
XWorkerCollect
::
CollectDataP2P
(
&
args
);
else
else
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerCollect
::
UpdateAll
,
&
args
);
queue
Run
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerCollect
::
CollectDataP2P
,
&
args
);
return
true
;
return
true
;
}
}
...
...
source/train/XWorkerCollect.h
查看文件 @
052a62b5
...
@@ -77,6 +77,10 @@ public:
...
@@ -77,6 +77,10 @@ public:
/* wrapper of UpdateDataAll */
/* wrapper of UpdateDataAll */
static
static
void
UpdateAll
(
XList
*
args
);
void
UpdateAll
(
XList
*
args
);
/* add a new job of collecting data, update the parameter and broadcast the new parameter */
bool
AddJobUpdateAll
(
XList
*
jobQueues
,
XList
*
memberActive
,
XList
*
memberAll
,
XModel
*
server
,
XOptimizer
*
optimizer
,
XWorkerUpdate
*
updater
,
XWorkerBroadcast
*
broadcaster
);
/* P2P data collection */
/* P2P data collection */
void
CollectP2P
(
XTensor
*
source
,
XTensor
*
target
);
void
CollectP2P
(
XTensor
*
source
,
XTensor
*
target
);
...
@@ -86,10 +90,15 @@ public:
...
@@ -86,10 +90,15 @@ public:
/* all-reduce */
/* all-reduce */
void
CollectAllReduce
(
XList
*
all
);
void
CollectAllReduce
(
XList
*
all
);
/* wrapper of Collect */
static
void
CollectDataP2P
(
XList
*
args
);
/* add a new job of collecting data */
bool
AddJobCollectDataP2P
(
XQueue
*
jobQueue
,
XTensor
*
source
,
XTensor
*
target
);
/* add a new job of collecting data, update the parameter and broadcast the new parameter */
bool
AddJobUpdateAll
(
XList
*
jobQueues
,
XList
*
memberActive
,
XList
*
memberAll
,
XModel
*
server
,
XOptimizer
*
optimizer
,
XWorkerUpdate
*
updater
,
XWorkerBroadcast
*
broadcaster
);
};
};
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论