Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
9eda6d83
Commit
9eda6d83
authored
Mar 06, 2021
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bug fixes and a new class XNNRecord
parent
5f345e87
隐藏空白字符变更
内嵌
并排
正在显示
16 个修改的文件
包含
471 行增加
和
49 行删除
+471
-49
source/train/TTrain.cpp
+6
-4
source/train/TTrain.h
+1
-1
source/train/XLeader.cpp
+77
-35
source/train/XLeader.h
+9
-0
source/train/XModel.cpp
+5
-2
source/train/XModel.h
+1
-1
source/train/XNNRecord.cpp
+60
-0
source/train/XNNRecord.h
+65
-0
source/train/XTrainer.cpp
+7
-1
source/train/XWorker.cpp
+3
-0
source/train/XWorker.h
+11
-0
source/train/XWorkerBroadcast.cpp
+1
-1
source/train/XWorkerCollect.cpp
+83
-1
source/train/XWorkerCollect.h
+13
-1
source/train/XWorkerJob.cpp
+95
-1
source/train/XWorkerJob.h
+34
-1
没有找到文件。
source/train/TTrain.cpp
查看文件 @
9eda6d83
...
...
@@ -276,17 +276,19 @@ run the neural network
>> inputs - inputs of the model
>> outputs - outputs of the model
>> golds - gold standards
>> losses - losses of the output respect to the gold standards
*/
bool
TTModel
::
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
)
bool
TTModel
::
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
,
XList
*
losses
)
{
CheckNTErrors
(
inputs
!=
NULL
&&
inputs
->
count
>=
1
,
"Wrong arguments!"
);
CheckNTErrors
(
outputs
!=
NULL
&&
outputs
->
count
>=
1
,
"Wrong arguments!"
);
CheckNTErrors
(
golds
!=
NULL
&&
golds
->
count
>=
1
,
"Wrong arguments!"
);
CheckNTErrors
(
losses
!=
NULL
&&
losses
->
count
>=
1
,
"Wrong arguments!"
);
XTensor
*
input
=
(
XTensor
*
)
inputs
->
GetItem
(
0
);
XTensor
*
output
=
(
XTensor
*
)
outputs
->
GetItem
(
0
);
XTensor
*
gold
=
(
XTensor
*
)
golds
->
GetItem
(
0
);
XTensor
loss
;
XTensor
*
loss
=
(
XTensor
*
)
losses
->
GetItem
(
0
)
;
XTensor
goldOneHot
;
XNet
net
;
...
...
@@ -301,9 +303,9 @@ bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds)
dims
[
goldOneHot
.
order
-
2
]
=
goldOneHot
.
GetDim
(
goldOneHot
.
order
-
1
);
goldOneHot
.
Reshape
(
goldOneHot
.
order
-
1
,
dims
);
loss
=
CrossEntropy
(
output
,
goldOneHot
);
*
loss
=
CrossEntropy
(
*
output
,
goldOneHot
);
net
.
Backward
(
loss
);
net
.
Backward
(
*
loss
);
delete
[]
dims
;
...
...
source/train/TTrain.h
查看文件 @
9eda6d83
...
...
@@ -146,7 +146,7 @@ public:
XModel
*
Clone
(
int
devID
);
/* run the neural network */
bool
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
);
bool
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
,
XList
*
losses
);
};
/* */
...
...
source/train/XLeader.cpp
查看文件 @
9eda6d83
...
...
@@ -65,6 +65,8 @@ void XLeader::Init()
for
(
int
i
=
0
;
i
<
bworkers
.
count
;
i
++
)
delete
(
XWorkerBroadcast
*
)
bworkers
.
GetItem
(
i
);
bworkers
.
Clear
();
serverRecord
.
Clear
();
}
/* set id */
...
...
@@ -109,6 +111,18 @@ void XLeader::SetServerModel(XConfig * config, XModel * model)
SetServerModel
(
config
,
model
,
&
members
);
}
/* get loss */
float
XLeader
::
GetLoss
()
{
return
serverRecord
.
lossAll
;
}
/* get prediction number */
int
XLeader
::
GetPredictNum
()
{
return
serverRecord
.
predictNum
;
}
/*
set the communication mode
>> myMode - the mode
...
...
@@ -117,6 +131,7 @@ void XLeader::SetMode(XLEADER_MODE myMode)
{
mode
=
myMode
;
}
/* start the workers */
void
XLeader
::
Start
()
{
...
...
@@ -195,7 +210,7 @@ void XLeader::AddJobBroadcastWorker()
}
/*
run the model (for one time)
run the model (for one time)
. Basically this is a map-reduce process.
>> config - the configuration
>> dataDistributor - data distributor
>> model - the neural network that we want to run
...
...
@@ -207,6 +222,10 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
{
bool
isDataOK
=
true
;
int
activeJobCount
=
0
;
int
*
active
=
new
int
[
jworkers
.
count
];
for
(
int
i
=
0
;
i
<
jworkers
.
count
;
i
++
)
active
[
i
]
=
0
;
/* Feed the input to each worker and geneate the output.
For each worker, we define a job queue and enqueue jobs
...
...
@@ -226,55 +245,78 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
worker
->
AddJobRefresh
(
jmodel
);
/* job in queue 1: run the model */
worker
->
AddJobNeuralNet
(
jmodel
,
worker
->
GetInput
(),
worker
->
GetOutput
(),
worker
->
GetGold
());
worker
->
AddJobNeuralNet
(
jmodel
,
worker
->
GetInput
(),
worker
->
GetOutput
(),
worker
->
GetGold
(),
worker
->
GetLoss
());
/* job in queue 1: make a record of the run */
worker
->
AddJobRecord
();
active
[
i
]
=
1
;
activeJobCount
++
;
}
}
if
(
activeJobCount
==
0
)
return
false
;
XList
members
(
jworkers
.
count
);
for
(
int
i
=
0
;
i
<
jworkers
.
count
;
i
++
)
{
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
jworkers
[
i
];
members
.
Add
(
worker
->
GetModel
());
}
if
(
activeJobCount
>=
0
)
{
/* member models that are active in this run */
XList
members
(
jworkers
.
count
);
/* job in queue 2: collect the (gradient) data */
if
(
cworkers
.
count
>
0
)
{
XWorkerCollect
*
collecter
=
(
XWorkerCollect
*
)
cworkers
.
GetItem
(
0
);
collecter
->
AddJobCollect
(
&
members
,
&
serverModel
);
}
else
{
ShowNTErrors
(
"No data-collecting workers!"
);
}
/* all member models */
XList
membersAll
(
jworkers
.
count
);
/* job in queue 3: update the model */
if
(
uworkers
.
count
>
0
)
{
XWorkerUpdate
*
updater
=
(
XWorkerUpdate
*
)
uworkers
.
GetItem
(
0
);
updater
->
AddJobUpdate
(
&
serverModel
,
optimizer
);
}
else
{
ShowNTErrors
(
"No model-update workers!"
);
}
/* records of the active member models */
XList
memberRecords
(
jworkers
.
count
);
/* job in queue 4: broadcast the lastest parameters to workers */
if
(
bworkers
.
count
>
0
)
{
XWorkerBroadcast
*
broadcaster
=
(
XWorkerBroadcast
*
)
bworkers
.
GetItem
(
0
);
broadcaster
->
AddJobBroadcast
(
&
serverModel
,
&
members
);
}
else
{
ShowNTErrors
(
"No data-broadcasting workers!"
);
}
for
(
int
i
=
0
;
i
<
jworkers
.
count
;
i
++
)
{
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
jworkers
[
i
];
membersAll
.
Add
(
worker
->
GetModel
()
);
if
(
active
[
i
]
==
1
)
{
members
.
Add
(
worker
->
GetModel
());
memberRecords
.
Add
(
worker
->
GetRecord
());
}
}
WaitForFinishing
();
/* jobs in queue 2: collect the (gradient) data and other stuff. This
is a reduce process. */
if
(
cworkers
.
count
>
0
)
{
XWorkerCollect
*
collecter
=
(
XWorkerCollect
*
)
cworkers
.
GetItem
(
0
);
collecter
->
AddJobCollect
(
&
members
,
&
serverModel
);
collecter
->
AddJobCollectOther
(
&
memberRecords
,
&
serverRecord
);
}
else
{
ShowNTErrors
(
"No data-collecting workers!"
);
}
/* job in queue 3: update the model */
if
(
uworkers
.
count
>
0
)
{
XWorkerUpdate
*
updater
=
(
XWorkerUpdate
*
)
uworkers
.
GetItem
(
0
);
updater
->
AddJobUpdate
(
&
serverModel
,
optimizer
);
}
else
{
ShowNTErrors
(
"No model-update workers!"
);
}
/* job in queue 4: broadcast the lastest parameters to workers. NOTE that
we would update a worker to the laster model parameters, even if it is
not involved in this run. */
if
(
bworkers
.
count
>
0
)
{
XWorkerBroadcast
*
broadcaster
=
(
XWorkerBroadcast
*
)
bworkers
.
GetItem
(
0
);
broadcaster
->
AddJobBroadcast
(
&
serverModel
,
&
membersAll
);
}
else
{
ShowNTErrors
(
"No data-broadcasting workers!"
);
}
WaitForFinishing
();
}
for
(
int
i
=
0
;
i
<
jworkers
.
count
;
i
++
)
{
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
jworkers
[
i
];
worker
->
Clear
();
}
delete
[]
active
;
return
isDataOK
;
}
...
...
source/train/XLeader.h
查看文件 @
9eda6d83
...
...
@@ -69,6 +69,9 @@ protected:
/* a model that keeps the parameters (as a server) */
XModel
serverModel
;
/* a record that keeps the information of the run */
XNNRecord
serverRecord
;
/* communication mode */
XLEADER_MODE
mode
;
...
...
@@ -106,6 +109,12 @@ public:
/* set the server model */
void
SetServerModel
(
XConfig
*
config
,
XModel
*
model
);
/* get loss */
float
GetLoss
();
/* get prediction number */
int
GetPredictNum
();
/* start the workers */
void
Start
();
...
...
source/train/XModel.cpp
查看文件 @
9eda6d83
...
...
@@ -69,8 +69,10 @@ XModel * XModel::Clone(int devID)
run the neural network
>> inputs - inputs of the model
>> outputs - outputs of the model
>> golds - gold standards
>> losses - losses of the input with respect to the gold standards
*/
bool
XModel
::
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
)
bool
XModel
::
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
,
XList
*
losses
)
{
return
false
;
}
...
...
@@ -86,8 +88,9 @@ bool XModel::RunMe(XList * args)
XList
*
inputs
=
(
XList
*
)
args
->
GetItem
(
0
);
XList
*
outputs
=
(
XList
*
)
args
->
GetItem
(
1
);
XList
*
golds
=
(
XList
*
)
args
->
GetItem
(
2
);
XList
*
losses
=
(
XList
*
)
args
->
GetItem
(
3
);
if
(
RunSimple
(
inputs
,
outputs
,
golds
))
if
(
RunSimple
(
inputs
,
outputs
,
golds
,
losses
))
return
true
;
ShowNTErrors
(
"You must be overload one of these: XModel::RunSimple ... !"
);
...
...
source/train/XModel.h
查看文件 @
9eda6d83
...
...
@@ -82,7 +82,7 @@ public:
/* run the neural network */
virtual
bool
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
);
bool
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
,
XList
*
losses
);
protected
:
/* run the neural network */
...
...
source/train/XNNRecord.cpp
0 → 100644
查看文件 @
9eda6d83
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* A record that keeps some information in running and training neural networks
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-06
* I will climb mountains with my wife and son this afternoon, hahaha :)
*/
#include "XNNRecord.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/* constructor */
XNNRecord
::
XNNRecord
()
{
Clear
();
}
/* de-constructor */
XNNRecord
::~
XNNRecord
()
{
}
/* clear it */
void
XNNRecord
::
Clear
()
{
lossAll
=
0
;
predictNum
=
0
;
state
=
XWORKER_UNSTARTED
;
}
/* update me with another record */
void
XNNRecord
::
Update
(
XNNRecord
&
record
)
{
lossAll
+=
record
.
lossAll
;
predictNum
+=
record
.
predictNum
;
}
}
source/train/XNNRecord.h
0 → 100644
查看文件 @
9eda6d83
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* A record that keeps some information in running and training neural networks
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-06
* I will climb mountains with my wife and son this afternoon, hahaha :)
*/
#ifndef __XNNRECORD_H__
#define __XNNRECORD_H__
#include "XWorker.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/* a record of keeping some stuff during training */
class
XNNRecord
{
public
:
/* loss over all samples */
float
lossAll
;
/* prediction number */
int
predictNum
;
/* state */
XWORKER_STATE
state
;
public
:
/* constructor */
XNNRecord
();
/* de-constructor */
~
XNNRecord
();
/* clear it */
void
Clear
();
/* update me with another record */
void
Update
(
XNNRecord
&
record
);
};
}
#endif
\ No newline at end of file
source/train/XTrainer.cpp
查看文件 @
9eda6d83
...
...
@@ -103,6 +103,9 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
int
*
ids
=
new
int
[
MAX_DEVICE_NUM_TRAINING
];
GetDevIDs
(
config
,
ids
,
jobNum
,
MAX_DEVICE_NUM_TRAINING
);
float
lossAll
=
0
;
int
predictNum
=
0
;
/* create the server and workers */
XLeader
leader
;
leader
.
Init
();
...
...
@@ -124,8 +127,11 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
/* one step of udpate */
ok
=
leader
.
Run
(
config
,
dataDistributor
,
model
,
optimizer
);
float
loss
=
leader
.
GetLoss
()
/
leader
.
GetPredictNum
();
if
((
step
+
1
)
%
100
==
0
)
fprintf
(
stderr
,
"epoch:%d step:%d
\n
"
,
epoch
+
1
,
step
+
1
);
fprintf
(
stderr
,
"epoch:%d step:%d loss:%f predict:%d
\n
"
,
epoch
+
1
,
step
+
1
,
loss
,
leader
.
GetPredictNum
());
if
(
step
++
>=
nstep
)
break
;
...
...
source/train/XWorker.cpp
查看文件 @
9eda6d83
...
...
@@ -34,6 +34,9 @@ namespace nts {
/* constructor */
XWorker
::
XWorker
()
{
devID
=
-
1
;
id
=
-
1
;
state
=
XWORKER_UNSTARTED
;
}
/* de-constructor */
...
...
source/train/XWorker.h
查看文件 @
9eda6d83
...
...
@@ -35,6 +35,14 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/*
state of a worker
1) unstarted
2) started
3) finished
*/
enum
XWORKER_STATE
{
XWORKER_UNSTARTED
,
XWORKER_STARTED
,
XWORKER_FINISHED
};
/* the worker class */
class
XWorker
{
...
...
@@ -49,6 +57,9 @@ protected:
/* the queue */
XQueue
queue
;
/* state of the worker */
XWORKER_STATE
state
;
public
:
/* constructor */
XWorker
();
...
...
source/train/XWorkerBroadcast.cpp
查看文件 @
9eda6d83
...
...
@@ -100,7 +100,7 @@ void XWorkerBroadcast::Broadcast(XList * args)
/* target models */
int
targetNum
=
args
->
GetItemInt
(
2
);
XList
target
(
targetNum
)
;
XList
target
;
for
(
int
i
=
0
;
i
<
targetNum
;
i
++
)
{
XModel
*
model
=
(
XModel
*
)
args
->
GetItem
(
3
+
i
);
target
.
Add
(
model
);
...
...
source/train/XWorkerCollect.cpp
查看文件 @
9eda6d83
...
...
@@ -177,7 +177,7 @@ void XWorkerCollect::Collect(XList * args)
int
sourceNum
=
args
->
GetItemInt
(
1
);
/* the source models */
XList
source
(
sourceNum
)
;
XList
source
;
for
(
int
i
=
0
;
i
<
sourceNum
;
i
++
)
{
XModel
*
model
=
(
XModel
*
)
args
->
GetItem
(
2
+
i
);
source
.
Add
(
model
);
...
...
@@ -257,4 +257,86 @@ bool XWorkerCollect::AddJobCollect(XList * sourceList, XModel * target)
return
true
;
}
/*
collect the data of the run (i.e., loss). This is a reducer.
>> sourceList - the list of record
>> target - the record that we keep the reduce result
>> sleepTime - waiting time in collecting data
*/
void
XWorkerCollect
::
CollectOtherData
(
XList
*
sourceList
,
XNNRecord
*
target
,
long
sleepTime
)
{
int
finished
=
0
;
int
*
flags
=
new
int
[
sourceList
->
count
];
for
(
int
i
=
0
;
i
<
sourceList
->
count
;
i
++
)
flags
[
i
]
=
0
;
while
(
1
)
{
for
(
int
i
=
0
;
i
<
sourceList
->
count
;
i
++
)
{
if
(
flags
[
i
]
!=
0
)
continue
;
XNNRecord
*
source
=
(
XNNRecord
*
)
sourceList
->
GetItem
(
i
);
if
(
source
->
state
==
XWORKER_FINISHED
)
{
if
(
target
!=
source
)
target
->
Update
(
*
source
);
flags
[
i
]
=
1
;
finished
++
;
}
}
if
(
finished
==
sourceList
->
count
)
break
;
#ifdef _WIN32
Sleep
((
DWORD
)
sleepTime
);
#else
sleep
((
unsigned
)
sleepTime
/
1000
);
#endif
}
delete
[]
flags
;
}
/* wrapper of CollectOtherData */
void
XWorkerCollect
::
CollectOther
(
XList
*
args
)
{
XWorkerCollect
*
collecter
=
(
XWorkerCollect
*
)
args
->
GetItem
(
0
);
int
sourceNum
=
args
->
GetItemInt
(
1
);
/* the source records */
XList
source
;
for
(
int
i
=
0
;
i
<
sourceNum
;
i
++
)
{
XNNRecord
*
record
=
(
XNNRecord
*
)
args
->
GetItem
(
2
+
i
);
source
.
Add
(
record
);
}
/* the target record */
XNNRecord
*
target
=
(
XNNRecord
*
)
args
->
GetItem
(
2
+
sourceNum
);
collecter
->
CollectOtherData
(
&
source
,
target
,
SLEEP_TIME_IN_COLLECTING_OTHER
);
}
/*
add a new job of collecting data of the run (i.e., loss)
collect the data of the run (i.e., loss). This is a reducer.
>> sourceList - the list of record
>> target - the record that we keep the reduce result
*/
bool
XWorkerCollect
::
AddJobCollectOther
(
XList
*
sourceList
,
XNNRecord
*
target
)
{
CheckNTErrors
(
sourceList
!=
NULL
,
"no input source record list!"
);
CheckNTErrors
(
target
!=
NULL
,
"no input target record!"
);
XList
args
;
args
.
Add
(
this
);
args
.
AddInt
(
sourceList
->
count
);
args
.
AddList
(
sourceList
);
args
.
Add
(
target
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerCollect
::
CollectOther
,
&
args
);
return
true
;
}
}
source/train/XWorkerCollect.h
查看文件 @
9eda6d83
...
...
@@ -31,10 +31,12 @@
#include "XWorker.h"
#include "XModel.h"
#include "XWorkerJob.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
#define SLEEP_TIME_IN_COLLECTING 10
#define SLEEP_TIME_IN_COLLECTING_OTHER 10
/*
data collection method
...
...
@@ -61,7 +63,7 @@ public:
/* set the collection type */
void
SetCollectMode
(
DATA_COLLECT_TYPE
myMode
);
/* collect
data
*/
/* collect
the gradient data (i.e., a reducer)
*/
void
CollectData
(
XList
*
sourceList
,
XModel
*
target
,
long
sleepTime
);
/* wrapper of CollectData */
...
...
@@ -79,6 +81,16 @@ public:
/* add a new job of collecting data */
bool
AddJobCollect
(
XList
*
sourceList
,
XModel
*
target
);
/* collect the data of the run (i.e., loss). This is a reducer. */
void
CollectOtherData
(
XList
*
sourceList
,
XNNRecord
*
target
,
long
sleepTime
);
/* wrapper of CollectOtherData */
static
void
CollectOther
(
XList
*
args
);
/* add a new job of collecting data of the run (i.e., loss) */
bool
AddJobCollectOther
(
XList
*
sourceList
,
XNNRecord
*
target
);
};
}
...
...
source/train/XWorkerJob.cpp
查看文件 @
9eda6d83
...
...
@@ -27,6 +27,7 @@
#include "XWorkerJob.h"
#include "../tensor/XList.h"
#include "../tensor/core/CHeader.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
...
...
@@ -47,6 +48,9 @@ XWorkerJob::~XWorkerJob()
for
(
int
i
=
0
;
i
<
golds
.
count
;
i
++
)
delete
(
XTensor
*
)
golds
[
i
];
for
(
int
i
=
0
;
i
<
losses
.
count
;
i
++
)
delete
(
XTensor
*
)
losses
[
i
];
}
/* set the model */
...
...
@@ -61,6 +65,13 @@ XModel * XWorkerJob::GetModel()
return
model
;
}
/* set the state of the worker */
void
XWorkerJob
::
SetState
(
XWORKER_STATE
myState
)
{
state
=
myState
;
record
.
state
=
myState
;
}
/* clear the worker */
void
XWorkerJob
::
Clear
()
{
...
...
@@ -78,6 +89,15 @@ void XWorkerJob::Clear()
delete
(
XTensor
*
)
golds
[
i
];
golds
.
Clear
();
golds
.
Add
(
new
XTensor
());
for
(
int
i
=
0
;
i
<
losses
.
count
;
i
++
)
delete
(
XTensor
*
)
losses
[
i
];
losses
.
Clear
();
losses
.
Add
(
new
XTensor
());
record
.
Clear
();
SetState
(
XWORKER_UNSTARTED
);
}
/* get the input list */
...
...
@@ -98,6 +118,52 @@ XList * XWorkerJob::GetGold()
return
&
golds
;
}
/* get the loss */
XList
*
XWorkerJob
::
GetLoss
()
{
return
&
losses
;
}
/* get the record of the run */
XNNRecord
*
XWorkerJob
::
GetRecord
()
{
return
&
record
;
}
/* record some stuff */
void
XWorkerJob
::
RecordMe
()
{
float
lossAll
=
0
;
for
(
int
i
=
0
;
i
<
losses
.
count
;
i
++
)
{
XTensor
*
loss
=
(
XTensor
*
)
losses
[
i
];
lossAll
+=
ReduceSumAllValue
(
*
loss
);
}
record
.
lossAll
=
lossAll
;
int
predictNum
=
0
;
for
(
int
i
=
0
;
i
<
outputs
.
count
;
i
++
)
{
XTensor
*
output
=
(
XTensor
*
)
outputs
[
i
];
predictNum
+=
output
->
GetSize
();
}
record
.
predictNum
=
predictNum
;
}
/* get the sum of losses over samples */
float
XWorkerJob
::
GetLossAll
()
{
return
record
.
lossAll
;
}
/* get the number of outputs (predictoins) */
int
XWorkerJob
::
GetPredictNum
()
{
return
record
.
predictNum
;
}
/*
add a new job of model refreshment
>> myModel - the model
...
...
@@ -121,9 +187,11 @@ add a new job of neural network forward and backward computation (with the input
>> inputs - inputs of the neural network
>> outputs - outputs of the neural network
>> golds - gold standards
>> losses - losses of the outputs respect to the gold standards
<< return - succeeded or not
*/
bool
XWorkerJob
::
AddJobNeuralNet
(
XModel
*
myModel
,
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
)
bool
XWorkerJob
::
AddJobNeuralNet
(
XModel
*
myModel
,
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
,
XList
*
losses
)
{
CheckNTErrors
(
myModel
!=
NULL
,
"no input neural network!"
);
CheckNTErrors
(
inputs
!=
NULL
,
"no inputs of the model!"
);
...
...
@@ -134,11 +202,37 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outpu
args
.
Add
(
inputs
);
args
.
Add
(
outputs
);
args
.
Add
(
golds
);
args
.
Add
(
losses
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XModel
::
Run
,
&
args
);
SetState
(
XWORKER_STARTED
);
return
true
;
}
/* add a new job of recording the running of the nerual network */
bool
XWorkerJob
::
AddJobRecord
()
{
XList
args
;
args
.
Add
(
this
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerJob
::
RecordMeStatic
,
&
args
);
return
true
;
}
/* wrapper of RecordMe */
void
XWorkerJob
::
RecordMeStatic
(
XList
*
args
)
{
CheckNTErrors
(
args
!=
NULL
&&
args
->
count
>
0
,
"Illegal arguments!"
);
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
args
->
GetItem
(
0
);
worker
->
RecordMe
();
worker
->
SetState
(
XWORKER_FINISHED
);
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
source/train/XWorkerJob.h
查看文件 @
9eda6d83
...
...
@@ -31,6 +31,7 @@
#include "XWorker.h"
#include "XModel.h"
#include "XNNRecord.h"
#include "XBaseTemplate.h"
#include "../tensor/XList.h"
...
...
@@ -51,6 +52,12 @@ protected:
/* the gold standard */
XList
golds
;
/* the loss */
XList
losses
;
/* record the information in running the neural network */
XNNRecord
record
;
public
:
...
...
@@ -66,6 +73,9 @@ public:
/* get the parameter keeper */
XModel
*
GetModel
();
/* set the state of the worker */
void
SetState
(
XWORKER_STATE
myState
);
/* clear the worker */
void
Clear
();
...
...
@@ -78,11 +88,34 @@ public:
/* get the gold standard */
XList
*
GetGold
();
/* get the loss */
XList
*
GetLoss
();
/* get the record of the run */
XNNRecord
*
GetRecord
();
/* record some stuff */
void
RecordMe
();
/* get the sum of losses over samples */
float
GetLossAll
();
/* get the number of outputs (predictoins) */
int
GetPredictNum
();
/* add a new job of model refreshment */
bool
AddJobRefresh
(
XModel
*
myModel
);
/* add a new job of neural network forward and backward computation (with the input) */
bool
AddJobNeuralNet
(
XModel
*
myModel
,
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
);
bool
AddJobNeuralNet
(
XModel
*
myModel
,
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
,
XList
*
losses
);
/* add a new job of recording the running of the nerual network */
bool
AddJobRecord
();
private
:
/* wrapper of RecordMe */
static
void
RecordMeStatic
(
XList
*
args
);
};
}
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论