Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
052a62b5
Commit
052a62b5
authored
Mar 22, 2021
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bug fixes in XQueue and XThread
parent
78307f09
显示空白字符变更
内嵌
并排
正在显示
11 个修改的文件
包含
169 行增加
和
54 行删除
+169
-54
source/tensor/XQueue.cpp
+23
-2
source/tensor/XQueue.h
+8
-1
source/tensor/XThread.cpp
+20
-0
source/tensor/XThread.h
+4
-0
source/train/XLeader.cpp
+18
-13
source/train/XLeader.h
+1
-0
source/train/XModel.cpp
+1
-1
source/train/XWorker.cpp
+6
-0
source/train/XWorker.h
+3
-0
source/train/XWorkerCollect.cpp
+73
-34
source/train/XWorkerCollect.h
+12
-3
没有找到文件。
source/tensor/XQueue.cpp
查看文件 @
052a62b5
...
...
@@ -176,8 +176,9 @@ void XQueue::RunJobConsumer(int jobDevID)
jobDequeuer
.
SetFunc
((
TFunction
)
DequeueJobs
,
jobDequeuerArgs
);
jobDequeuer
.
Start
();
jobDequeuer
.
LetItGo
();
//jobDequeuer.Start();
//jobDequeuer.LetItGo();
jobDequeuer
.
StartNow
();
}
/* stop the job consumer */
...
...
@@ -257,4 +258,24 @@ int XQueue::GetJobNum()
return
c
;
}
/*
get the number of items in the queue. Note that
this function is not the same as GetJobNum() because
"items" are the real elements we put into the queue.
"jobs" only make sense when the queue is running as a
job queue.
*/
int
XQueue
::
GetItemNum
()
{
MUTEX_LOCK
(
enqueueMutex
);
MUTEX_LOCK
(
dequeueMutex
);
int
c
=
itemCount
;
MUTEX_UNLOCK
(
dequeueMutex
);
MUTEX_UNLOCK
(
enqueueMutex
);
return
c
;
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
source/tensor/XQueue.h
查看文件 @
052a62b5
...
...
@@ -144,8 +144,15 @@ public:
/* get the break flag */
bool
GetJobBreak
();
/* get the number of jobs */
/* get the number of
running
jobs */
int
GetJobNum
();
/* get the number of items in the queue. Note that
this function is not the same as GetJobNum() because
"items" are the real elements we put into the queue.
"jobs" only make sense when the queue is running as a
job queue. */
int
GetItemNum
();
};
}
/* end of the nts (NiuTrans.Tensor) namespace */
...
...
source/tensor/XThread.cpp
查看文件 @
052a62b5
...
...
@@ -225,6 +225,26 @@ void XThread::LetItGo()
#endif
}
/*
create the thread and run it immediately (a combination of
Start() and LetItGo() */
bool
XThread
::
StartNow
()
{
CheckNTErrors
(
jobCount
==
0
,
"Cannot start a thread again when it is running!"
);
jobCount
++
;
Start
();
#ifdef _WIN32
MUTEX_LOCK
(
workingMutex
);
COND_RESET
(
jobCond
);
MUTEX_UNLOCK
(
workingMutex
);
COND_SIGNAL
(
jobCond
);
#endif
return
true
;
}
/* waith for a singal */
void
XThread
::
Wait
(
COND_HANDLE
*
c
,
MUTEX_HANDLE
*
m
)
{
...
...
source/tensor/XThread.h
查看文件 @
052a62b5
...
...
@@ -143,6 +143,10 @@ public:
/* let the thread process a job */
void
LetItGo
();
/* create the thread and run it immediately (a combination of
Start() and LetItGo() */
bool
StartNow
();
/* waith for a singal */
static
void
Wait
(
COND_HANDLE
*
c
,
MUTEX_HANDLE
*
m
);
...
...
source/train/XLeader.cpp
查看文件 @
052a62b5
...
...
@@ -182,25 +182,30 @@ void XLeader::WaitForFinishing(const int* activeJobWorkers, const int isToUpdate
XWorker
*
worker
=
(
XWorker
*
)
jworkers
[
i
];
worker
->
DequeueFinishedJob
();
activeCount
++
;
CheckNTErrors
(
worker
->
GetFinishedNumInQueue
()
==
0
,
"Incorrect job number!"
);
}
}
if
(
activeCount
>
0
&&
isToUpdate
)
{
for
(
int
i
=
0
;
i
<
cworkers
.
count
;
i
++
)
{
XWorker
*
worker
=
(
XWorker
*
)
cworkers
[
i
];
for
(
int
j
=
0
;
j
<
serverModel
.
paramNum
*
activeCount
;
j
++
)
worker
->
DequeueFinishedJob
();
CheckNTErrors
(
worker
->
GetFinishedNumInQueue
()
==
0
,
"Incorrect job number!"
);
}
for
(
int
i
=
0
;
i
<
uworkers
.
count
;
i
++
)
{
XWorker
*
worker
=
(
XWorker
*
)
uworkers
[
i
];
for
(
int
j
=
0
;
j
<
serverModel
.
paramNum
;
j
++
)
worker
->
DequeueFinishedJob
();
CheckNTErrors
(
worker
->
GetFinishedNumInQueue
()
==
0
,
"Incorrect job number!"
);
}
for
(
int
i
=
0
;
i
<
bworkers
.
count
;
i
++
)
{
XWorker
*
worker
=
(
XWorker
*
)
bworkers
[
i
];
for
(
int
j
=
0
;
j
<
serverModel
.
paramNum
;
j
++
)
worker
->
DequeueFinishedJob
();
CheckNTErrors
(
worker
->
GetFinishedNumInQueue
()
==
0
,
"Incorrect job number!"
);
}
}
}
...
...
@@ -373,7 +378,6 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, XOptim
CheckNTErrors
(
bworkers
.
count
>
0
,
"No bworkers!"
);
CheckNTErrors
(
pworkers
.
count
>
0
,
"No pworkers!"
);
bool
isDataOK
=
true
;
bool
isToUpdate
=
(
optimizer
!=
NULL
);
int
activeJobCount
=
0
;
int
*
active
=
new
int
[
jworkers
.
count
];
...
...
@@ -526,9 +530,11 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
/* sp[j]->isGradFinished is true only if the model finishes the computation
(in another process) */
if
(
paramSource
.
flag
==
PARAM_STATE_NOT_READY
&&
paramSource
.
param
->
isGradFinished
)
{
XQueue
*
jobQueue
=
(
XQueue
*
)
jobQueues
.
GetItem
(
j
);
/* data transmit */
CollectP2P
(
paramSource
.
param
->
grad
,
paramServer
.
param
->
grad
);
collecter
->
AddJobCollectDataP2P
(
jobQueue
,
paramSource
.
param
->
grad
,
paramServer
.
param
->
grad
);
collecter
->
AddJobEnqueueFinished
();
/* reset the flag */
paramSource
.
flag
=
PARAM_STATE_COLLECTED
;
...
...
@@ -538,21 +544,20 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
/* we call model update (in another thread) and then
broadcast the new parameters to member models
(in another thread) */
if
(
finishedCount
[
j
]
==
member
Active
->
count
)
{
if
(
finishedCount
[
j
]
==
member
s
.
count
)
{
paramServer
.
flag
=
PARAM_STATE_COLLECTED
;
if
(
updater
!=
NULL
)
{
XQueue
*
jobQueue
=
(
XQueue
*
)
jobQueues
->
GetItem
(
j
);
/* update the parameters */
updater
->
AddJobUpdate
(
jobQueue
,
server
,
j
,
optimizer
);
updater
->
AddJobUpdate
(
jobQueue
,
&
serverModel
,
j
,
optimizer
);
updater
->
AddJobEnqueueFinished
(
jobQueue
);
/* broadcast the new parameter to other models*/
broadcaster
->
AddJobBroadcastSingle
(
jobQueue
,
server
,
member
All
,
j
);
broadcaster
->
AddJobBroadcastSingle
(
jobQueue
,
&
serverModel
,
&
members
All
,
j
);
broadcaster
->
AddJobEnqueueFinished
(
jobQueue
);
}
}
else
if
(
finishedCount
[
j
]
>
member
Active
->
count
)
{
else
if
(
finishedCount
[
j
]
>
member
s
.
count
)
{
ShowNTErrors
(
"Something is wrong with finishedCount!"
);
}
}
...
...
@@ -560,10 +565,10 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
}
/* the collection finishes if all data tensors are processed */
if
(
finished
==
server
->
paramNum
*
memberActive
->
count
)
if
(
finished
==
server
Model
.
paramNum
*
members
.
count
)
break
;
XSleep
(
sleepTime
);
XSleep
(
SLEEP_TIME_IN_WAITING_JOB_WORKERS
);
}
delete
[]
finishedCount
;
...
...
@@ -576,10 +581,10 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
broadcast the lastest parameters to workers. NOTE that we would update
a worker to the laster model parameters, even if it is not involved
in this run. */
collecter
->
AddJobUpdateAll
(
&
jobQueues
,
&
members
,
&
membersAll
,
&
serverModel
,
optimizer
,
updater
,
broadcaster
);
collecter
->
AddJobEnqueueFinished
();
//
collecter->AddJobUpdateAll(&jobQueues,
//
&members, &membersAll, &serverModel,
//
optimizer, updater, broadcaster);
//
collecter->AddJobEnqueueFinished();
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
source/train/XLeader.h
查看文件 @
052a62b5
...
...
@@ -50,6 +50,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MAX_NUM_OF_WORKERS 1024
#define SLEEP_TIME_IN_WAITING_FOR_JOBS 20
#define SLEEP_TIME_IN_WAITING_JOB_WORKERS 5
/*
conmmunication mode of a leader. This offers a way of organizing a hierachy of the work
...
...
source/train/XModel.cpp
查看文件 @
052a62b5
...
...
@@ -112,7 +112,7 @@ bool XModel::RunMe(XList * args)
if
(
RunSimple
(
inputs
,
outputs
,
golds
,
losses
))
return
true
;
ShowNTErrors
(
"You must
be
overload one of these: XModel::RunSimple ... !"
);
ShowNTErrors
(
"You must overload one of these: XModel::RunSimple ... !"
);
return
false
;
}
...
...
source/train/XWorker.cpp
查看文件 @
052a62b5
...
...
@@ -177,4 +177,10 @@ void XWorker::AddJobDequeueFinished(XQueue* jobQueue)
}
/* get number of unflaged finished job */
int
XWorker
::
GetFinishedNumInQueue
()
{
return
finishedQueue
.
GetItemNum
();
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
source/train/XWorker.h
查看文件 @
052a62b5
...
...
@@ -126,6 +126,9 @@ public:
/* add a job of dequeuing a counting a finished job */
void
AddJobDequeueFinished
(
XQueue
*
jobQueue
=
NULL
);
/* get number of unflaged finished job */
int
GetFinishedNumInQueue
();
};
}
...
...
source/train/XWorkerCollect.cpp
查看文件 @
052a62b5
...
...
@@ -202,6 +202,53 @@ void XWorkerCollect::UpdateAll(XList * args)
}
/*
add a new job of collecting data, update the parameter and
broadcast the new parameter
>> jobQueues - the queues that we would use in following jobs
>> memberActive - member models that are active, i.e., have generated gradients
>> memberAll - all member models
>> server - the server model
>> optimizer - the optimizer
>> updater - the worker that updates the parameters
>> broadcaster - the worker that broadcasts the new parameters to all member
models
<< return - successful or not
*/
bool
XWorkerCollect
::
AddJobUpdateAll
(
XList
*
jobQueues
,
XList
*
memberActive
,
XList
*
memberAll
,
XModel
*
server
,
XOptimizer
*
optimizer
,
XWorkerUpdate
*
updater
,
XWorkerBroadcast
*
broadcaster
)
{
CheckNTErrors
(
memberActive
!=
NULL
,
"No input (active) member list!"
);
CheckNTErrors
(
memberAll
!=
NULL
,
"No input (all) member list!"
);
CheckNTErrors
(
server
!=
NULL
,
"No input server model!"
);
CheckNTErrors
(
optimizer
!=
NULL
,
"No input optimizer!"
);
CheckNTErrors
(
updater
!=
NULL
,
"No input updater!"
);
CheckNTErrors
(
broadcaster
!=
NULL
,
"No input broadcaster!"
);
XList
args
;
args
.
Add
(
this
);
args
.
AddInt
(
jobQueues
->
count
);
args
.
AddList
(
jobQueues
);
args
.
AddInt
(
memberActive
->
count
);
args
.
AddList
(
memberActive
);
args
.
AddInt
(
memberAll
->
count
);
args
.
AddList
(
memberAll
);
args
.
Add
(
server
);
args
.
Add
(
optimizer
);
args
.
Add
(
updater
);
args
.
Add
(
broadcaster
);
if
(
isInstantRun
)
XWorkerCollect
::
UpdateAll
(
&
args
);
else
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerCollect
::
UpdateAll
,
&
args
);
return
true
;
}
/*
P2P data collection
target += source
...
...
@@ -259,49 +306,41 @@ void XWorkerCollect::CollectAllReduce(XList * all)
ShowNTErrors
(
"TODO!"
);
}
/* wrapper of Collect */
void
XWorkerCollect
::
CollectDataP2P
(
XList
*
args
)
{
int
paramCount
=
0
;
XWorkerCollect
*
collecter
=
(
XWorkerCollect
*
)
args
->
GetItem
(
paramCount
++
);
XTensor
*
source
=
(
XTensor
*
)
args
->
GetItem
(
paramCount
++
);
XTensor
*
target
=
(
XTensor
*
)
args
->
GetItem
(
paramCount
++
);
if
(
collecter
!=
NULL
)
collecter
->
CollectP2P
(
source
,
target
);
}
/*
add a new job of collecting data, update the parameter and
broadcast the new parameter
>> jobQueues - the queues that we would use in following jobs
>> memberActive - member models that are active, i.e., have generated gradients
>> memberAll - all member models
>> server - the server model
>> optimizer - the optimizer
>> updater - the worker that updates the parameters
>> broadcaster - the worker that broadcasts the new parameters to all member
models
<< return - successful or not
add a new job of collecting data
>> jobQueue - the queue where we run the job
>> source - where we collect the data from
>> target - where we place the data (on the server end)
*/
bool
XWorkerCollect
::
AddJobUpdateAll
(
XList
*
jobQueues
,
XList
*
memberActive
,
XList
*
memberAll
,
XModel
*
server
,
XOptimizer
*
optimizer
,
XWorkerUpdate
*
updater
,
XWorkerBroadcast
*
broadcaster
)
bool
XWorkerCollect
::
AddJobCollectDataP2P
(
XQueue
*
jobQueue
,
XTensor
*
source
,
XTensor
*
target
)
{
CheckNTErrors
(
memberActive
!=
NULL
,
"No input (active) member list!"
);
CheckNTErrors
(
memberAll
!=
NULL
,
"No input (all) member list!"
);
CheckNTErrors
(
server
!=
NULL
,
"No input server model!"
);
CheckNTErrors
(
optimizer
!=
NULL
,
"No input optimizer!"
);
CheckNTErrors
(
updater
!=
NULL
,
"No input updater!"
);
CheckNTErrors
(
broadcaster
!=
NULL
,
"No input broadcaster!"
);
CheckNTErrors
(
source
!=
NULL
,
"No input soure tensor!"
);
CheckNTErrors
(
target
!=
NULL
,
"No input target tensor!"
);
XList
args
;
args
.
Add
(
this
);
args
.
AddInt
(
jobQueues
->
count
);
args
.
AddList
(
jobQueues
);
args
.
AddInt
(
memberActive
->
count
);
args
.
AddList
(
memberActive
);
args
.
AddInt
(
memberAll
->
count
);
args
.
AddList
(
memberAll
);
args
.
Add
(
server
);
args
.
Add
(
optimizer
);
args
.
Add
(
updater
);
args
.
Add
(
broadcaster
);
args
.
Add
(
source
);
args
.
Add
(
target
);
XQueue
&
queueRun
=
jobQueue
!=
NULL
?
*
jobQueue
:
queue
;
if
(
isInstantRun
)
XWorkerCollect
::
UpdateAll
(
&
args
);
XWorkerCollect
::
CollectDataP2P
(
&
args
);
else
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerCollect
::
UpdateAll
,
&
args
);
queue
Run
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerCollect
::
CollectDataP2P
,
&
args
);
return
true
;
}
...
...
source/train/XWorkerCollect.h
查看文件 @
052a62b5
...
...
@@ -78,6 +78,10 @@ public:
static
void
UpdateAll
(
XList
*
args
);
/* add a new job of collecting data, update the parameter and broadcast the new parameter */
bool
AddJobUpdateAll
(
XList
*
jobQueues
,
XList
*
memberActive
,
XList
*
memberAll
,
XModel
*
server
,
XOptimizer
*
optimizer
,
XWorkerUpdate
*
updater
,
XWorkerBroadcast
*
broadcaster
);
/* P2P data collection */
void
CollectP2P
(
XTensor
*
source
,
XTensor
*
target
);
...
...
@@ -87,9 +91,14 @@ public:
/* all-reduce */
void
CollectAllReduce
(
XList
*
all
);
/* add a new job of collecting data, update the parameter and broadcast the new parameter */
bool
AddJobUpdateAll
(
XList
*
jobQueues
,
XList
*
memberActive
,
XList
*
memberAll
,
XModel
*
server
,
XOptimizer
*
optimizer
,
XWorkerUpdate
*
updater
,
XWorkerBroadcast
*
broadcaster
);
/* wrapper of Collect */
static
void
CollectDataP2P
(
XList
*
args
);
/* add a new job of collecting data */
bool
AddJobCollectDataP2P
(
XQueue
*
jobQueue
,
XTensor
*
source
,
XTensor
*
target
);
};
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论