Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
98a9130d
Commit
98a9130d
authored
Feb 28, 2021
by
hello
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor class `TrainDataSet`
parent
4bbd6a27
全部展开
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
180 行增加
和
78 行删除
+180
-78
source/sample/transformer/train/TrainDataSet.cpp
+0
-0
source/sample/transformer/train/TrainDataSet.h
+63
-38
source/sample/transformer/train/Trainer.cpp
+113
-39
source/sample/transformer/train/Trainer.h
+4
-1
没有找到文件。
source/sample/transformer/train/TrainDataSet.cpp
查看文件 @
98a9130d
差异被折叠。
点击展开。
source/sample/transformer/train/TrainDataSet.h
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -30,7 +30,6 @@
#include "../../../tensor/XTensor.h"
#include "../../../tensor/XGlobal.h"
#define MAX_WORD_NUM 120
using
namespace
std
;
...
...
@@ -39,39 +38,54 @@ namespace nts {
/* a class of sentence pairs for training */
struct
TrainExample
{
public
:
/* id of the sentence pair */
int
id
;
/* source language setence (tokenized) */
IntList
srcSent
;
IntList
*
srcSent
;
/* target language setence (tokenized) */
IntList
tgtSent
;
/* the key used to shuffle items in a bucket */
int
key
;
IntList
*
tgtSent
;
/* the key used to shuffle buckets */
int
bucketKey
;
public
:
/* constructor */
TrainExample
(
int
myID
,
int
myKey
,
IntList
*
s
,
IntList
*
t
);
/* de-constructor */
~
TrainExample
();
};
struct
ReservedIDs
{
/* the padding id */
int
padID
;
/* the unk id */
int
unkID
;
/* start symbol */
int
startID
;
/* end symbol */
int
endID
;
};
/* A `TrainDataSet` is associated with a file which contains training data. */
struct
TrainDataSet
{
public
:
/* the data buffer */
TrainBufferType
buffer
;
/* a list of empty line number */
IntList
emptyLines
;
public
:
/* the pointer to file stream */
FILE
*
fp
;
/*
current index in the buffer
*/
size_t
curIdx
;
/*
number of training samples
*/
size_t
totalSampleNum
;
/*
size of used data in the buffer
*/
size_t
buffer
Used
;
/*
buffer size
*/
size_t
buffer
Size
;
/* size of the bucket used for grouping sentences */
size_t
bucketSize
;
...
...
@@ -79,34 +93,51 @@ public:
/* indicates whether it is used for training */
bool
isTraining
;
/* the padding id */
int
padID
;
/* the unk id */
int
unkID
;
/* start symbol */
int
startID
;
/* end symbol */
int
endID
;
/* the maximum length for a source sentence */
int
maxSrcLen
;
/* the maximum length for a target sentence */
int
maxTgtLen
;
public
:
/* sort the input by length (in descending order) */
void
SortByLength
();
/* get the maximum source sentence length in a range */
static
int
MaxSrcLen
(
XList
*
buf
,
int
begin
,
int
end
);
/* sort buckets by key (in descending order) */
void
SortBucket
();
/* get the maximum target sentence length in a range */
static
int
MaxTgtLen
(
XList
*
buf
,
int
begin
,
int
end
);
/* sort the
output by key
(in descending order) */
void
Sort
InBucket
(
int
begin
,
int
end
);
/* sort the
input by source sentence length
(in descending order) */
void
Sort
BySrcLength
(
XList
*
buf
);
/*
load data from a file to the buffer
*/
void
LoadDataToBuffer
(
);
/*
sort the input by target sentence length (in descending order)
*/
void
SortByTgtLength
(
XList
*
buf
);
/* generate a mini-batch */
UInt64List
LoadBatch
(
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
size_t
minSentBatch
,
size_t
batchSize
,
int
devID
);
/* sort buckets by key (in descending order) */
void
SortBuckets
(
XList
*
buf
);
/* load the samples into the buffer (a list) */
bool
LoadBatchToBuf
(
XList
*
buf
);
/* load the samples into tensors from the buffer */
static
bool
LoadBatch
(
XList
*
buf
,
bool
LoadBatch
(
XList
*
buf
,
int
&
curIdx
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
size_t
minSentBatch
,
size_t
batchSize
,
int
devID
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
int
minSentBatch
,
int
batchSize
,
int
devID
,
int
&
wc
,
int
&
sc
);
/* release the samples in a buffer */
...
...
@@ -116,14 +147,8 @@ public:
/* initialization function */
void
Init
(
const
char
*
dataFile
,
int
bucketSize
,
bool
training
);
/* check if the buffer is empty */
bool
IsEmpty
();
/* reset the buffer */
void
ClearBuf
();
/* group data into buckets with similar length */
void
BuildBucket
();
void
BuildBucket
(
XList
*
buf
);
/* de-constructor */
~
TrainDataSet
();
...
...
source/sample/transformer/train/Trainer.cpp
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -39,6 +39,41 @@ namespace nmt
Trainer
::
Trainer
()
{
cfg
=
NULL
;
lrate
=
0.0
F
;
lrbias
=
0.0
F
;
sBatchSize
=
0
;
wBatchSize
=
0
;
bucketSize
=
0
;
nstep
=
0
;
nepoch
=
0
;
logInterval
=
0
;
maxCheckpoint
=
0
;
d
=
0
;
nwarmup
=
0
;
vSize
=
0
;
vSizeTgt
=
0
;
useAdam
=
false
;
adamBeta1
=
0.0
F
;
adamBeta2
=
0.0
F
;
adamDelta
=
0.0
F
;
isShuffled
=
false
;
labelSmoothingP
=
0.0
F
;
nStepCheckpoint
=
0
;
useEpochCheckpoint
=
false
;
updateStep
=
0
;
isLenSorted
=
0
;
adamBeta1T
=
1.0
F
;
adamBeta2T
=
1.0
F
;
batchLoader
.
startID
=
0
;
batchLoader
.
endID
=
0
;
batchLoader
.
unkID
=
0
;
batchLoader
.
padID
=
0
;
batchLoader
.
maxSrcLen
=
0
;
batchLoader
.
maxTgtLen
=
0
;
batchLoader
.
bufferSize
=
0
;
}
/* de-constructor */
...
...
@@ -62,13 +97,15 @@ initialization
void
Trainer
::
Init
(
Config
&
config
)
{
cfg
=
&
config
;
lrate
=
config
.
lrate
;
lrbias
=
config
.
lrbias
;
sBatchSize
=
config
.
sBatchSize
;
wBatchSize
=
config
.
wBatchSize
;
bucketSize
=
config
.
bucketSize
;
nepoch
=
config
.
nepoch
;
nstep
=
config
.
nstep
;
nepoch
=
config
.
nepoch
;
logInterval
=
config
.
logInterval
;
maxCheckpoint
=
config
.
maxCheckpoint
;
d
=
config
.
modelSize
;
nwarmup
=
config
.
nwarmup
;
...
...
@@ -87,6 +124,14 @@ void Trainer::Init(Config& config)
adamBeta1T
=
1.0
F
;
adamBeta2T
=
1.0
F
;
batchLoader
.
startID
=
config
.
startID
;
batchLoader
.
endID
=
config
.
endID
;
batchLoader
.
unkID
=
config
.
unkID
;
batchLoader
.
padID
=
config
.
padID
;
batchLoader
.
maxSrcLen
=
config
.
maxSrcLen
;
batchLoader
.
maxTgtLen
=
config
.
maxTgtLen
;
batchLoader
.
bufferSize
=
config
.
bufSize
;
}
/*
...
...
@@ -106,7 +151,7 @@ void Trainer::Train(const char* fn, const char* validFN,
}
int
step
=
0
;
int
wc
=
0
;
int
ws
=
0
;
int
sc
=
0
;
int
wordCount
=
0
;
int
wordCountTotal
=
0
;
int
batchCountTotal
=
0
;
...
...
@@ -134,6 +179,9 @@ void Trainer::Train(const char* fn, const char* validFN,
double
startT
=
GetClockSec
();
int
curIdx
=
0
;
XList
*
buf
=
new
XList
;
batchLoader
.
Init
(
fn
,
bucketSize
,
true
);
for
(
epoch
=
1
;
epoch
<=
nepoch
;
epoch
++
)
{
...
...
@@ -141,10 +189,7 @@ void Trainer::Train(const char* fn, const char* validFN,
wordCount
=
0
;
loss
=
0
;
/* reset the batch loader */
batchLoader
.
ClearBuf
();
while
(
!
batchLoader
.
IsEmpty
())
while
(
step
++
<
nstep
)
{
XNet
net
;
net
.
Clear
();
...
...
@@ -160,21 +205,26 @@ void Trainer::Train(const char* fn, const char* validFN,
XTensor
paddingEnc
;
XTensor
paddingDec
;
UInt64List
info
=
batchLoader
.
LoadBatch
(
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
wBatchSize
,
devID
);
if
(
curIdx
==
0
||
curIdx
==
buf
->
Size
())
{
curIdx
=
0
;
batchLoader
.
LoadBatchToBuf
(
buf
);
}
batchLoader
.
LoadBatch
(
buf
,
curIdx
,
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
wBatchSize
,
devID
,
wc
,
sc
);
wc
=
(
int
)
info
[
0
];
ws
=
(
int
)
info
[
1
];
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
/* output probabilities */
XTensor
output
;
/* make the network */
if
(
model
->
isLM
)
if
(
model
->
isLM
)
{
model
->
MakeLM
(
batchEnc
,
output
,
paddingEnc
,
true
);
else
if
(
model
->
isMT
)
}
else
if
(
model
->
isMT
)
{
model
->
MakeMT
(
batchEnc
,
batchDec
,
output
,
paddingEnc
,
paddingDec
,
true
);
}
else
{
ShowNTErrors
(
"Illegal model type!"
);
}
...
...
@@ -192,15 +242,29 @@ void Trainer::Train(const char* fn, const char* validFN,
DTYPE
lossLocal
=
lossBatch
/
wc
;
bool
doUpdate
=
(
!
IsNAN
(
lossLocal
)
&&
!
IsINF
(
lossLocal
)
&&
lossLocal
<
1e3
F
);
net
.
isGradEfficient
=
true
;
bool
debug
(
false
);
if
(
debug
)
{
LOG
(
"after forward:"
);
batchEnc
.
mem
->
ShowMemUsage
(
stderr
);
exit
(
0
);
}
if
(
doUpdate
)
{
/* back-propagation */
net
.
Backward
(
lossTensor
);
if
(
model
->
encoder
->
useHistory
)
model
->
encoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
if
(
model
->
decoder
->
useHistory
)
model
->
decoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
gradStep
+=
1
;
loss
+=
lossBatch
;
wordCount
+=
wc
;
wordCountTotal
+=
wc
;
batchCountTotal
+=
ws
;
batchCountTotal
+=
sc
;
/* update the parameters */
if
(
gradStep
==
updateStep
)
{
...
...
@@ -227,18 +291,7 @@ void Trainer::Train(const char* fn, const char* validFN,
else
nSkipped
++
;
if
(
++
step
>=
nstep
)
{
isEnd
=
true
;
break
;
}
if
(
step
==
10
)
{
// LOG("after backward --------");
// lossTensor.mem->ShowMemUsage(stderr);
// exit(0);
}
if
(
step
%
100
==
0
)
{
if
(
step
%
logInterval
==
0
)
{
double
elapsed
=
GetClockSec
()
-
startT
;
LOG
(
"elapsed=%.1fs, step=%d, epoch=%d, "
"total word=%d, total batch=%d, loss=%.3f, ppl=%.3f, lr=%.2e"
,
...
...
@@ -256,13 +309,13 @@ void Trainer::Train(const char* fn, const char* validFN,
}
}
if
(
isEnd
)
break
;
if
(
useEpochCheckpoint
)
MakeCheckpoint
(
model
,
validFN
,
modelFN
,
"epoch"
,
epoch
);
}
batchLoader
.
ClearSamples
(
buf
);
delete
buf
;
double
elapsed
=
GetClockSec
()
-
startT
;
epoch
=
MIN
(
epoch
,
nepoch
);
...
...
@@ -287,8 +340,12 @@ test the model
*/
void
Trainer
::
Validate
(
const
char
*
fn
,
const
char
*
ofn
,
Model
*
model
)
{
double
startT
=
GetClockSec
();
DISABLE_GRAD
;
int
wc
=
0
;
int
ws
=
0
;
int
sc
=
0
;
int
wordCount
=
0
;
int
sentCount
=
0
;
float
loss
=
0
;
...
...
@@ -296,9 +353,14 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
/* data files */
batchLoader
.
Init
(
fn
,
0
,
false
);
double
startT
=
GetClockSec
();
int
curIdx
=
0
;
XList
*
buf
=
new
XList
;
/* set the buffer size to the size of valiadation set */
batchLoader
.
bufferSize
=
batchLoader
.
totalSampleNum
;
batchLoader
.
LoadBatchToBuf
(
buf
);
while
(
!
batchLoader
.
IsEmpty
()
)
while
(
curIdx
<
buf
->
count
)
{
/* batch of input sequences */
XTensor
batchEnc
;
...
...
@@ -318,10 +380,9 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
XTensor
labelOnehot
;
XTensor
lossTensor
;
UInt64List
info
=
batchLoader
.
LoadBatch
(
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
0
,
model
->
devID
);
wc
=
(
int
)
info
[
0
];
ws
=
(
int
)
info
[
1
];
batchLoader
.
LoadBatch
(
buf
,
curIdx
,
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
0
,
model
->
devID
,
wc
,
sc
);
CheckNTErrors
(
batchEnc
.
order
==
2
,
"Wrong tensor order of the sequence batch"
);
/* make the network */
...
...
@@ -337,18 +398,31 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
int
length
=
output
.
GetDim
(
1
);
labelOnehot
=
IndexToOnehot
(
label
,
vSizeTgt
,
0
);
lossTensor
=
CrossEntropy
(
output
,
labelOnehot
,
paddingDec
);
float
lossBatch
=
ReduceSumAllValue
(
lossTensor
);
loss
+=
lossBatch
;
wordCount
+=
wc
;
sentCount
+=
bSize
;
if
(
model
->
encoder
->
useHistory
)
model
->
encoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
if
(
model
->
decoder
->
useHistory
)
model
->
decoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
}
batchLoader
.
ClearSamples
(
buf
);
delete
buf
;
double
elapsed
=
GetClockSec
()
-
startT
;
LOG
(
"test finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)"
,
ENABLE_GRAD
;
LOG
(
"validating finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)"
,
elapsed
,
sentCount
,
wordCount
,
loss
/
wordCount
/
log
(
2.0
),
exp
(
loss
/
wordCount
));
}
...
...
@@ -428,7 +502,7 @@ void Trainer::Update(Model* model, const float lr)
_ScaleAndShiftMe
(
v
,
(
1.0
F
-
adamBeta2
),
0
);
/* v2 = m / (sqrt(v) + delta) */
XTensor
*
v2
=
NewTensorBuf
(
v
,
v
->
devID
);
XTensor
*
v2
=
NewTensorBuf
V2
(
v
,
v
->
devID
,
v
->
mem
);
_Power
(
v
,
v2
,
0.5
F
);
_ScaleAndShiftMe
(
v2
,
1.0
F
,
d
);
_Div
(
m
,
v2
,
v2
);
...
...
source/sample/transformer/train/Trainer.h
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -70,6 +70,9 @@ public:
/* traing step number */
int
nstep
;
/* interval step for logging */
int
logInterval
;
/* the maximum number of saved checkpoints */
int
maxCheckpoint
;
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论