Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
e9d68683
Commit
e9d68683
authored
Mar 05, 2021
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bug fixes
parent
fbb4331c
隐藏空白字符变更
内嵌
并排
正在显示
12 个修改的文件
包含
120 行增加
和
55 行删除
+120
-55
source/Main.cpp
+1
-1
source/train/TTrain.cpp
+32
-34
source/train/TTrain.h
+3
-3
source/train/XBaseTemplate.cpp
+20
-2
source/train/XBaseTemplate.h
+5
-1
source/train/XLeader.cpp
+2
-2
source/train/XModel.cpp
+26
-3
source/train/XModel.h
+5
-1
source/train/XOptimizer.cpp
+1
-1
source/train/XWorkerJob.cpp
+22
-4
source/train/XWorkerJob.h
+2
-2
source/train/XWorkerUpdate.cpp
+1
-1
没有找到文件。
source/Main.cpp
查看文件 @
e9d68683
...
@@ -42,7 +42,7 @@ int main( int argc, const char ** argv )
...
@@ -42,7 +42,7 @@ int main( int argc, const char ** argv )
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-test"
))
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-test"
))
Test
();
Test
();
else
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-testtrain"
))
else
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-testtrain"
))
TestTrain
(
argc
-
1
,
argv
+
1
);
TestTrain
();
else
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-fnnlm"
))
else
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-fnnlm"
))
FNNLMMain
(
argc
-
1
,
argv
+
1
);
FNNLMMain
(
argc
-
1
,
argv
+
1
);
else
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-t2t"
))
else
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-t2t"
))
...
...
source/train/TTrain.cpp
查看文件 @
e9d68683
...
@@ -68,12 +68,12 @@ void GeneateTTrainData(const char * fileName)
...
@@ -68,12 +68,12 @@ void GeneateTTrainData(const char * fileName)
}
}
/* run the test */
/* run the test */
void
TestTrain
(
int
argc
,
const
char
**
argv
)
void
TestTrain
()
{
{
GeneateTTrainData
(
"ttrain.txt"
);
GeneateTTrainData
(
"ttrain.txt"
);
XConfig
config
;
XConfig
config
;
config
.
Create
(
argc
,
argv
);
config
.
Add
(
"dev"
,
-
1
);
TTDataLoader
loader
;
TTDataLoader
loader
;
loader
.
SetFileName
(
"ttrain.txt"
);
loader
.
SetFileName
(
"ttrain.txt"
);
...
@@ -141,35 +141,19 @@ bool TTDataLoader::End()
...
@@ -141,35 +141,19 @@ bool TTDataLoader::End()
return
true
;
return
true
;
}
}
/* get a batch of samples */
/*
bool
TTDataLoader
::
GetBatch
(
XList
*
args
)
get a batch of samples
>> inputs - inputs of the model
>> golds - gold standards
*/
bool
TTDataLoader
::
GetBatchSimple
(
XList
*
inputs
,
XList
*
golds
)
{
{
CheckNTErrors
(
file
!=
NULL
,
"No input file specificed!"
);
CheckNTErrors
(
file
!=
NULL
,
"No input file specificed!"
);
CheckNTErrors
(
inputs
!=
NULL
&&
inputs
->
count
>=
1
,
"Wrong argument!"
);
CheckNTErrors
(
golds
!=
NULL
&&
golds
->
count
>=
1
,
"Wrong argument!"
);
XTensor
*
input
=
NULL
;
XTensor
*
input
=
(
XTensor
*
)
inputs
->
GetItem
(
0
);
XTensor
*
gold
=
NULL
;
XTensor
*
gold
=
(
XTensor
*
)
golds
->
GetItem
(
0
);
XTensor
*
output
=
NULL
;
if
(
args
->
count
==
0
)
{
input
=
new
XTensor
();
args
->
Add
(
input
);
}
else
input
=
(
XTensor
*
)
args
->
GetItem
(
0
);
if
(
args
->
count
==
1
)
{
output
=
new
XTensor
();
args
->
Add
(
output
);
}
if
(
args
->
count
==
2
)
{
gold
=
new
XTensor
();
args
->
Add
(
gold
);
}
else
gold
=
(
XTensor
*
)
args
->
GetItem
(
1
);
int
count
=
0
;
int
count
=
0
;
int
sampleSize
=
MAX_SAMPLE_SIZE
;
int
sampleSize
=
MAX_SAMPLE_SIZE
;
...
@@ -249,9 +233,16 @@ void TTModel::Forward(int devID, XTensor * input, XTensor * output)
...
@@ -249,9 +233,16 @@ void TTModel::Forward(int devID, XTensor * input, XTensor * output)
XTensor
embeddingCat
;
XTensor
embeddingCat
;
XTensor
hidden
;
XTensor
hidden
;
/* [e_0, e_1, e_2] = w_e * input(one-hot) */
embedding
=
Gather
(
embeddingW
,
*
input
);
embedding
=
Gather
(
embeddingW
,
*
input
);
/* e = merge(e_0, e_1, e_2) */
embeddingCat
=
Merge
(
embedding
,
0
,
1
);
embeddingCat
=
Merge
(
embedding
,
0
,
1
);
/* h = e * w_h */
hidden
=
MMul
(
embeddingCat
,
hiddenW
);
hidden
=
MMul
(
embeddingCat
,
hiddenW
);
/* output = Softmax(h) */
*
output
=
Softmax
(
hidden
,
0
);
*
output
=
Softmax
(
hidden
,
0
);
}
}
...
@@ -271,14 +262,21 @@ XModel * TTModel::Clone(int devID)
...
@@ -271,14 +262,21 @@ XModel * TTModel::Clone(int devID)
return
model
;
return
model
;
}
}
/* run the neural network */
/*
bool
TTModel
::
RunMe
(
XList
*
args
)
run the neural network
>> inputs - inputs of the model
>> outputs - outputs of the model
>> golds - gold standards
*/
bool
TTModel
::
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
)
{
{
CheckNTErrors
(
args
!=
NULL
&&
args
->
count
>=
3
,
"Illegal input arguments!"
);
CheckNTErrors
(
inputs
!=
NULL
&&
inputs
->
count
>=
1
,
"Wrong arguments!"
);
CheckNTErrors
(
outputs
!=
NULL
&&
outputs
->
count
>=
1
,
"Wrong arguments!"
);
CheckNTErrors
(
golds
!=
NULL
&&
golds
->
count
>=
1
,
"Wrong arguments!"
);
XTensor
*
input
=
(
XTensor
*
)
arg
s
->
GetItem
(
0
);
XTensor
*
input
=
(
XTensor
*
)
input
s
->
GetItem
(
0
);
XTensor
*
output
=
(
XTensor
*
)
args
->
GetItem
(
1
);
XTensor
*
output
=
(
XTensor
*
)
outputs
->
GetItem
(
0
);
XTensor
*
gold
=
(
XTensor
*
)
args
->
GetItem
(
2
);
XTensor
*
gold
=
(
XTensor
*
)
golds
->
GetItem
(
0
);
XTensor
loss
;
XTensor
loss
;
XNet
net
;
XNet
net
;
...
...
source/train/TTrain.h
查看文件 @
e9d68683
...
@@ -57,7 +57,7 @@ void GeneateTTrainData(const char * fileName);
...
@@ -57,7 +57,7 @@ void GeneateTTrainData(const char * fileName);
/* run the test */
/* run the test */
extern
extern
void
TestTrain
(
int
argc
,
const
char
**
argv
);
void
TestTrain
();
/* data loader */
/* data loader */
class
TTDataLoader
:
public
DataDistributeBase
class
TTDataLoader
:
public
DataDistributeBase
...
@@ -92,7 +92,7 @@ public:
...
@@ -92,7 +92,7 @@ public:
bool
End
();
bool
End
();
/* get a batch of samples */
/* get a batch of samples */
bool
GetBatch
(
XList
*
arg
s
);
bool
GetBatch
Simple
(
XList
*
inputs
,
XList
*
gold
s
);
};
};
/* the model */
/* the model */
...
@@ -134,7 +134,7 @@ public:
...
@@ -134,7 +134,7 @@ public:
XModel
*
Clone
(
int
devID
);
XModel
*
Clone
(
int
devID
);
/* run the neural network */
/* run the neural network */
bool
Run
Me
(
XList
*
arg
s
);
bool
Run
Simple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
gold
s
);
};
};
/* */
/* */
...
...
source/train/XBaseTemplate.cpp
查看文件 @
e9d68683
...
@@ -60,11 +60,29 @@ bool DataDistributeBase::End()
...
@@ -60,11 +60,29 @@ bool DataDistributeBase::End()
return
true
;
return
true
;
}
}
/*
get a batch of samples
>> inputs - inputs of the model
>> golds - gold standards
*/
bool
DataDistributeBase
::
GetBatchSimple
(
XList
*
inputs
,
XList
*
golds
)
{
return
false
;
}
/* get a batch of samples */
/* get a batch of samples */
bool
DataDistributeBase
::
GetBatch
(
XList
*
args
)
bool
DataDistributeBase
::
GetBatch
(
XList
*
args
)
{
{
ShowNTErrors
(
"DataDistributeBase::GetBatch must be overloaded!"
);
CheckNTErrors
(
args
->
count
>=
2
,
"More input arguments are required!"
);
return
true
;
XList
*
input
=
(
XList
*
)
args
->
GetItem
(
0
);
XList
*
gold
=
(
XList
*
)
args
->
GetItem
(
1
);
if
(
GetBatchSimple
(
input
,
gold
))
return
true
;
ShowNTErrors
(
"You must be overload one of these: DataDistributeBase::GetBatchSimple ... !"
);
return
false
;
}
}
/* get a batch of samples (for multi-threading) */
/* get a batch of samples (for multi-threading) */
...
...
source/train/XBaseTemplate.h
查看文件 @
e9d68683
...
@@ -69,9 +69,13 @@ public:
...
@@ -69,9 +69,13 @@ public:
/* get a batch of samples */
/* get a batch of samples */
virtual
virtual
bool
GetBatchSimple
(
XList
*
inputs
,
XList
*
golds
);
public
:
/* get a batch of samples */
bool
GetBatch
(
XList
*
args
);
bool
GetBatch
(
XList
*
args
);
protected
:
/* get a batch of samples (for multi-threading) */
/* get a batch of samples (for multi-threading) */
bool
GetBatchSafe
(
XList
*
args
);
bool
GetBatchSafe
(
XList
*
args
);
};
};
...
...
source/train/XLeader.cpp
查看文件 @
e9d68683
...
@@ -216,13 +216,13 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
...
@@ -216,13 +216,13 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
XModel
*
jmodel
=
worker
->
GetModel
();
XModel
*
jmodel
=
worker
->
GetModel
();
/* get a batch of samples */
/* get a batch of samples */
bool
fetched
=
dataDistributor
->
GetBatch
(
worker
->
GetInput
());
bool
fetched
=
dataDistributor
->
GetBatch
Simple
(
worker
->
GetInput
(),
worker
->
GetGold
());
/* job in queue 1: refresh the model */
/* job in queue 1: refresh the model */
worker
->
AddJobRefresh
(
jmodel
);
worker
->
AddJobRefresh
(
jmodel
);
/* job in queue 1: run the model */
/* job in queue 1: run the model */
worker
->
AddJobNeuralNet
(
jmodel
,
worker
->
GetInput
(),
worker
->
GetOutput
());
worker
->
AddJobNeuralNet
(
jmodel
,
worker
->
GetInput
(),
worker
->
GetOutput
()
,
worker
->
GetGold
()
);
/* clear it */
/* clear it */
worker
->
Clear
();
worker
->
Clear
();
...
...
source/train/XModel.cpp
查看文件 @
e9d68683
...
@@ -67,12 +67,31 @@ XModel * XModel::Clone(int devID)
...
@@ -67,12 +67,31 @@ XModel * XModel::Clone(int devID)
/*
/*
run the neural network
run the neural network
>> inputs - inputs of the model
>> outputs - outputs of the model
*/
bool
XModel
::
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
)
{
return
false
;
}
/*
run the neural network
>> args - the arguments
>> args - the arguments
*/
*/
bool
XModel
::
RunMe
(
XList
*
args
)
bool
XModel
::
RunMe
(
XList
*
args
)
{
{
ShowNTErrors
(
"NetBase::Run must be overloaded!"
);
CheckNTErrors
(
args
->
count
>=
3
,
"More arguments are required!"
);
return
true
;
XList
*
inputs
=
(
XList
*
)
args
->
GetItem
(
0
);
XList
*
outputs
=
(
XList
*
)
args
->
GetItem
(
1
);
XList
*
golds
=
(
XList
*
)
args
->
GetItem
(
2
);
if
(
RunSimple
(
inputs
,
outputs
,
golds
))
return
true
;
ShowNTErrors
(
"You must be overload one of these: XModel::RunSimple ... !"
);
return
false
;
}
}
/* refresh the model */
/* refresh the model */
...
@@ -103,8 +122,12 @@ bool XModel::Run(XList * args)
...
@@ -103,8 +122,12 @@ bool XModel::Run(XList * args)
{
{
CheckNTErrors
(
args
!=
NULL
||
args
->
count
==
0
,
"no arguments for XModel::Refresh"
);
CheckNTErrors
(
args
!=
NULL
||
args
->
count
==
0
,
"no arguments for XModel::Refresh"
);
XModel
*
model
=
(
XModel
*
)
args
->
GetItem
(
0
);
XModel
*
model
=
(
XModel
*
)
args
->
GetItem
(
0
);
XList
newArgs
;
for
(
int
i
=
1
;
i
<
args
->
count
;
i
++
)
newArgs
.
Add
(
args
->
GetItem
(
i
));
return
model
->
Run
(
a
rgs
);
return
model
->
Run
Me
(
&
newA
rgs
);
}
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
}
/* end of the nts (NiuTrans.Tensor) namespace */
source/train/XModel.h
查看文件 @
e9d68683
...
@@ -80,8 +80,12 @@ public:
...
@@ -80,8 +80,12 @@ public:
virtual
virtual
XModel
*
Clone
(
int
devID
);
XModel
*
Clone
(
int
devID
);
/* run the neural network
(would be overloaded)
*/
/* run the neural network */
virtual
virtual
bool
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
);
protected
:
/* run the neural network */
bool
RunMe
(
XList
*
args
);
bool
RunMe
(
XList
*
args
);
public
:
public
:
...
...
source/train/XOptimizer.cpp
查看文件 @
e9d68683
...
@@ -84,7 +84,7 @@ void XOptimizer::UpdateParam(XTensor * param, XTensor * grad, int pid)
...
@@ -84,7 +84,7 @@ void XOptimizer::UpdateParam(XTensor * param, XTensor * grad, int pid)
{
{
/* the delta rule
/* the delta rule
\theta_new = \theta_old - \grad * \lrate */
\theta_new = \theta_old - \grad * \lrate */
Sum
(
param
,
grad
,
param
,
-
lrate
);
Sum
(
*
param
,
*
grad
,
*
param
,
-
lrate
);
}
}
}
}
source/train/XWorkerJob.cpp
查看文件 @
e9d68683
...
@@ -33,7 +33,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
...
@@ -33,7 +33,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */
/* constructor */
XWorkerJob
::
XWorkerJob
()
XWorkerJob
::
XWorkerJob
()
{
{
Clear
();
}
}
/* de-constructor */
/* de-constructor */
...
@@ -44,6 +44,9 @@ XWorkerJob::~XWorkerJob()
...
@@ -44,6 +44,9 @@ XWorkerJob::~XWorkerJob()
for
(
int
i
=
0
;
i
<
outputs
.
count
;
i
++
)
for
(
int
i
=
0
;
i
<
outputs
.
count
;
i
++
)
delete
(
XTensor
*
)
outputs
[
i
];
delete
(
XTensor
*
)
outputs
[
i
];
for
(
int
i
=
0
;
i
<
golds
.
count
;
i
++
)
delete
(
XTensor
*
)
golds
[
i
];
}
}
/* set the model */
/* set the model */
...
@@ -64,10 +67,17 @@ void XWorkerJob::Clear()
...
@@ -64,10 +67,17 @@ void XWorkerJob::Clear()
for
(
int
i
=
0
;
i
<
inputs
.
count
;
i
++
)
for
(
int
i
=
0
;
i
<
inputs
.
count
;
i
++
)
delete
(
XTensor
*
)
inputs
[
i
];
delete
(
XTensor
*
)
inputs
[
i
];
inputs
.
Clear
();
inputs
.
Clear
();
inputs
.
Add
(
new
XTensor
());
for
(
int
i
=
0
;
i
<
outputs
.
count
;
i
++
)
for
(
int
i
=
0
;
i
<
outputs
.
count
;
i
++
)
delete
(
XTensor
*
)
outputs
[
i
];
delete
(
XTensor
*
)
outputs
[
i
];
outputs
.
Clear
();
outputs
.
Clear
();
outputs
.
Add
(
new
XTensor
());
for
(
int
i
=
0
;
i
<
golds
.
count
;
i
++
)
delete
(
XTensor
*
)
golds
[
i
];
golds
.
Clear
();
golds
.
Add
(
new
XTensor
());
}
}
/* get the input list */
/* get the input list */
...
@@ -82,6 +92,12 @@ XList * XWorkerJob::GetOutput()
...
@@ -82,6 +92,12 @@ XList * XWorkerJob::GetOutput()
return
&
outputs
;
return
&
outputs
;
}
}
/* get the gold standard */
XList
*
XWorkerJob
::
GetGold
()
{
return
&
golds
;
}
/*
/*
add a new job of model refreshment
add a new job of model refreshment
>> myModel - the model
>> myModel - the model
...
@@ -104,9 +120,10 @@ add a new job of neural network forward and backward computation (with the input
...
@@ -104,9 +120,10 @@ add a new job of neural network forward and backward computation (with the input
>> myModel - the model
>> myModel - the model
>> inputs - inputs of the neural network
>> inputs - inputs of the neural network
>> outputs - outputs of the neural network
>> outputs - outputs of the neural network
>> golds - gold standards
<< return - succeeded or not
<< return - succeeded or not
*/
*/
bool
XWorkerJob
::
AddJobNeuralNet
(
XModel
*
myModel
,
XList
*
inputs
,
XList
*
outputs
)
bool
XWorkerJob
::
AddJobNeuralNet
(
XModel
*
myModel
,
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
)
{
{
CheckNTErrors
(
myModel
!=
NULL
,
"no input neural network!"
);
CheckNTErrors
(
myModel
!=
NULL
,
"no input neural network!"
);
CheckNTErrors
(
inputs
!=
NULL
,
"no inputs of the model!"
);
CheckNTErrors
(
inputs
!=
NULL
,
"no inputs of the model!"
);
...
@@ -114,8 +131,9 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outpu
...
@@ -114,8 +131,9 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outpu
XList
args
;
XList
args
;
args
.
Add
(
myModel
);
args
.
Add
(
myModel
);
args
.
AddList
(
inputs
);
args
.
Add
(
inputs
);
args
.
AddList
(
outputs
);
args
.
Add
(
outputs
);
args
.
Add
(
golds
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XModel
::
Run
,
&
args
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XModel
::
Run
,
&
args
);
...
...
source/train/XWorkerJob.h
查看文件 @
e9d68683
...
@@ -50,7 +50,7 @@ protected:
...
@@ -50,7 +50,7 @@ protected:
XList
outputs
;
XList
outputs
;
/* the gold standard */
/* the gold standard */
XList
gold
;
XList
gold
s
;
public
:
public
:
...
@@ -82,7 +82,7 @@ public:
...
@@ -82,7 +82,7 @@ public:
bool
AddJobRefresh
(
XModel
*
myModel
);
bool
AddJobRefresh
(
XModel
*
myModel
);
/* add a new job of neural network forward and backward computation (with the input) */
/* add a new job of neural network forward and backward computation (with the input) */
bool
AddJobNeuralNet
(
XModel
*
myModel
,
XList
*
inputs
,
XList
*
outputs
);
bool
AddJobNeuralNet
(
XModel
*
myModel
,
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
);
};
};
}
}
...
...
source/train/XWorkerUpdate.cpp
查看文件 @
e9d68683
...
@@ -101,7 +101,7 @@ wrapper of UpdateModel
...
@@ -101,7 +101,7 @@ wrapper of UpdateModel
*/
*/
void
XWorkerUpdate
::
Update
(
XList
*
args
)
void
XWorkerUpdate
::
Update
(
XList
*
args
)
{
{
CheckNTErrors
(
args
!=
NULL
&&
args
->
count
>
3
,
"Illegal argument list!"
);
CheckNTErrors
(
args
!=
NULL
&&
args
->
count
>
=
3
,
"Illegal argument list!"
);
XWorkerUpdate
*
updater
=
(
XWorkerUpdate
*
)
args
->
GetItem
(
0
);
XWorkerUpdate
*
updater
=
(
XWorkerUpdate
*
)
args
->
GetItem
(
0
);
XModel
*
model
=
(
XModel
*
)
args
->
GetItem
(
1
);
XModel
*
model
=
(
XModel
*
)
args
->
GetItem
(
1
);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论