Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
98a9130d
Commit
98a9130d
authored
Feb 28, 2021
by
hello
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor class `TrainDataSet`
parent
4bbd6a27
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
376 行增加
和
268 行删除
+376
-268
source/sample/transformer/train/TrainDataSet.cpp
+196
-190
source/sample/transformer/train/TrainDataSet.h
+63
-38
source/sample/transformer/train/Trainer.cpp
+113
-39
source/sample/transformer/train/Trainer.h
+4
-1
没有找到文件。
source/sample/transformer/train/TrainDataSet.cpp
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -16,13 +16,10 @@
...
@@ -16,13 +16,10 @@
/*
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2020-08-09
* $Created by: HU Chi (huchinlp@foxmail.com) 2020-08-09
*
TODO: refactor the data loader class and references
*
$Updated by: CAO Hang and Wu Siming 2020-12-13
*/
*/
#include <string>
#include <vector>
#include <cstdlib>
#include <cstdlib>
#include <fstream>
#include <algorithm>
#include <algorithm>
#include "TrainDataSet.h"
#include "TrainDataSet.h"
...
@@ -33,37 +30,56 @@ using namespace nmt;
...
@@ -33,37 +30,56 @@ using namespace nmt;
namespace
nts
{
namespace
nts
{
/* sort the dataset by length (in descending order) */
/* get the maximum source sentence length in a range */
void
TrainDataSet
::
SortByLength
()
{
int
TrainDataSet
::
MaxSrcLen
(
XList
*
buf
,
int
begin
,
int
end
)
{
sort
(
buffer
.
items
,
buffer
.
items
+
buffer
.
count
,
CheckNTErrors
((
end
>
begin
)
&&
(
begin
>=
0
)
&&
(
end
<=
buf
->
count
),
"Invalid range"
);
[](
TrainExample
*
a
,
TrainExample
*
b
)
{
int
maxLen
=
0
;
return
(
a
->
srcSent
.
Size
()
+
a
->
tgtSent
.
Size
())
for
(
int
i
=
begin
;
i
<
end
;
i
++
)
{
>
(
b
->
srcSent
.
Size
()
+
b
->
tgtSent
.
Size
());
IntList
*
srcSent
=
((
TrainExample
*
)
buf
->
Get
(
i
))
->
srcSent
;
});
maxLen
=
MAX
(
int
(
srcSent
->
Size
()),
maxLen
);
}
return
maxLen
;
}
}
/* sort buckets by key (in descending order) */
/* get the maximum target sentence length in a range */
void
TrainDataSet
::
SortBucket
()
{
int
TrainDataSet
::
MaxTgtLen
(
XList
*
buf
,
int
begin
,
int
end
)
{
sort
(
buffer
.
items
,
buffer
.
items
+
buffer
.
count
,
CheckNTErrors
((
end
>
begin
)
&&
(
begin
>=
0
)
&&
(
end
<=
buf
->
count
),
"Invalid range"
);
[](
TrainExample
*
a
,
TrainExample
*
b
)
{
int
maxLen
=
0
;
return
a
->
bucketKey
>
b
->
bucketKey
;
for
(
int
i
=
begin
;
i
<
end
;
i
++
)
{
});
IntList
*
tgtSent
=
((
TrainExample
*
)
buf
->
Get
(
i
))
->
tgtSent
;
maxLen
=
MAX
(
int
(
tgtSent
->
Size
()),
maxLen
);
}
return
maxLen
;
}
}
/*
/* sort the buffer by source sentence length (in descending order) */
sort the output by key in a range (in descending order)
void
TrainDataSet
::
SortBySrcLength
(
XList
*
buf
)
{
>> begin - the first index of the range
stable_sort
(
buf
->
items
,
buf
->
items
+
buf
->
count
,
>> end - the last index of the range
[](
void
*
a
,
void
*
b
)
{
*/
return
((
TrainExample
*
)(
a
))
->
srcSent
->
Size
()
<
void
TrainDataSet
::
SortInBucket
(
int
begin
,
int
end
)
{
((
TrainExample
*
)(
b
))
->
srcSent
->
Size
();
sort
(
buffer
.
items
+
begin
,
buffer
.
items
+
end
,
});
[](
TrainExample
*
a
,
TrainExample
*
b
)
{
}
return
(
a
->
key
>
b
->
key
);
/* sort the buffer by target sentence length (in descending order) */
});
void
TrainDataSet
::
SortByTgtLength
(
XList
*
buf
)
{
stable_sort
(
buf
->
items
,
buf
->
items
+
buf
->
count
,
[](
void
*
a
,
void
*
b
)
{
return
((
TrainExample
*
)(
a
))
->
tgtSent
->
Size
()
<
((
TrainExample
*
)(
b
))
->
tgtSent
->
Size
();
});
}
/* sort buckets by key (in descending order) */
void
TrainDataSet
::
SortBuckets
(
XList
*
buf
)
{
sort
(
buf
->
items
,
buf
->
items
+
buf
->
count
,
[](
void
*
a
,
void
*
b
)
{
return
((
TrainExample
*
)(
a
))
->
bucketKey
<
((
TrainExample
*
)(
b
))
->
bucketKey
;
});
}
}
/*
/*
load
all data from a file
to the buffer
load
samples from a file in
to the buffer
training data format (binary):
training data format (binary):
first 8 bit: number of sentence pairs
first 8 bit: number of sentence pairs
subsequent segements:
subsequent segements:
...
@@ -71,52 +87,63 @@ source sentence length (4 bit)
...
@@ -71,52 +87,63 @@ source sentence length (4 bit)
target sentence length (4 bit)
target sentence length (4 bit)
source tokens (4 bit per token)
source tokens (4 bit per token)
target tokens (4 bit per token)
target tokens (4 bit per token)
>> buf - the buffer (list) of samples
*/
*/
void
TrainDataSet
::
LoadDataToBuffer
(
)
bool
TrainDataSet
::
LoadBatchToBuf
(
XList
*
buf
)
{
{
buffer
.
Clear
();
ClearSamples
(
buf
);
curIdx
=
0
;
int
id
=
0
;
uint64_t
sentNum
=
0
;
int
srcVocabSize
=
0
;
int
sampleNum
=
0
;
int
tgtVocabSize
=
0
;
fread
(
&
srcVocabSize
,
sizeof
(
srcVocabSize
),
1
,
fp
);
fread
(
&
tgtVocabSize
,
sizeof
(
tgtVocabSize
),
1
,
fp
);
fread
(
&
sentNum
,
sizeof
(
uint64_t
),
1
,
fp
);
while
((
sampleNum
<
bufferSize
))
{
CheckNTErrors
(
sentNum
>
0
,
"Invalid sentence pairs number"
);
while
(
id
<
sentNum
)
{
int
srcLen
=
0
;
int
srcLen
=
0
;
int
tgtLen
=
0
;
int
tgtLen
=
0
;
fread
(
&
srcLen
,
sizeof
(
int
),
1
,
fp
);
size_t
n
=
fread
(
&
srcLen
,
sizeof
(
int
),
1
,
fp
);
if
(
n
==
0
)
break
;
fread
(
&
tgtLen
,
sizeof
(
int
),
1
,
fp
);
fread
(
&
tgtLen
,
sizeof
(
int
),
1
,
fp
);
CheckNTErrors
(
srcLen
>
0
,
"Invalid source sentence length"
);
CheckNTErrors
(
srcLen
>
0
,
"Invalid source sentence length"
);
CheckNTErrors
(
tgtLen
>
0
,
"Invalid target sentence length"
);
CheckNTErrors
(
tgtLen
>
0
,
"Invalid target sentence length"
);
IntList
srcSent
;
IntList
*
srcSent
=
new
IntList
(
srcLen
);
IntList
tgtSent
;
IntList
*
tgtSent
=
new
IntList
(
tgtLen
);
srcSent
.
ReadFromFile
(
fp
,
srcLen
);
srcSent
->
ReadFromFile
(
fp
,
srcLen
);
tgtSent
.
ReadFromFile
(
fp
,
tgtLen
);
tgtSent
->
ReadFromFile
(
fp
,
tgtLen
);
TrainExample
*
example
=
new
TrainExample
(
sampleNum
++
,
0
,
srcSent
,
tgtSent
);
buf
->
Add
(
example
);
}
/* reset the file pointer to the begin */
if
(
feof
(
fp
)
&&
isTraining
)
{
TrainExample
*
example
=
new
TrainExample
;
rewind
(
fp
);
example
->
id
=
id
++
;
example
->
key
=
id
;
example
->
srcSent
=
srcSent
;
example
->
tgtSent
=
tgtSent
;
buffer
.
Add
(
example
);
int
srcVocabSize
=
0
;
int
tgtVocabSize
=
0
;
fread
(
&
srcVocabSize
,
sizeof
(
int
),
1
,
fp
);
fread
(
&
tgtVocabSize
,
sizeof
(
int
),
1
,
fp
);
fread
(
&
totalSampleNum
,
sizeof
(
totalSampleNum
),
1
,
fp
);
}
}
fclose
(
fp
);
/* group samples in the buffer into buckets */
SortByTgtLength
(
buf
);
SortBySrcLength
(
buf
);
if
(
isTraining
)
BuildBucket
(
buf
);
XPRINT1
(
0
,
stderr
,
"[INFO] loaded %d sentences
\n
"
,
id
)
;
return
true
;
}
}
/*
/*
load a mini-batch to the device (for training)
load a mini-batch to a device
>> buf - the buffer (list) of samples
>> curIdx - the index of the buffer
>> batchEnc - a tensor to store the batch of encoder input
>> batchEnc - a tensor to store the batch of encoder input
>> paddingEnc - a tensor to store the batch of encoder paddings
>> paddingEnc - a tensor to store the batch of encoder paddings
>> batchDec - a tensor to store the batch of decoder input
>> batchDec - a tensor to store the batch of decoder input
...
@@ -125,57 +152,34 @@ load a mini-batch to the device (for training)
...
@@ -125,57 +152,34 @@ load a mini-batch to the device (for training)
>> minSentBatch - the minimum number of sentence batch
>> minSentBatch - the minimum number of sentence batch
>> batchSize - the maxium number of words in a batch
>> batchSize - the maxium number of words in a batch
>> devID - the device id, -1 for the CPU
>> devID - the device id, -1 for the CPU
<< return - number of target tokens and sentences
>> wc - number of target words in a batch
>> sc - number of samples in a batch
*/
*/
UInt64List
TrainDataSet
::
LoadBatch
(
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
bool
TrainDataSet
::
LoadBatch
(
XList
*
buf
,
int
&
curIdx
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
size_t
minSentBatch
,
size_t
batchSize
,
int
devID
)
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
int
minSentBatch
,
int
batchSize
,
int
devID
,
int
&
wc
,
int
&
sc
)
{
{
UInt64List
info
;
int
srcTokenNum
=
0
;
size_t
srcTokenNum
=
0
;
int
tgtTokenNum
=
0
;
size_t
tgtTokenNum
=
0
;
int
realBatchSize
=
0
;
size_t
realBatchSize
=
1
;
/* dynamic batching for sentences */
if
(
!
isTraining
)
int
bucketKey
=
((
TrainExample
*
)(
buf
->
Get
(
curIdx
)))
->
bucketKey
;
realBatchSize
=
minSentBatch
;
while
((
realBatchSize
<
(
int
(
buf
->
Size
())
-
curIdx
))
&&
(((
TrainExample
*
)(
buf
->
Get
(
curIdx
+
realBatchSize
)))
->
bucketKey
==
bucketKey
))
{
/* get the maximum source sentence length in a mini-batch */
realBatchSize
++
;
size_t
maxSrcLen
=
buffer
[(
int
)
curIdx
]
->
srcSent
.
Size
();
/* max batch size */
const
int
MAX_BATCH_SIZE
=
512
;
/* dynamic batching for sentences, enabled when the dataset is used for training */
if
(
isTraining
)
{
while
((
realBatchSize
<
(
buffer
.
Size
()
-
curIdx
))
&&
(
realBatchSize
*
maxSrcLen
<
batchSize
)
&&
(
realBatchSize
<
MAX_BATCH_SIZE
)
&&
(
realBatchSize
*
buffer
[(
int
)(
curIdx
+
realBatchSize
)]
->
srcSent
.
Size
()
<
batchSize
))
{
if
(
maxSrcLen
<
buffer
[(
int
)(
curIdx
+
realBatchSize
)]
->
srcSent
.
Size
())
maxSrcLen
=
buffer
[(
int
)(
curIdx
+
realBatchSize
)]
->
srcSent
.
Size
();
realBatchSize
++
;
}
}
/* real batch size */
if
((
buffer
.
Size
()
-
curIdx
)
<
realBatchSize
)
{
realBatchSize
=
buffer
.
Size
()
-
curIdx
;
}
}
realBatchSize
=
MIN
(
realBatchSize
,
(
int
(
buf
->
Size
())
-
curIdx
));
CheckNTErrors
(
realBatchSize
>
0
,
"Invalid batch size"
);
CheckNTErrors
(
realBatchSize
>
0
,
"Invalid batch size"
);
/* get the maximum target sentence length in a mini-batch */
/* get the maximum target sentence length in a mini-batch */
size_t
maxTgtLen
=
buffer
[(
int
)
curIdx
]
->
tgtSent
.
Size
();
int
maxSrcLen
=
MaxSrcLen
(
buf
,
curIdx
,
curIdx
+
realBatchSize
);
for
(
size_t
i
=
0
;
i
<
realBatchSize
;
i
++
)
{
int
maxTgtLen
=
MaxTgtLen
(
buf
,
curIdx
,
curIdx
+
realBatchSize
);
if
(
maxTgtLen
<
buffer
[(
int
)(
curIdx
+
i
)]
->
tgtSent
.
Size
())
maxTgtLen
=
buffer
[(
int
)(
curIdx
+
i
)]
->
tgtSent
.
Size
();
}
for
(
size_t
i
=
0
;
i
<
realBatchSize
;
i
++
)
{
if
(
maxSrcLen
<
buffer
[(
int
)(
curIdx
+
i
)]
->
srcSent
.
Size
())
maxSrcLen
=
buffer
[(
int
)(
curIdx
+
i
)]
->
srcSent
.
Size
();
}
CheckNTErrors
(
maxSrcLen
!=
0
,
"Invalid source length for batching"
);
CheckNTErrors
(
maxSrcLen
>
0
,
"Invalid source length for batching"
);
CheckNTErrors
(
maxTgtLen
>
0
,
"Invalid target length for batching"
);
int
*
batchEncValues
=
new
int
[
realBatchSize
*
maxSrcLen
];
int
*
batchEncValues
=
new
int
[
realBatchSize
*
maxSrcLen
];
float
*
paddingEncValues
=
new
float
[
realBatchSize
*
maxSrcLen
];
float
*
paddingEncValues
=
new
float
[
realBatchSize
*
maxSrcLen
];
...
@@ -185,17 +189,17 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
...
@@ -185,17 +189,17 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
float
*
paddingDecValues
=
new
float
[
realBatchSize
*
maxTgtLen
];
float
*
paddingDecValues
=
new
float
[
realBatchSize
*
maxTgtLen
];
for
(
int
i
=
0
;
i
<
realBatchSize
*
maxSrcLen
;
i
++
)
{
for
(
int
i
=
0
;
i
<
realBatchSize
*
maxSrcLen
;
i
++
)
{
batchEncValues
[
i
]
=
PAD
;
batchEncValues
[
i
]
=
1
;
paddingEncValues
[
i
]
=
1
;
paddingEncValues
[
i
]
=
1
.0
F
;
}
}
for
(
int
i
=
0
;
i
<
realBatchSize
*
maxTgtLen
;
i
++
)
{
for
(
int
i
=
0
;
i
<
realBatchSize
*
maxTgtLen
;
i
++
)
{
batchDecValues
[
i
]
=
PAD
;
batchDecValues
[
i
]
=
1
;
labelVaues
[
i
]
=
PAD
;
labelVaues
[
i
]
=
1
;
paddingDecValues
[
i
]
=
1.0
F
;
paddingDecValues
[
i
]
=
1.0
F
;
}
}
size_
t
curSrc
=
0
;
in
t
curSrc
=
0
;
size_
t
curTgt
=
0
;
in
t
curTgt
=
0
;
/*
/*
batchEnc: end with EOS (left padding)
batchEnc: end with EOS (left padding)
...
@@ -204,35 +208,33 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
...
@@ -204,35 +208,33 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
*/
*/
for
(
int
i
=
0
;
i
<
realBatchSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
realBatchSize
;
++
i
)
{
srcTokenNum
+=
buffer
[(
int
)(
curIdx
+
i
)]
->
srcSent
.
Size
();
TrainExample
*
sample
=
(
TrainExample
*
)(
buf
->
Get
(
curIdx
+
i
));
tgtTokenNum
+=
buffer
[(
int
)(
curIdx
+
i
)]
->
tgtSent
.
Size
();
srcTokenNum
+=
int
(
sample
->
srcSent
->
Size
());
tgtTokenNum
+=
int
(
sample
->
tgtSent
->
Size
());
curSrc
=
maxSrcLen
*
i
;
curSrc
=
maxSrcLen
*
i
;
for
(
int
j
=
0
;
j
<
buffer
[(
int
)(
curIdx
+
i
)]
->
srcSent
.
Size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
sample
->
srcSent
->
Size
();
j
++
)
{
batchEncValues
[
curSrc
++
]
=
buffer
[(
int
)(
curIdx
+
i
)]
->
srcSent
[
j
]
;
batchEncValues
[
curSrc
++
]
=
sample
->
srcSent
->
Get
(
j
)
;
}
}
curTgt
=
maxTgtLen
*
i
;
curTgt
=
maxTgtLen
*
i
;
for
(
int
j
=
0
;
j
<
buffer
[(
int
)(
curIdx
+
i
)]
->
tgtSent
.
Size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
sample
->
tgtSent
->
Size
();
j
++
)
{
if
(
j
>
0
)
if
(
j
>
0
)
labelVaues
[
curTgt
-
1
]
=
buffer
[(
int
)(
curIdx
+
i
)]
->
tgtSent
[
j
]
;
labelVaues
[
curTgt
-
1
]
=
sample
->
tgtSent
->
Get
(
j
)
;
batchDecValues
[
curTgt
++
]
=
buffer
[(
int
)(
curIdx
+
i
)]
->
tgtSent
[
j
]
;
batchDecValues
[
curTgt
++
]
=
sample
->
tgtSent
->
Get
(
j
)
;
}
}
labelVaues
[
curTgt
-
1
]
=
EOS
;
labelVaues
[
curTgt
-
1
]
=
2
;
while
(
curSrc
<
maxSrcLen
*
(
i
+
1
))
while
(
curSrc
<
maxSrcLen
*
(
i
+
1
))
paddingEncValues
[
curSrc
++
]
=
0
;
paddingEncValues
[
curSrc
++
]
=
0
;
while
(
curTgt
<
maxTgtLen
*
(
i
+
1
))
while
(
curTgt
<
maxTgtLen
*
(
i
+
1
))
paddingDecValues
[
curTgt
++
]
=
0
;
paddingDecValues
[
curTgt
++
]
=
0
;
}
}
int
rbs
=
(
int
)
realBatchSize
;
InitTensor2D
(
batchEnc
,
realBatchSize
,
maxSrcLen
,
X_INT
,
devID
);
int
msl
=
(
int
)
maxSrcLen
;
InitTensor2D
(
paddingEnc
,
realBatchSize
,
maxSrcLen
,
X_FLOAT
,
devID
);
InitTensor2D
(
batchEnc
,
rbs
,
msl
,
X_INT
,
devID
);
InitTensor2D
(
batchDec
,
realBatchSize
,
maxTgtLen
,
X_INT
,
devID
);
InitTensor2D
(
paddingEnc
,
rbs
,
msl
,
X_FLOAT
,
devID
);
InitTensor2D
(
paddingDec
,
realBatchSize
,
maxTgtLen
,
X_FLOAT
,
devID
);
InitTensor2D
(
batchDec
,
rbs
,
msl
,
X_INT
,
devID
);
InitTensor2D
(
label
,
realBatchSize
,
maxTgtLen
,
X_INT
,
devID
);
InitTensor2D
(
paddingDec
,
rbs
,
msl
,
X_FLOAT
,
devID
);
InitTensor2D
(
label
,
rbs
,
msl
,
X_INT
,
devID
);
curIdx
+=
realBatchSize
;
curIdx
+=
realBatchSize
;
...
@@ -248,9 +250,22 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
...
@@ -248,9 +250,22 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
delete
[]
paddingDecValues
;
delete
[]
paddingDecValues
;
delete
[]
labelVaues
;
delete
[]
labelVaues
;
info
.
Add
(
tgtTokenNum
);
wc
=
tgtTokenNum
;
info
.
Add
(
realBatchSize
);
sc
=
realBatchSize
;
return
info
;
return
true
;
}
/*
clear the buffer
>> buf - the buffer (list) of samples
*/
void
TrainDataSet
::
ClearSamples
(
XList
*
buf
)
{
for
(
int
i
=
0
;
i
<
buf
->
count
;
i
++
)
{
TrainExample
*
sample
=
(
TrainExample
*
)
buf
->
Get
(
i
);
delete
sample
;
}
buf
->
Clear
();
}
}
/*
/*
...
@@ -263,98 +278,90 @@ void TrainDataSet::Init(const char* dataFile, int myBucketSize, bool training)
...
@@ -263,98 +278,90 @@ void TrainDataSet::Init(const char* dataFile, int myBucketSize, bool training)
{
{
fp
=
fopen
(
dataFile
,
"rb"
);
fp
=
fopen
(
dataFile
,
"rb"
);
CheckNTErrors
(
fp
,
"can not open the training file"
);
CheckNTErrors
(
fp
,
"can not open the training file"
);
curIdx
=
0
;
bucketSize
=
myBucketSize
;
isTraining
=
training
;
LoadDataToBuffer
();
SortByLength
();
if
(
isTraining
)
BuildBucket
();
}
/* check if the buffer is empty */
bool
TrainDataSet
::
IsEmpty
()
{
if
(
curIdx
<
buffer
.
Size
())
return
false
;
return
true
;
}
/* reset the buffer */
int
srcVocabSize
=
0
;
void
TrainDataSet
::
ClearBuf
()
int
tgtVocabSize
=
0
;
{
fread
(
&
srcVocabSize
,
sizeof
(
int
),
1
,
fp
);
curIdx
=
0
;
fread
(
&
tgtVocabSize
,
sizeof
(
int
),
1
,
fp
);
CheckNTErrors
(
srcVocabSize
>
0
,
"Invalid source vocabulary size"
);
CheckNTErrors
(
tgtVocabSize
>
0
,
"Invalid target vocabulary size"
);
/* make different batches in different epochs */
fread
(
&
totalSampleNum
,
sizeof
(
totalSampleNum
),
1
,
fp
);
SortByLength
(
);
CheckNTErrors
(
totalSampleNum
>
0
,
"Invalid sentence pairs number"
);
if
(
isTraining
)
bucketSize
=
myBucketSize
;
BuildBucket
()
;
isTraining
=
training
;
}
}
/* group data
into buckets with similar length
*/
/* group data
with similar length into buckets
*/
void
TrainDataSet
::
BuildBucket
()
void
TrainDataSet
::
BuildBucket
(
XList
*
buf
)
{
{
size_
t
idx
=
0
;
in
t
idx
=
0
;
/* build
and shuffle bucket
s */
/* build
buckets by the length of source and target sentence
s */
while
(
idx
<
buffer
.
Size
(
))
{
while
(
idx
<
int
(
buf
->
Size
()
))
{
/* sentence number in a bucket */
/* sentence number in a bucket */
size_
t
sentNum
=
1
;
in
t
sentNum
=
1
;
/* get the maximum source sentence length in a bucket */
/* get the maximum source sentence length in a bucket */
size_t
maxSrcLen
=
buffer
[(
int
)
idx
]
->
srcSent
.
Size
();
int
maxSrcLen
=
MaxSrcLen
(
buf
,
idx
,
idx
+
sentNum
);
int
maxTgtLen
=
MaxTgtLen
(
buf
,
idx
,
idx
+
sentNum
);
/* bucketing for sentences */
int
maxLen
=
MAX
(
maxSrcLen
,
maxTgtLen
);
while
((
sentNum
<
(
buffer
.
Size
()
-
idx
))
&&
(
sentNum
*
maxSrcLen
<
bucketSize
)
/* the maximum sentence number in a bucket */
&&
(
sentNum
*
buffer
[(
int
)(
curIdx
+
sentNum
)]
->
srcSent
.
Size
()
<
bucketSize
))
{
const
int
MAX_SENT_NUM
=
5120
;
if
(
maxSrcLen
<
buffer
[(
int
)(
idx
+
sentNum
)]
->
srcSent
.
Size
())
maxSrcLen
=
buffer
[(
int
)(
idx
+
sentNum
)]
->
srcSent
.
Size
();
while
((
sentNum
<
(
buf
->
count
-
idx
))
&&
(
sentNum
<
MAX_SENT_NUM
)
&&
(
sentNum
*
maxLen
<=
bucketSize
))
{
sentNum
++
;
sentNum
++
;
maxSrcLen
=
MaxSrcLen
(
buf
,
idx
,
idx
+
sentNum
);
maxTgtLen
=
MaxTgtLen
(
buf
,
idx
,
idx
+
sentNum
);
maxLen
=
MAX
(
maxSrcLen
,
maxTgtLen
);
}
}
/* make sure the number is valid */
/* make sure the number is valid */
if
((
buffer
.
Size
()
-
idx
)
<
sentNum
)
{
if
((
sentNum
)
*
maxLen
>
bucketSize
||
sentNum
>=
MAX_SENT_NUM
)
{
sentNum
=
buffer
.
Size
()
-
idx
;
sentNum
--
;
sentNum
=
max
(
8
*
(
sentNum
/
8
),
sentNum
%
8
);
}
}
if
((
int
(
buf
->
Size
())
-
idx
)
<
sentNum
)
sentNum
=
int
(
buf
->
Size
())
-
idx
;
/* assign the same key for items in a bucket */
int
randomKey
=
rand
();
int
randomKey
=
rand
();
for
(
int
i
=
0
;
i
<
sentNum
;
i
++
)
{
/* shuffle items in a bucket */
((
TrainExample
*
)(
buf
->
Get
(
idx
+
i
)))
->
bucketKey
=
randomKey
;
for
(
size_t
i
=
0
;
i
<
sentNum
;
i
++
)
{
buffer
[(
int
)(
idx
+
i
)]
->
bucketKey
=
randomKey
;
}
}
idx
+=
sentNum
;
idx
+=
sentNum
;
}
}
SortBucket
();
/* sort buckets by their keys */
/* sort items in a bucket */
SortBuckets
(
buf
);
idx
=
0
;
while
(
idx
<
buffer
.
Size
())
{
size_t
sentNum
=
0
;
int
bucketKey
=
buffer
[(
int
)(
idx
+
sentNum
)]
->
bucketKey
;
while
(
sentNum
<
(
buffer
.
Size
()
-
idx
)
&&
buffer
[(
int
)(
idx
+
sentNum
)]
->
bucketKey
==
bucketKey
)
{
buffer
[(
int
)(
idx
+
sentNum
)]
->
key
=
(
int
)
buffer
[(
int
)(
idx
+
sentNum
)]
->
srcSent
.
Size
();
sentNum
++
;
}
SortInBucket
((
int
)
idx
,
(
int
)(
idx
+
sentNum
));
idx
+=
sentNum
;
}
}
}
/* de-constructor */
/* de-constructor */
TrainDataSet
::~
TrainDataSet
()
TrainDataSet
::~
TrainDataSet
()
{
{
fclose
(
fp
);
}
/* constructor */
TrainExample
::
TrainExample
(
int
myID
,
int
myKey
,
IntList
*
s
,
IntList
*
t
)
{
id
=
myID
;
bucketKey
=
myKey
;
srcSent
=
s
;
tgtSent
=
t
;
}
/* release the buffer */
/* de-constructor */
for
(
int
i
=
0
;
i
<
buffer
.
Size
();
i
++
)
TrainExample
::~
TrainExample
()
delete
buffer
[
i
];
{
delete
srcSent
;
delete
tgtSent
;
}
}
}
}
\ No newline at end of file
source/sample/transformer/train/TrainDataSet.h
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -30,7 +30,6 @@
...
@@ -30,7 +30,6 @@
#include "../../../tensor/XTensor.h"
#include "../../../tensor/XTensor.h"
#include "../../../tensor/XGlobal.h"
#include "../../../tensor/XGlobal.h"
#define MAX_WORD_NUM 120
using
namespace
std
;
using
namespace
std
;
...
@@ -39,39 +38,54 @@ namespace nts {
...
@@ -39,39 +38,54 @@ namespace nts {
/* a class of sentence pairs for training */
/* a class of sentence pairs for training */
struct
TrainExample
{
struct
TrainExample
{
public
:
/* id of the sentence pair */
/* id of the sentence pair */
int
id
;
int
id
;
/* source language setence (tokenized) */
/* source language setence (tokenized) */
IntList
srcSent
;
IntList
*
srcSent
;
/* target language setence (tokenized) */
/* target language setence (tokenized) */
IntList
tgtSent
;
IntList
*
tgtSent
;
/* the key used to shuffle items in a bucket */
int
key
;
/* the key used to shuffle buckets */
/* the key used to shuffle buckets */
int
bucketKey
;
int
bucketKey
;
public
:
/* constructor */
TrainExample
(
int
myID
,
int
myKey
,
IntList
*
s
,
IntList
*
t
);
/* de-constructor */
~
TrainExample
();
};
struct
ReservedIDs
{
/* the padding id */
int
padID
;
/* the unk id */
int
unkID
;
/* start symbol */
int
startID
;
/* end symbol */
int
endID
;
};
};
/* A `TrainDataSet` is associated with a file which contains training data. */
/* A `TrainDataSet` is associated with a file which contains training data. */
struct
TrainDataSet
{
struct
TrainDataSet
{
public
:
/* the data buffer */
TrainBufferType
buffer
;
/* a list of empty line number */
public
:
IntList
emptyLines
;
/* the pointer to file stream */
/* the pointer to file stream */
FILE
*
fp
;
FILE
*
fp
;
/*
current index in the buffer
*/
/*
number of training samples
*/
size_t
curIdx
;
size_t
totalSampleNum
;
/*
size of used data in the buffer
*/
/*
buffer size
*/
size_t
buffer
Used
;
size_t
buffer
Size
;
/* size of the bucket used for grouping sentences */
/* size of the bucket used for grouping sentences */
size_t
bucketSize
;
size_t
bucketSize
;
...
@@ -79,34 +93,51 @@ public:
...
@@ -79,34 +93,51 @@ public:
/* indicates whether it is used for training */
/* indicates whether it is used for training */
bool
isTraining
;
bool
isTraining
;
/* the padding id */
int
padID
;
/* the unk id */
int
unkID
;
/* start symbol */
int
startID
;
/* end symbol */
int
endID
;
/* the maximum length for a source sentence */
int
maxSrcLen
;
/* the maximum length for a target sentence */
int
maxTgtLen
;
public
:
public
:
/* sort the input by length (in descending order) */
/* get the maximum source sentence length in a range */
void
SortByLength
();
static
int
MaxSrcLen
(
XList
*
buf
,
int
begin
,
int
end
);
/* sort buckets by key (in descending order) */
/* get the maximum target sentence length in a range */
void
SortBucket
();
static
int
MaxTgtLen
(
XList
*
buf
,
int
begin
,
int
end
);
/* sort the
output by key
(in descending order) */
/* sort the
input by source sentence length
(in descending order) */
void
Sort
InBucket
(
int
begin
,
int
end
);
void
Sort
BySrcLength
(
XList
*
buf
);
/*
load data from a file to the buffer
*/
/*
sort the input by target sentence length (in descending order)
*/
void
LoadDataToBuffer
(
);
void
SortByTgtLength
(
XList
*
buf
);
/* generate a mini-batch */
/* sort buckets by key (in descending order) */
UInt64List
LoadBatch
(
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
void
SortBuckets
(
XList
*
buf
);
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
size_t
minSentBatch
,
size_t
batchSize
,
int
devID
);
/* load the samples into the buffer (a list) */
/* load the samples into the buffer (a list) */
bool
LoadBatchToBuf
(
XList
*
buf
);
bool
LoadBatchToBuf
(
XList
*
buf
);
/* load the samples into tensors from the buffer */
/* load the samples into tensors from the buffer */
static
static
bool
LoadBatch
(
XList
*
buf
,
bool
LoadBatch
(
XList
*
buf
,
int
&
curIdx
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
int
minSentBatch
,
int
batchSize
,
int
devID
,
size_t
minSentBatch
,
size_t
batchSize
,
int
devID
,
int
&
wc
,
int
&
sc
);
int
&
wc
,
int
&
sc
);
/* release the samples in a buffer */
/* release the samples in a buffer */
...
@@ -116,14 +147,8 @@ public:
...
@@ -116,14 +147,8 @@ public:
/* initialization function */
/* initialization function */
void
Init
(
const
char
*
dataFile
,
int
bucketSize
,
bool
training
);
void
Init
(
const
char
*
dataFile
,
int
bucketSize
,
bool
training
);
/* check if the buffer is empty */
bool
IsEmpty
();
/* reset the buffer */
void
ClearBuf
();
/* group data into buckets with similar length */
/* group data into buckets with similar length */
void
BuildBucket
();
void
BuildBucket
(
XList
*
buf
);
/* de-constructor */
/* de-constructor */
~
TrainDataSet
();
~
TrainDataSet
();
...
...
source/sample/transformer/train/Trainer.cpp
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -39,6 +39,41 @@ namespace nmt
...
@@ -39,6 +39,41 @@ namespace nmt
Trainer
::
Trainer
()
Trainer
::
Trainer
()
{
{
cfg
=
NULL
;
cfg
=
NULL
;
lrate
=
0.0
F
;
lrbias
=
0.0
F
;
sBatchSize
=
0
;
wBatchSize
=
0
;
bucketSize
=
0
;
nstep
=
0
;
nepoch
=
0
;
logInterval
=
0
;
maxCheckpoint
=
0
;
d
=
0
;
nwarmup
=
0
;
vSize
=
0
;
vSizeTgt
=
0
;
useAdam
=
false
;
adamBeta1
=
0.0
F
;
adamBeta2
=
0.0
F
;
adamDelta
=
0.0
F
;
isShuffled
=
false
;
labelSmoothingP
=
0.0
F
;
nStepCheckpoint
=
0
;
useEpochCheckpoint
=
false
;
updateStep
=
0
;
isLenSorted
=
0
;
adamBeta1T
=
1.0
F
;
adamBeta2T
=
1.0
F
;
batchLoader
.
startID
=
0
;
batchLoader
.
endID
=
0
;
batchLoader
.
unkID
=
0
;
batchLoader
.
padID
=
0
;
batchLoader
.
maxSrcLen
=
0
;
batchLoader
.
maxTgtLen
=
0
;
batchLoader
.
bufferSize
=
0
;
}
}
/* de-constructor */
/* de-constructor */
...
@@ -62,13 +97,15 @@ initialization
...
@@ -62,13 +97,15 @@ initialization
void
Trainer
::
Init
(
Config
&
config
)
void
Trainer
::
Init
(
Config
&
config
)
{
{
cfg
=
&
config
;
cfg
=
&
config
;
lrate
=
config
.
lrate
;
lrate
=
config
.
lrate
;
lrbias
=
config
.
lrbias
;
lrbias
=
config
.
lrbias
;
sBatchSize
=
config
.
sBatchSize
;
sBatchSize
=
config
.
sBatchSize
;
wBatchSize
=
config
.
wBatchSize
;
wBatchSize
=
config
.
wBatchSize
;
bucketSize
=
config
.
bucketSize
;
bucketSize
=
config
.
bucketSize
;
nepoch
=
config
.
nepoch
;
nstep
=
config
.
nstep
;
nstep
=
config
.
nstep
;
nepoch
=
config
.
nepoch
;
logInterval
=
config
.
logInterval
;
maxCheckpoint
=
config
.
maxCheckpoint
;
maxCheckpoint
=
config
.
maxCheckpoint
;
d
=
config
.
modelSize
;
d
=
config
.
modelSize
;
nwarmup
=
config
.
nwarmup
;
nwarmup
=
config
.
nwarmup
;
...
@@ -87,6 +124,14 @@ void Trainer::Init(Config& config)
...
@@ -87,6 +124,14 @@ void Trainer::Init(Config& config)
adamBeta1T
=
1.0
F
;
adamBeta1T
=
1.0
F
;
adamBeta2T
=
1.0
F
;
adamBeta2T
=
1.0
F
;
batchLoader
.
startID
=
config
.
startID
;
batchLoader
.
endID
=
config
.
endID
;
batchLoader
.
unkID
=
config
.
unkID
;
batchLoader
.
padID
=
config
.
padID
;
batchLoader
.
maxSrcLen
=
config
.
maxSrcLen
;
batchLoader
.
maxTgtLen
=
config
.
maxTgtLen
;
batchLoader
.
bufferSize
=
config
.
bufSize
;
}
}
/*
/*
...
@@ -106,7 +151,7 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -106,7 +151,7 @@ void Trainer::Train(const char* fn, const char* validFN,
}
}
int
step
=
0
;
int
step
=
0
;
int
wc
=
0
;
int
wc
=
0
;
int
ws
=
0
;
int
sc
=
0
;
int
wordCount
=
0
;
int
wordCount
=
0
;
int
wordCountTotal
=
0
;
int
wordCountTotal
=
0
;
int
batchCountTotal
=
0
;
int
batchCountTotal
=
0
;
...
@@ -134,6 +179,9 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -134,6 +179,9 @@ void Trainer::Train(const char* fn, const char* validFN,
double
startT
=
GetClockSec
();
double
startT
=
GetClockSec
();
int
curIdx
=
0
;
XList
*
buf
=
new
XList
;
batchLoader
.
Init
(
fn
,
bucketSize
,
true
);
batchLoader
.
Init
(
fn
,
bucketSize
,
true
);
for
(
epoch
=
1
;
epoch
<=
nepoch
;
epoch
++
)
{
for
(
epoch
=
1
;
epoch
<=
nepoch
;
epoch
++
)
{
...
@@ -141,10 +189,7 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -141,10 +189,7 @@ void Trainer::Train(const char* fn, const char* validFN,
wordCount
=
0
;
wordCount
=
0
;
loss
=
0
;
loss
=
0
;
/* reset the batch loader */
while
(
step
++
<
nstep
)
batchLoader
.
ClearBuf
();
while
(
!
batchLoader
.
IsEmpty
())
{
{
XNet
net
;
XNet
net
;
net
.
Clear
();
net
.
Clear
();
...
@@ -160,21 +205,26 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -160,21 +205,26 @@ void Trainer::Train(const char* fn, const char* validFN,
XTensor
paddingEnc
;
XTensor
paddingEnc
;
XTensor
paddingDec
;
XTensor
paddingDec
;
UInt64List
info
=
batchLoader
.
LoadBatch
(
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
if
(
curIdx
==
0
||
curIdx
==
buf
->
Size
())
{
sBatchSize
,
wBatchSize
,
devID
);
curIdx
=
0
;
batchLoader
.
LoadBatchToBuf
(
buf
);
}
batchLoader
.
LoadBatch
(
buf
,
curIdx
,
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
wBatchSize
,
devID
,
wc
,
sc
);
wc
=
(
int
)
info
[
0
];
ws
=
(
int
)
info
[
1
];
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
/* output probabilities */
/* output probabilities */
XTensor
output
;
XTensor
output
;
/* make the network */
/* make the network */
if
(
model
->
isLM
)
if
(
model
->
isLM
)
{
model
->
MakeLM
(
batchEnc
,
output
,
paddingEnc
,
true
);
model
->
MakeLM
(
batchEnc
,
output
,
paddingEnc
,
true
);
else
if
(
model
->
isMT
)
}
else
if
(
model
->
isMT
)
{
model
->
MakeMT
(
batchEnc
,
batchDec
,
output
,
paddingEnc
,
paddingDec
,
true
);
model
->
MakeMT
(
batchEnc
,
batchDec
,
output
,
paddingEnc
,
paddingDec
,
true
);
}
else
{
else
{
ShowNTErrors
(
"Illegal model type!"
);
ShowNTErrors
(
"Illegal model type!"
);
}
}
...
@@ -192,15 +242,29 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -192,15 +242,29 @@ void Trainer::Train(const char* fn, const char* validFN,
DTYPE
lossLocal
=
lossBatch
/
wc
;
DTYPE
lossLocal
=
lossBatch
/
wc
;
bool
doUpdate
=
(
!
IsNAN
(
lossLocal
)
&&
!
IsINF
(
lossLocal
)
&&
lossLocal
<
1e3
F
);
bool
doUpdate
=
(
!
IsNAN
(
lossLocal
)
&&
!
IsINF
(
lossLocal
)
&&
lossLocal
<
1e3
F
);
net
.
isGradEfficient
=
true
;
bool
debug
(
false
);
if
(
debug
)
{
LOG
(
"after forward:"
);
batchEnc
.
mem
->
ShowMemUsage
(
stderr
);
exit
(
0
);
}
if
(
doUpdate
)
{
if
(
doUpdate
)
{
/* back-propagation */
net
.
Backward
(
lossTensor
);
net
.
Backward
(
lossTensor
);
if
(
model
->
encoder
->
useHistory
)
model
->
encoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
if
(
model
->
decoder
->
useHistory
)
model
->
decoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
gradStep
+=
1
;
gradStep
+=
1
;
loss
+=
lossBatch
;
loss
+=
lossBatch
;
wordCount
+=
wc
;
wordCount
+=
wc
;
wordCountTotal
+=
wc
;
wordCountTotal
+=
wc
;
batchCountTotal
+=
ws
;
batchCountTotal
+=
sc
;
/* update the parameters */
/* update the parameters */
if
(
gradStep
==
updateStep
)
{
if
(
gradStep
==
updateStep
)
{
...
@@ -227,18 +291,7 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -227,18 +291,7 @@ void Trainer::Train(const char* fn, const char* validFN,
else
else
nSkipped
++
;
nSkipped
++
;
if
(
++
step
>=
nstep
)
{
if
(
step
%
logInterval
==
0
)
{
isEnd
=
true
;
break
;
}
if
(
step
==
10
)
{
// LOG("after backward --------");
// lossTensor.mem->ShowMemUsage(stderr);
// exit(0);
}
if
(
step
%
100
==
0
)
{
double
elapsed
=
GetClockSec
()
-
startT
;
double
elapsed
=
GetClockSec
()
-
startT
;
LOG
(
"elapsed=%.1fs, step=%d, epoch=%d, "
LOG
(
"elapsed=%.1fs, step=%d, epoch=%d, "
"total word=%d, total batch=%d, loss=%.3f, ppl=%.3f, lr=%.2e"
,
"total word=%d, total batch=%d, loss=%.3f, ppl=%.3f, lr=%.2e"
,
...
@@ -256,13 +309,13 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -256,13 +309,13 @@ void Trainer::Train(const char* fn, const char* validFN,
}
}
}
}
if
(
isEnd
)
break
;
if
(
useEpochCheckpoint
)
if
(
useEpochCheckpoint
)
MakeCheckpoint
(
model
,
validFN
,
modelFN
,
"epoch"
,
epoch
);
MakeCheckpoint
(
model
,
validFN
,
modelFN
,
"epoch"
,
epoch
);
}
}
batchLoader
.
ClearSamples
(
buf
);
delete
buf
;
double
elapsed
=
GetClockSec
()
-
startT
;
double
elapsed
=
GetClockSec
()
-
startT
;
epoch
=
MIN
(
epoch
,
nepoch
);
epoch
=
MIN
(
epoch
,
nepoch
);
...
@@ -287,8 +340,12 @@ test the model
...
@@ -287,8 +340,12 @@ test the model
*/
*/
void
Trainer
::
Validate
(
const
char
*
fn
,
const
char
*
ofn
,
Model
*
model
)
void
Trainer
::
Validate
(
const
char
*
fn
,
const
char
*
ofn
,
Model
*
model
)
{
{
double
startT
=
GetClockSec
();
DISABLE_GRAD
;
int
wc
=
0
;
int
wc
=
0
;
int
ws
=
0
;
int
sc
=
0
;
int
wordCount
=
0
;
int
wordCount
=
0
;
int
sentCount
=
0
;
int
sentCount
=
0
;
float
loss
=
0
;
float
loss
=
0
;
...
@@ -296,9 +353,14 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
...
@@ -296,9 +353,14 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
/* data files */
/* data files */
batchLoader
.
Init
(
fn
,
0
,
false
);
batchLoader
.
Init
(
fn
,
0
,
false
);
double
startT
=
GetClockSec
();
int
curIdx
=
0
;
XList
*
buf
=
new
XList
;
/* set the buffer size to the size of valiadation set */
batchLoader
.
bufferSize
=
batchLoader
.
totalSampleNum
;
batchLoader
.
LoadBatchToBuf
(
buf
);
while
(
!
batchLoader
.
IsEmpty
()
)
while
(
curIdx
<
buf
->
count
)
{
{
/* batch of input sequences */
/* batch of input sequences */
XTensor
batchEnc
;
XTensor
batchEnc
;
...
@@ -318,10 +380,9 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
...
@@ -318,10 +380,9 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
XTensor
labelOnehot
;
XTensor
labelOnehot
;
XTensor
lossTensor
;
XTensor
lossTensor
;
UInt64List
info
=
batchLoader
.
LoadBatch
(
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
batchLoader
.
LoadBatch
(
buf
,
curIdx
,
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
0
,
model
->
devID
);
sBatchSize
,
0
,
model
->
devID
,
wc
,
sc
);
wc
=
(
int
)
info
[
0
];
ws
=
(
int
)
info
[
1
];
CheckNTErrors
(
batchEnc
.
order
==
2
,
"Wrong tensor order of the sequence batch"
);
CheckNTErrors
(
batchEnc
.
order
==
2
,
"Wrong tensor order of the sequence batch"
);
/* make the network */
/* make the network */
...
@@ -337,18 +398,31 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
...
@@ -337,18 +398,31 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
int
length
=
output
.
GetDim
(
1
);
int
length
=
output
.
GetDim
(
1
);
labelOnehot
=
IndexToOnehot
(
label
,
vSizeTgt
,
0
);
labelOnehot
=
IndexToOnehot
(
label
,
vSizeTgt
,
0
);
lossTensor
=
CrossEntropy
(
output
,
labelOnehot
,
paddingDec
);
lossTensor
=
CrossEntropy
(
output
,
labelOnehot
,
paddingDec
);
float
lossBatch
=
ReduceSumAllValue
(
lossTensor
);
float
lossBatch
=
ReduceSumAllValue
(
lossTensor
);
loss
+=
lossBatch
;
loss
+=
lossBatch
;
wordCount
+=
wc
;
wordCount
+=
wc
;
sentCount
+=
bSize
;
sentCount
+=
bSize
;
if
(
model
->
encoder
->
useHistory
)
model
->
encoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
if
(
model
->
decoder
->
useHistory
)
model
->
decoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
}
}
batchLoader
.
ClearSamples
(
buf
);
delete
buf
;
double
elapsed
=
GetClockSec
()
-
startT
;
double
elapsed
=
GetClockSec
()
-
startT
;
LOG
(
"test finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)"
,
ENABLE_GRAD
;
LOG
(
"validating finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)"
,
elapsed
,
sentCount
,
wordCount
,
loss
/
wordCount
/
log
(
2.0
),
exp
(
loss
/
wordCount
));
elapsed
,
sentCount
,
wordCount
,
loss
/
wordCount
/
log
(
2.0
),
exp
(
loss
/
wordCount
));
}
}
...
@@ -428,7 +502,7 @@ void Trainer::Update(Model* model, const float lr)
...
@@ -428,7 +502,7 @@ void Trainer::Update(Model* model, const float lr)
_ScaleAndShiftMe
(
v
,
(
1.0
F
-
adamBeta2
),
0
);
_ScaleAndShiftMe
(
v
,
(
1.0
F
-
adamBeta2
),
0
);
/* v2 = m / (sqrt(v) + delta) */
/* v2 = m / (sqrt(v) + delta) */
XTensor
*
v2
=
NewTensorBuf
(
v
,
v
->
devID
);
XTensor
*
v2
=
NewTensorBuf
V2
(
v
,
v
->
devID
,
v
->
mem
);
_Power
(
v
,
v2
,
0.5
F
);
_Power
(
v
,
v2
,
0.5
F
);
_ScaleAndShiftMe
(
v2
,
1.0
F
,
d
);
_ScaleAndShiftMe
(
v2
,
1.0
F
,
d
);
_Div
(
m
,
v2
,
v2
);
_Div
(
m
,
v2
,
v2
);
...
...
source/sample/transformer/train/Trainer.h
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -70,6 +70,9 @@ public:
...
@@ -70,6 +70,9 @@ public:
/* traing step number */
/* traing step number */
int
nstep
;
int
nstep
;
/* interval step for logging */
int
logInterval
;
/* the maximum number of saved checkpoints */
/* the maximum number of saved checkpoints */
int
maxCheckpoint
;
int
maxCheckpoint
;
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论