Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
554bdfd6
Commit
554bdfd6
authored
Mar 02, 2021
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
updates of XLeader
parent
a864edb1
隐藏空白字符变更
内嵌
并排
正在显示
12 个修改的文件
包含
460 行增加
和
25 行删除
+460
-25
source/train/XLeader.cpp
+92
-9
source/train/XLeader.h
+19
-2
source/train/XModel.h
+5
-1
source/train/XOptimizer.cpp
+58
-1
source/train/XOptimizer.h
+41
-1
source/train/XTrainer.cpp
+73
-2
source/train/XTrainer.h
+8
-3
source/train/XWorker.cpp
+2
-1
source/train/XWorkerCollect.cpp
+25
-3
source/train/XWorkerCollect.h
+5
-1
source/train/XWorkerUpdate.cpp
+104
-1
source/train/XWorkerUpdate.h
+28
-0
没有找到文件。
source/train/XLeader.cpp
查看文件 @
554bdfd6
...
...
@@ -50,10 +50,17 @@ XLeader::~XLeader()
/* intialize the leader */
void
XLeader
::
Init
()
{
for
(
int
i
=
0
;
i
<
jworkers
.
count
;
i
++
)
{
for
(
int
i
=
0
;
i
<
jworkers
.
count
;
i
++
)
delete
(
XWorkerJob
*
)
jworkers
.
GetItem
(
i
);
}
jworkers
.
Clear
();
for
(
int
i
=
0
;
i
<
cworkers
.
count
;
i
++
)
delete
(
XWorkerCollect
*
)
cworkers
.
GetItem
(
i
);
cworkers
.
Clear
();
for
(
int
i
=
0
;
i
<
uworkers
.
count
;
i
++
)
delete
(
XWorkerUpdate
*
)
uworkers
.
GetItem
(
i
);
uworkers
.
Clear
();
}
/* set id */
...
...
@@ -76,6 +83,24 @@ void XLeader::SetMode(XLEADER_MODE myMode)
{
mode
=
myMode
;
}
/* start the workers */
void
XLeader
::
Start
()
{
for
(
int
i
=
0
;
i
<
jworkers
.
count
;
i
++
)
{
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
jworkers
.
GetItem
(
i
);
worker
->
Start
();
}
for
(
int
i
=
0
;
i
<
cworkers
.
count
;
i
++
)
{
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
cworkers
.
GetItem
(
i
);
worker
->
Start
();
}
for
(
int
i
=
0
;
i
<
uworkers
.
count
;
i
++
)
{
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
uworkers
.
GetItem
(
i
);
worker
->
Start
();
}
}
/*
add a number of job workers (given their device ids)
...
...
@@ -87,25 +112,55 @@ void XLeader::AddJobWorker(XModel * model, int n, int * ids)
{
/* we keep the input model */
if
(
n
>=
1
)
{
jworkers
.
Add
(
model
);
XWorkerJob
*
worker
=
new
XWorkerJob
();
worker
->
SetModel
(
model
);
jworkers
.
Add
(
worker
);
}
/* we clone the input model */
for
(
int
i
=
0
;
i
<
n
-
1
;
i
++
)
{
jworkers
.
Add
(
model
->
Clone
(
ids
[
i
]));
XWorkerJob
*
worker
=
new
XWorkerJob
();
worker
->
SetModel
(
model
->
Clone
(
ids
[
i
]));
jworkers
.
Add
(
worker
);
}
}
/*
add a data-collecting worker
>> mode - the data-transfer mode of the worker
*/
void
XLeader
::
AddJobCollectWorker
(
DATA_COLLECT_TYPE
mode
)
{
XWorkerCollect
*
worker
=
new
XWorkerCollect
();
worker
->
SetCollectMode
(
mode
);
cworkers
.
Add
(
worker
);
}
/*
add a model-update worker
>> model - the model
>> optimizer - the optimizer
*/
void
XLeader
::
AddJobUpdateWorker
(
XModel
*
model
,
XOptimizer
*
optimizer
)
{
XWorkerUpdate
*
worker
=
new
XWorkerUpdate
();
worker
->
SetOptimizer
(
optimizer
);
uworkers
.
Add
(
worker
);
}
/*
run the model (for one time)
>> config - the configuration
>> dataDistributor - data distributor
>> model - the neural network that we want to run
>> optimizer - the optimization method
<< return - if we can fetch the new data
*/
void
XLeader
::
Run
(
XConfig
*
config
,
DataDistributeBase
*
dataDistributor
,
bool
XLeader
::
Run
(
XConfig
*
config
,
DataDistributeBase
*
dataDistributor
,
XModel
*
model
,
XOptimizer
*
optimizer
)
{
bool
isDataOK
=
true
;
/* Feed the input to each worker and geneate the output.
For each worker, we define a job queue and enqueue jobs
into it.
...
...
@@ -115,19 +170,47 @@ void XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
XModel
*
model
=
worker
->
GetModel
();
/* get a batch of samples */
dataDistributor
->
GetBatch
(
worker
->
GetInput
());
bool
fetched
=
dataDistributor
->
GetBatch
(
worker
->
GetInput
());
/* job in
the queue
: refresh the model */
/* job in
queue 1
: refresh the model */
worker
->
AddJobRefresh
(
model
);
/* job in
the queue
: run the model */
/* job in
queue 1
: run the model */
worker
->
AddJobNeuralNet
(
model
,
worker
->
GetInput
(),
worker
->
GetOutput
());
/* clear it */
worker
->
Clear
();
if
(
!
fetched
)
isDataOK
=
false
;
}
XList
members
(
jworkers
.
count
);
for
(
int
i
=
0
;
i
<
jworkers
.
count
;
i
++
)
{
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
jworkers
[
i
];
if
(
worker
->
GetModel
()
!=
model
)
members
.
Add
(
worker
->
GetModel
());
}
/* job in queue 2: collect the (gradient) data */
if
(
cworkers
.
count
>
0
)
{
XWorkerCollect
*
collecter
=
(
XWorkerCollect
*
)
cworkers
.
GetItem
(
0
);
collecter
->
AddJobCollect
(
&
members
,
model
);
}
else
{
ShowNTErrors
(
"No data-collecting workers!"
);
}
/* job in queue 3: update the model */
if
(
uworkers
.
count
>
0
)
{
XWorkerUpdate
*
updater
=
(
XWorkerUpdate
*
)
uworkers
.
GetItem
(
0
);
updater
->
AddJobUpdate
(
model
,
optimizer
);
}
else
{
ShowNTErrors
(
"No model-update workers!"
);
}
/* collect the (gradient) data and update the model */
return
isDataOK
;
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
source/train/XLeader.h
查看文件 @
554bdfd6
...
...
@@ -39,6 +39,8 @@
#include "XOptimizer.h"
#include "XBaseTemplate.h"
#include "XWorkerJob.h"
#include "XWorkerCollect.h"
#include "XWorkerUpdate.h"
#include "../tensor/XConfig.h"
#include "../tensor/XList.h"
...
...
@@ -65,9 +67,15 @@ protected:
/* communication mode */
XLEADER_MODE
mode
;
/* job workers
of the leader
*/
/* job workers */
XList
jworkers
;
/* data-collecting workers */
XList
cworkers
;
/* model-update workers */
XList
uworkers
;
public
:
/* constructor */
XLeader
();
...
...
@@ -84,14 +92,23 @@ public:
/* get id */
int
GetID
();
/* start the workers */
void
Start
();
/* set the communication mode */
void
SetMode
(
XLEADER_MODE
myMode
);
/* add a number of job workers (given their device ids) */
void
AddJobWorker
(
XModel
*
model
,
int
n
,
int
*
ids
);
/* add a data-collecting worker */
void
AddJobCollectWorker
(
DATA_COLLECT_TYPE
mode
=
DATA_COLLECT_P2P
);
/* add a model-update worker */
void
AddJobUpdateWorker
(
XModel
*
model
,
XOptimizer
*
optimizer
);
/* run the model (for one time) */
void
Run
(
XConfig
*
config
,
DataDistributeBase
*
dataDistributor
,
bool
Run
(
XConfig
*
config
,
DataDistributeBase
*
dataDistributor
,
XModel
*
model
,
XOptimizer
*
optimizer
);
};
...
...
source/train/XModel.h
查看文件 @
554bdfd6
...
...
@@ -43,8 +43,12 @@ parameter state
1) not ready
2) ready
3) the parameter has been collected from other models
4) the updated parameter
*/
enum
PARAM_STATE
{
PARAM_STATE_NOT_READY
,
PARAM_STATE_READY
,
PARAM_STATE_COLLECTED
};
enum
PARAM_STATE
{
PARAM_STATE_NOT_READY
,
PARAM_STATE_READY
,
PARAM_STATE_COLLECTED
,
PARAM_STATE_UPDATED
};
/* a model template for training */
class
XModel
...
...
source/train/XOptimizer.cpp
查看文件 @
554bdfd6
...
...
@@ -23,11 +23,68 @@
* This class define the template of the update rule in gradient based methods
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-01
* March comes but there was a snow last night.
*/
#include "XOptimizer.h"
#include "../tensor/core/CHeader.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/* constructor */
XOptimizer
::
XOptimizer
()
{
Clear
();
}
/* de-constructor */
XOptimizer
::~
XOptimizer
()
{
}
/*
initialize the optimizer
>> config - the configuration
*/
void
XOptimizer
::
Init
(
XConfig
*
config
)
{
}
/* clear the optimizer */
void
XOptimizer
::
Clear
()
{
nstep
=
0
;
nepoch
=
0
;
lrate
=
0
;
}
/*
prepare for the update
>> model - the model that we want to update
*/
void
XOptimizer
::
Prepare
(
XModel
*
model
)
{
}
/*
record the update
>> model - the model that we want to update
*/
void
XOptimizer
::
Note
(
XModel
*
model
)
{
nstep
++
;
}
/*
update a parameter matrix
>> param - the parameter matrix
>> gard - the gradient
>> pid - the id of the parameter matrix
*/
void
XOptimizer
::
UpdateParam
(
XTensor
*
param
,
XTensor
*
grad
,
int
pid
)
{
/* the delta rule
\theta_new = \theta_old - \grad * \lrate */
Sum
(
param
,
grad
,
param
,
-
lrate
);
}
}
source/train/XOptimizer.h
查看文件 @
554bdfd6
...
...
@@ -23,17 +23,57 @@
* This class define the template of the update rule in gradient based methods
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-01
* March c
omes
but there was a snow last night.
* March c
ame finally
but there was a snow last night.
*/
#ifndef __XOPTIMIZER_H__
#define __XOPTIMIZER_H__
#include "XModel.h"
#include "../tensor/XConfig.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/* this class defines a template of the optimizer and
implement the simple delta-rule in SGD. */
class
XOptimizer
{
public
:
/* update step number */
int
nstep
;
/* training epoch number */
int
nepoch
;
/* learning rate */
float
lrate
;
public
:
/* constructor */
XOptimizer
();
/* de-constructor */
~
XOptimizer
();
/* initialize the optimizer */
virtual
void
Init
(
XConfig
*
config
);
/* clear the optimizer */
virtual
void
Clear
();
/* prepare for the update */
virtual
void
Prepare
(
XModel
*
model
);
/* record the update */
virtual
void
Note
(
XModel
*
model
);
/* update a parameter matrix */
virtual
void
UpdateParam
(
XTensor
*
param
,
XTensor
*
grad
,
int
pid
);
};
}
...
...
source/train/XTrainer.cpp
查看文件 @
554bdfd6
...
...
@@ -40,14 +40,55 @@ XTrainer::~XTrainer()
{
}
/*
get the device ids of the jobs
>> config - configuration
>> ids - the array of device ids
>> num - number of the jobs
>> maxDevNum - the maximum number of devices
*/
void
XTrainer
::
GetDevIDs
(
XConfig
*
config
,
int
*
ids
,
int
&
num
,
int
maxDevNum
)
{
CheckNTErrors
(
maxDevNum
>
0
,
"No data array for input!"
);
num
=
0
;
for
(
int
i
=
0
;
i
<
maxDevNum
;
i
++
)
{
char
dev
[
16
];
sprintf
(
dev
,
"jobdev%d"
,
i
);
int
id
=
config
->
GetInt
(
dev
,
-
128
);
if
(
id
!=
-
128
)
{
ids
[
num
++
]
=
id
;
}
else
break
;
}
if
(
num
==
0
)
{
char
dev
[
16
];
sprintf
(
dev
,
"jobdev"
);
int
id
=
config
->
GetInt
(
dev
,
-
128
);
if
(
id
!=
-
128
)
ids
[
num
++
]
=
id
;
}
if
(
num
==
0
)
{
char
dev
[
16
];
sprintf
(
dev
,
"dev"
);
int
id
=
config
->
GetInt
(
dev
,
-
128
);
if
(
id
!=
-
128
)
ids
[
num
++
]
=
id
;
}
}
/*
run the trainer (this is the core process)
>> config - configuration
>> dataDistributor - the data distributor that generates an input for the net each time
>> model - the neural network
>> optimizer - the optimizer
*/
void
XTrainer
::
Run
(
XConfig
*
config
,
DataDistributeBase
*
dataDistributor
,
XModel
*
model
)
XModel
*
model
,
XOptimizer
*
optimizer
)
{
CheckNTErrors
(
config
!=
NULL
,
"No input config!"
);
CheckNTErrors
(
dataDistributor
!=
NULL
,
"No input data distributor!"
);
...
...
@@ -57,11 +98,41 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
int
nstep
=
config
->
GetInt
(
"nstep"
,
100000
);
int
epoch
=
0
;
int
step
=
0
;
int
jobNum
=
0
;
int
*
ids
=
new
int
[
MAX_DEVICE_NUM_TRAINING
];
GetDevIDs
(
config
,
ids
,
jobNum
,
MAX_DEVICE_NUM_TRAINING
);
/* create the server and workers */
XLeader
leader
;
leader
.
Init
();
leader
.
AddJobWorker
(
model
,
jobNum
,
ids
);
leader
.
AddJobCollectWorker
();
leader
.
AddJobUpdateWorker
(
model
,
optimizer
);
leader
.
Start
();
/* train the model */
for
(
int
epoch
=
0
;
epoch
<
nepoch
;
epoch
++
)
{
if
(
step
++
>=
nstep
)
bool
ok
=
true
;
dataDistributor
->
Start
();
while
(
ok
)
{
/* one step of udpate */
ok
=
leader
.
Run
(
config
,
dataDistributor
,
model
,
optimizer
);
if
(
step
++
>=
nstep
)
break
;
}
dataDistributor
->
End
();
if
(
step
>=
nstep
)
break
;
}
delete
[]
ids
;
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
source/train/XTrainer.h
查看文件 @
554bdfd6
...
...
@@ -38,6 +38,8 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
#define MAX_DEVICE_NUM_TRAINING 128
/*
Training of neural networks with gradient methods. Here we suppose that we
are training NLP models. The routine could be:
...
...
@@ -59,8 +61,6 @@ the job to the workers and maintain the model.
*/
class
XTrainer
{
private
:
public
:
/* constructor */
XTrainer
();
...
...
@@ -68,9 +68,13 @@ public:
/* de-constructor */
~
XTrainer
();
/* get the device ids of the jobs */
void
GetDevIDs
(
XConfig
*
config
,
int
*
ids
,
int
&
num
,
int
maxDevNum
);
/* run the leader (this is the core process) */
virtual
void
Run
(
XConfig
*
config
,
DataDistributeBase
*
dataDistributor
,
XModel
*
model
);
void
Run
(
XConfig
*
config
,
DataDistributeBase
*
dataDistributor
,
XModel
*
model
,
XOptimizer
*
optimizer
);
};
}
#endif // __XTRAINER_H__
\ No newline at end of file
source/train/XWorker.cpp
查看文件 @
554bdfd6
...
...
@@ -39,6 +39,7 @@ XWorker::XWorker()
/* de-constructor */
XWorker
::~
XWorker
()
{
Stop
();
}
/* set device id */
...
...
@@ -78,7 +79,7 @@ void XWorker::AddJob(void * job, XList * jobArgs)
/* start the work */
void
XWorker
::
Start
()
{
queue
.
Stop
JobConsumer
();
queue
.
Run
JobConsumer
();
}
/* stop the work */
...
...
source/train/XWorkerCollect.cpp
查看文件 @
554bdfd6
...
...
@@ -91,7 +91,7 @@ void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleep
if
(
source
->
flags
[
j
]
!=
PARAM_STATE_COLLECTED
&&
sp
[
j
]
->
isGradFinished
)
{
/* data transmit */
CollectP2P
(
sp
.
GetItem
(
j
)
,
tp
.
GetItem
(
j
)
);
CollectP2P
(
sp
.
GetItem
(
j
)
->
grad
,
tp
.
GetItem
(
j
)
->
grad
);
/* reset the flag */
source
->
flags
[
j
]
=
PARAM_STATE_COLLECTED
;
...
...
@@ -134,11 +134,11 @@ void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleep
for
(
int
i
=
0
;
i
<
sourceList
->
count
;
i
++
)
{
XModel
*
source
=
(
XModel
*
)
sourceList
->
GetItem
(
i
);
TensorList
&
sp
=
source
->
params
;
tensorList
.
Add
(
sp
.
GetItem
(
j
));
tensorList
.
Add
(
sp
.
GetItem
(
j
)
->
grad
);
}
/* data transmit */
CollectReduceSum
(
&
tensorList
,
tp
.
GetItem
(
j
));
CollectReduceSum
(
&
tensorList
,
tp
.
GetItem
(
j
)
->
grad
);
/* reset the flags */
for
(
int
i
=
0
;
i
<
sourceList
->
count
;
i
++
)
{
...
...
@@ -146,6 +146,7 @@ void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleep
source
->
flags
[
j
]
=
PARAM_STATE_COLLECTED
;
}
target
->
flags
[
j
]
=
PARAM_STATE_COLLECTED
;
finished
+=
sourceList
->
count
;
}
}
...
...
@@ -236,4 +237,25 @@ void XWorkerCollect::CollectAllReduce(XList * all)
ShowNTErrors
(
"TODO!"
);
}
/*
add a new job of collecting data
>> sourceList - the list of models that we want collect data from
>> target - the destination of the collection
*/
bool
XWorkerCollect
::
AddJobCollect
(
XList
*
sourceList
,
XModel
*
target
)
{
CheckNTErrors
(
sourceList
!=
NULL
,
"no input source model list!"
);
CheckNTErrors
(
target
!=
NULL
,
"no input target model!"
);
XList
args
;
args
.
Add
(
this
);
args
.
AddInt
(
sourceList
->
count
);
args
.
AddList
(
sourceList
);
args
.
Add
(
target
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerCollect
::
Collect
,
&
args
);
return
true
;
}
}
source/train/XWorkerCollect.h
查看文件 @
554bdfd6
...
...
@@ -22,7 +22,8 @@
/*
* The worker that collects data from workers.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-01
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-02
* minus 10 degrees centigrade comes again!
*/
#ifndef __XWORKERCOLLECT_H__
...
...
@@ -74,6 +75,9 @@ public:
/* all-reduce */
void
CollectAllReduce
(
XList
*
all
);
/* add a new job of collecting data */
bool
AddJobCollect
(
XList
*
sourceList
,
XModel
*
target
);
};
}
...
...
source/train/XWorkerUpdate.cpp
查看文件 @
554bdfd6
...
...
@@ -27,5 +27,107 @@
#include "XWorkerUpdate.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts (NiuTrans.Tensor)
/* constructor */
XWorkerUpdate
::
XWorkerUpdate
()
{
optimizer
=
NULL
;
}
/* de-constructor */
XWorkerUpdate
::~
XWorkerUpdate
()
{
}
/* set the optimizer */
void
XWorkerUpdate
::
SetOptimizer
(
XOptimizer
*
myOptimizer
)
{
optimizer
=
myOptimizer
;
}
/* get the optimizer */
XOptimizer
*
XWorkerUpdate
::
GetOptimizer
()
{
return
optimizer
;
}
/*
update the model
>> model - the model that we want to update
>> optimizer - the optimizer
>> sleepTime - waiting time in each update
*/
void
XWorkerUpdate
::
UpdateModel
(
XModel
*
model
,
XOptimizer
*
optimizer
,
long
sleepTime
)
{
int
finished
=
0
;
TensorList
&
params
=
model
->
params
;
PARAM_STATE
*
flags
=
model
->
flags
;
optimizer
->
Prepare
(
model
);
while
(
1
)
{
for
(
int
i
=
0
;
i
<
params
.
count
;
i
++
)
{
if
(
flags
[
i
]
==
PARAM_STATE_COLLECTED
)
{
XTensor
*
param
=
params
.
GetItem
(
i
);
XTensor
*
grad
=
param
->
grad
;
/* update the parameter */
optimizer
->
UpdateParam
(
param
,
grad
,
i
);
/* set the flag */
flags
[
i
]
=
PARAM_STATE_UPDATED
;
finished
++
;
}
}
if
(
finished
==
params
.
count
)
break
;
#ifdef _WIN32
Sleep
((
DWORD
)
sleepTime
);
#else
sleep
((
unsigned
)
sleepTime
/
1000
);
#endif
}
optimizer
->
Note
(
model
);
}
/*
wrapper of UpdateModel
>> args - arguments of the update
*/
void
XWorkerUpdate
::
Update
(
XList
*
args
)
{
CheckNTErrors
(
args
!=
NULL
&&
args
->
count
>
3
,
"Illegal argument list!"
);
XWorkerUpdate
*
updater
=
(
XWorkerUpdate
*
)
args
->
GetItem
(
0
);
XModel
*
model
=
(
XModel
*
)
args
->
GetItem
(
1
);
XOptimizer
*
optimizer
=
(
XOptimizer
*
)
args
->
GetItem
(
2
);
updater
->
UpdateModel
(
model
,
optimizer
,
SLEEP_TIME_IN_MODEL_UPDATE
);
}
/*
add a new job of model update
>> model - the model that we want to update
>> optimizer - the optimizer
*/
bool
XWorkerUpdate
::
AddJobUpdate
(
XModel
*
model
,
XOptimizer
*
optimizer
)
{
CheckNTErrors
(
model
!=
NULL
,
"No input model!"
);
CheckNTErrors
(
optimizer
!=
NULL
,
"No optimizer!"
);
XList
args
;
args
.
Add
(
this
);
args
.
Add
(
model
);
args
.
Add
(
optimizer
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerUpdate
::
Update
,
&
args
);
return
true
;
}
}
\ No newline at end of file
source/train/XWorkerUpdate.h
查看文件 @
554bdfd6
...
...
@@ -29,13 +29,41 @@
#define __XWORKERUPDATE_H__
#include "XWorker.h"
#include "XOptimizer.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
#define SLEEP_TIME_IN_MODEL_UPDATE 10
/* The class defines the model-update worker */
class
XWorkerUpdate
:
public
XWorker
{
protected
:
/* the optimizer */
XOptimizer
*
optimizer
;
public
:
/* constructor */
XWorkerUpdate
();
/* de-constructor */
~
XWorkerUpdate
();
/* set the optimizer */
void
SetOptimizer
(
XOptimizer
*
myOptimizer
);
/* get the optimizer */
XOptimizer
*
GetOptimizer
();
/* update the model */
void
UpdateModel
(
XModel
*
model
,
XOptimizer
*
optimizer
,
long
sleepTime
);
/* wrapper of UpdateModel */
static
void
Update
(
XList
*
args
);
/* add a new job of model update */
bool
AddJobUpdate
(
XModel
*
model
,
XOptimizer
*
optimizer
);
};
}
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论