Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
98a9130d
Commit
98a9130d
authored
Feb 28, 2021
by
hello
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor class `TrainDataSet`
parent
4bbd6a27
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
376 行增加
和
268 行删除
+376
-268
source/sample/transformer/train/TrainDataSet.cpp
+196
-190
source/sample/transformer/train/TrainDataSet.h
+63
-38
source/sample/transformer/train/Trainer.cpp
+113
-39
source/sample/transformer/train/Trainer.h
+4
-1
没有找到文件。
source/sample/transformer/train/TrainDataSet.cpp
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -16,13 +16,10 @@
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2020-08-09
*
TODO: refactor the data loader class and references
*
$Updated by: CAO Hang and Wu Siming 2020-12-13
*/
#include <string>
#include <vector>
#include <cstdlib>
#include <fstream>
#include <algorithm>
#include "TrainDataSet.h"
...
...
@@ -33,37 +30,56 @@ using namespace nmt;
namespace
nts
{
/* sort the dataset by length (in descending order) */
void
TrainDataSet
::
SortByLength
()
{
sort
(
buffer
.
items
,
buffer
.
items
+
buffer
.
count
,
[](
TrainExample
*
a
,
TrainExample
*
b
)
{
return
(
a
->
srcSent
.
Size
()
+
a
->
tgtSent
.
Size
())
>
(
b
->
srcSent
.
Size
()
+
b
->
tgtSent
.
Size
());
});
/* get the maximum source sentence length in a range */
int
TrainDataSet
::
MaxSrcLen
(
XList
*
buf
,
int
begin
,
int
end
)
{
CheckNTErrors
((
end
>
begin
)
&&
(
begin
>=
0
)
&&
(
end
<=
buf
->
count
),
"Invalid range"
);
int
maxLen
=
0
;
for
(
int
i
=
begin
;
i
<
end
;
i
++
)
{
IntList
*
srcSent
=
((
TrainExample
*
)
buf
->
Get
(
i
))
->
srcSent
;
maxLen
=
MAX
(
int
(
srcSent
->
Size
()),
maxLen
);
}
return
maxLen
;
}
/* sort buckets by key (in descending order) */
void
TrainDataSet
::
SortBucket
()
{
sort
(
buffer
.
items
,
buffer
.
items
+
buffer
.
count
,
[](
TrainExample
*
a
,
TrainExample
*
b
)
{
return
a
->
bucketKey
>
b
->
bucketKey
;
});
/* get the maximum target sentence length in a range */
int
TrainDataSet
::
MaxTgtLen
(
XList
*
buf
,
int
begin
,
int
end
)
{
CheckNTErrors
((
end
>
begin
)
&&
(
begin
>=
0
)
&&
(
end
<=
buf
->
count
),
"Invalid range"
);
int
maxLen
=
0
;
for
(
int
i
=
begin
;
i
<
end
;
i
++
)
{
IntList
*
tgtSent
=
((
TrainExample
*
)
buf
->
Get
(
i
))
->
tgtSent
;
maxLen
=
MAX
(
int
(
tgtSent
->
Size
()),
maxLen
);
}
return
maxLen
;
}
/*
sort the output by key in a range (in descending order)
>> begin - the first index of the range
>> end - the last index of the range
*/
void
TrainDataSet
::
SortInBucket
(
int
begin
,
int
end
)
{
sort
(
buffer
.
items
+
begin
,
buffer
.
items
+
end
,
[](
TrainExample
*
a
,
TrainExample
*
b
)
{
return
(
a
->
key
>
b
->
key
);
});
/* sort the buffer by source sentence length (in descending order) */
void
TrainDataSet
::
SortBySrcLength
(
XList
*
buf
)
{
stable_sort
(
buf
->
items
,
buf
->
items
+
buf
->
count
,
[](
void
*
a
,
void
*
b
)
{
return
((
TrainExample
*
)(
a
))
->
srcSent
->
Size
()
<
((
TrainExample
*
)(
b
))
->
srcSent
->
Size
();
});
}
/* sort the buffer by target sentence length (in descending order) */
void
TrainDataSet
::
SortByTgtLength
(
XList
*
buf
)
{
stable_sort
(
buf
->
items
,
buf
->
items
+
buf
->
count
,
[](
void
*
a
,
void
*
b
)
{
return
((
TrainExample
*
)(
a
))
->
tgtSent
->
Size
()
<
((
TrainExample
*
)(
b
))
->
tgtSent
->
Size
();
});
}
/* sort buckets by key (in descending order) */
void
TrainDataSet
::
SortBuckets
(
XList
*
buf
)
{
sort
(
buf
->
items
,
buf
->
items
+
buf
->
count
,
[](
void
*
a
,
void
*
b
)
{
return
((
TrainExample
*
)(
a
))
->
bucketKey
<
((
TrainExample
*
)(
b
))
->
bucketKey
;
});
}
/*
load
all data from a file
to the buffer
load
samples from a file in
to the buffer
training data format (binary):
first 8 bit: number of sentence pairs
subsequent segements:
...
...
@@ -71,52 +87,63 @@ source sentence length (4 bit)
target sentence length (4 bit)
source tokens (4 bit per token)
target tokens (4 bit per token)
>> buf - the buffer (list) of samples
*/
void
TrainDataSet
::
LoadDataToBuffer
(
)
bool
TrainDataSet
::
LoadBatchToBuf
(
XList
*
buf
)
{
buffer
.
Clear
();
curIdx
=
0
;
int
id
=
0
;
uint64_t
sentNum
=
0
;
ClearSamples
(
buf
);
int
srcVocabSize
=
0
;
int
tgtVocabSize
=
0
;
fread
(
&
srcVocabSize
,
sizeof
(
srcVocabSize
),
1
,
fp
);
fread
(
&
tgtVocabSize
,
sizeof
(
tgtVocabSize
),
1
,
fp
);
int
sampleNum
=
0
;
fread
(
&
sentNum
,
sizeof
(
uint64_t
),
1
,
fp
);
CheckNTErrors
(
sentNum
>
0
,
"Invalid sentence pairs number"
);
while
((
sampleNum
<
bufferSize
))
{
while
(
id
<
sentNum
)
{
int
srcLen
=
0
;
int
tgtLen
=
0
;
fread
(
&
srcLen
,
sizeof
(
int
),
1
,
fp
);
size_t
n
=
fread
(
&
srcLen
,
sizeof
(
int
),
1
,
fp
);
if
(
n
==
0
)
break
;
fread
(
&
tgtLen
,
sizeof
(
int
),
1
,
fp
);
CheckNTErrors
(
srcLen
>
0
,
"Invalid source sentence length"
);
CheckNTErrors
(
tgtLen
>
0
,
"Invalid target sentence length"
);
IntList
srcSent
;
IntList
tgtSent
;
srcSent
.
ReadFromFile
(
fp
,
srcLen
);
tgtSent
.
ReadFromFile
(
fp
,
tgtLen
);
IntList
*
srcSent
=
new
IntList
(
srcLen
);
IntList
*
tgtSent
=
new
IntList
(
tgtLen
);
srcSent
->
ReadFromFile
(
fp
,
srcLen
);
tgtSent
->
ReadFromFile
(
fp
,
tgtLen
);
TrainExample
*
example
=
new
TrainExample
(
sampleNum
++
,
0
,
srcSent
,
tgtSent
);
buf
->
Add
(
example
);
}
/* reset the file pointer to the begin */
if
(
feof
(
fp
)
&&
isTraining
)
{
TrainExample
*
example
=
new
TrainExample
;
example
->
id
=
id
++
;
example
->
key
=
id
;
example
->
srcSent
=
srcSent
;
example
->
tgtSent
=
tgtSent
;
rewind
(
fp
);
buffer
.
Add
(
example
);
int
srcVocabSize
=
0
;
int
tgtVocabSize
=
0
;
fread
(
&
srcVocabSize
,
sizeof
(
int
),
1
,
fp
);
fread
(
&
tgtVocabSize
,
sizeof
(
int
),
1
,
fp
);
fread
(
&
totalSampleNum
,
sizeof
(
totalSampleNum
),
1
,
fp
);
}
fclose
(
fp
);
/* group samples in the buffer into buckets */
SortByTgtLength
(
buf
);
SortBySrcLength
(
buf
);
if
(
isTraining
)
BuildBucket
(
buf
);
XPRINT1
(
0
,
stderr
,
"[INFO] loaded %d sentences
\n
"
,
id
)
;
return
true
;
}
/*
load a mini-batch to the device (for training)
load a mini-batch to a device
>> buf - the buffer (list) of samples
>> curIdx - the index of the buffer
>> batchEnc - a tensor to store the batch of encoder input
>> paddingEnc - a tensor to store the batch of encoder paddings
>> batchDec - a tensor to store the batch of decoder input
...
...
@@ -125,57 +152,34 @@ load a mini-batch to the device (for training)
>> minSentBatch - the minimum number of sentence batch
>> batchSize - the maxium number of words in a batch
>> devID - the device id, -1 for the CPU
<< return - number of target tokens and sentences
>> wc - number of target words in a batch
>> sc - number of samples in a batch
*/
UInt64List
TrainDataSet
::
LoadBatch
(
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
size_t
minSentBatch
,
size_t
batchSize
,
int
devID
)
bool
TrainDataSet
::
LoadBatch
(
XList
*
buf
,
int
&
curIdx
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
int
minSentBatch
,
int
batchSize
,
int
devID
,
int
&
wc
,
int
&
sc
)
{
UInt64List
info
;
size_t
srcTokenNum
=
0
;
size_t
tgtTokenNum
=
0
;
size_t
realBatchSize
=
1
;
if
(
!
isTraining
)
realBatchSize
=
minSentBatch
;
/* get the maximum source sentence length in a mini-batch */
size_t
maxSrcLen
=
buffer
[(
int
)
curIdx
]
->
srcSent
.
Size
();
/* max batch size */
const
int
MAX_BATCH_SIZE
=
512
;
/* dynamic batching for sentences, enabled when the dataset is used for training */
if
(
isTraining
)
{
while
((
realBatchSize
<
(
buffer
.
Size
()
-
curIdx
))
&&
(
realBatchSize
*
maxSrcLen
<
batchSize
)
&&
(
realBatchSize
<
MAX_BATCH_SIZE
)
&&
(
realBatchSize
*
buffer
[(
int
)(
curIdx
+
realBatchSize
)]
->
srcSent
.
Size
()
<
batchSize
))
{
if
(
maxSrcLen
<
buffer
[(
int
)(
curIdx
+
realBatchSize
)]
->
srcSent
.
Size
())
maxSrcLen
=
buffer
[(
int
)(
curIdx
+
realBatchSize
)]
->
srcSent
.
Size
();
realBatchSize
++
;
}
}
/* real batch size */
if
((
buffer
.
Size
()
-
curIdx
)
<
realBatchSize
)
{
realBatchSize
=
buffer
.
Size
()
-
curIdx
;
int
srcTokenNum
=
0
;
int
tgtTokenNum
=
0
;
int
realBatchSize
=
0
;
/* dynamic batching for sentences */
int
bucketKey
=
((
TrainExample
*
)(
buf
->
Get
(
curIdx
)))
->
bucketKey
;
while
((
realBatchSize
<
(
int
(
buf
->
Size
())
-
curIdx
))
&&
(((
TrainExample
*
)(
buf
->
Get
(
curIdx
+
realBatchSize
)))
->
bucketKey
==
bucketKey
))
{
realBatchSize
++
;
}
realBatchSize
=
MIN
(
realBatchSize
,
(
int
(
buf
->
Size
())
-
curIdx
));
CheckNTErrors
(
realBatchSize
>
0
,
"Invalid batch size"
);
/* get the maximum target sentence length in a mini-batch */
size_t
maxTgtLen
=
buffer
[(
int
)
curIdx
]
->
tgtSent
.
Size
();
for
(
size_t
i
=
0
;
i
<
realBatchSize
;
i
++
)
{
if
(
maxTgtLen
<
buffer
[(
int
)(
curIdx
+
i
)]
->
tgtSent
.
Size
())
maxTgtLen
=
buffer
[(
int
)(
curIdx
+
i
)]
->
tgtSent
.
Size
();
}
for
(
size_t
i
=
0
;
i
<
realBatchSize
;
i
++
)
{
if
(
maxSrcLen
<
buffer
[(
int
)(
curIdx
+
i
)]
->
srcSent
.
Size
())
maxSrcLen
=
buffer
[(
int
)(
curIdx
+
i
)]
->
srcSent
.
Size
();
}
int
maxSrcLen
=
MaxSrcLen
(
buf
,
curIdx
,
curIdx
+
realBatchSize
);
int
maxTgtLen
=
MaxTgtLen
(
buf
,
curIdx
,
curIdx
+
realBatchSize
);
CheckNTErrors
(
maxSrcLen
!=
0
,
"Invalid source length for batching"
);
CheckNTErrors
(
maxSrcLen
>
0
,
"Invalid source length for batching"
);
CheckNTErrors
(
maxTgtLen
>
0
,
"Invalid target length for batching"
);
int
*
batchEncValues
=
new
int
[
realBatchSize
*
maxSrcLen
];
float
*
paddingEncValues
=
new
float
[
realBatchSize
*
maxSrcLen
];
...
...
@@ -185,17 +189,17 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
float
*
paddingDecValues
=
new
float
[
realBatchSize
*
maxTgtLen
];
for
(
int
i
=
0
;
i
<
realBatchSize
*
maxSrcLen
;
i
++
)
{
batchEncValues
[
i
]
=
PAD
;
paddingEncValues
[
i
]
=
1
;
batchEncValues
[
i
]
=
1
;
paddingEncValues
[
i
]
=
1
.0
F
;
}
for
(
int
i
=
0
;
i
<
realBatchSize
*
maxTgtLen
;
i
++
)
{
batchDecValues
[
i
]
=
PAD
;
labelVaues
[
i
]
=
PAD
;
batchDecValues
[
i
]
=
1
;
labelVaues
[
i
]
=
1
;
paddingDecValues
[
i
]
=
1.0
F
;
}
size_
t
curSrc
=
0
;
size_
t
curTgt
=
0
;
in
t
curSrc
=
0
;
in
t
curTgt
=
0
;
/*
batchEnc: end with EOS (left padding)
...
...
@@ -204,35 +208,33 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
*/
for
(
int
i
=
0
;
i
<
realBatchSize
;
++
i
)
{
srcTokenNum
+=
buffer
[(
int
)(
curIdx
+
i
)]
->
srcSent
.
Size
();
tgtTokenNum
+=
buffer
[(
int
)(
curIdx
+
i
)]
->
tgtSent
.
Size
();
TrainExample
*
sample
=
(
TrainExample
*
)(
buf
->
Get
(
curIdx
+
i
));
srcTokenNum
+=
int
(
sample
->
srcSent
->
Size
());
tgtTokenNum
+=
int
(
sample
->
tgtSent
->
Size
());
curSrc
=
maxSrcLen
*
i
;
for
(
int
j
=
0
;
j
<
buffer
[(
int
)(
curIdx
+
i
)]
->
srcSent
.
Size
();
j
++
)
{
batchEncValues
[
curSrc
++
]
=
buffer
[(
int
)(
curIdx
+
i
)]
->
srcSent
[
j
]
;
for
(
int
j
=
0
;
j
<
sample
->
srcSent
->
Size
();
j
++
)
{
batchEncValues
[
curSrc
++
]
=
sample
->
srcSent
->
Get
(
j
)
;
}
curTgt
=
maxTgtLen
*
i
;
for
(
int
j
=
0
;
j
<
buffer
[(
int
)(
curIdx
+
i
)]
->
tgtSent
.
Size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
sample
->
tgtSent
->
Size
();
j
++
)
{
if
(
j
>
0
)
labelVaues
[
curTgt
-
1
]
=
buffer
[(
int
)(
curIdx
+
i
)]
->
tgtSent
[
j
]
;
batchDecValues
[
curTgt
++
]
=
buffer
[(
int
)(
curIdx
+
i
)]
->
tgtSent
[
j
]
;
labelVaues
[
curTgt
-
1
]
=
sample
->
tgtSent
->
Get
(
j
)
;
batchDecValues
[
curTgt
++
]
=
sample
->
tgtSent
->
Get
(
j
)
;
}
labelVaues
[
curTgt
-
1
]
=
EOS
;
labelVaues
[
curTgt
-
1
]
=
2
;
while
(
curSrc
<
maxSrcLen
*
(
i
+
1
))
paddingEncValues
[
curSrc
++
]
=
0
;
while
(
curTgt
<
maxTgtLen
*
(
i
+
1
))
paddingDecValues
[
curTgt
++
]
=
0
;
}
int
rbs
=
(
int
)
realBatchSize
;
int
msl
=
(
int
)
maxSrcLen
;
InitTensor2D
(
batchEnc
,
rbs
,
msl
,
X_INT
,
devID
);
InitTensor2D
(
paddingEnc
,
rbs
,
msl
,
X_FLOAT
,
devID
);
InitTensor2D
(
batchDec
,
rbs
,
msl
,
X_INT
,
devID
);
InitTensor2D
(
paddingDec
,
rbs
,
msl
,
X_FLOAT
,
devID
);
InitTensor2D
(
label
,
rbs
,
msl
,
X_INT
,
devID
);
InitTensor2D
(
batchEnc
,
realBatchSize
,
maxSrcLen
,
X_INT
,
devID
);
InitTensor2D
(
paddingEnc
,
realBatchSize
,
maxSrcLen
,
X_FLOAT
,
devID
);
InitTensor2D
(
batchDec
,
realBatchSize
,
maxTgtLen
,
X_INT
,
devID
);
InitTensor2D
(
paddingDec
,
realBatchSize
,
maxTgtLen
,
X_FLOAT
,
devID
);
InitTensor2D
(
label
,
realBatchSize
,
maxTgtLen
,
X_INT
,
devID
);
curIdx
+=
realBatchSize
;
...
...
@@ -248,9 +250,22 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
delete
[]
paddingDecValues
;
delete
[]
labelVaues
;
info
.
Add
(
tgtTokenNum
);
info
.
Add
(
realBatchSize
);
return
info
;
wc
=
tgtTokenNum
;
sc
=
realBatchSize
;
return
true
;
}
/*
clear the buffer
>> buf - the buffer (list) of samples
*/
void
TrainDataSet
::
ClearSamples
(
XList
*
buf
)
{
for
(
int
i
=
0
;
i
<
buf
->
count
;
i
++
)
{
TrainExample
*
sample
=
(
TrainExample
*
)
buf
->
Get
(
i
);
delete
sample
;
}
buf
->
Clear
();
}
/*
...
...
@@ -263,98 +278,90 @@ void TrainDataSet::Init(const char* dataFile, int myBucketSize, bool training)
{
fp
=
fopen
(
dataFile
,
"rb"
);
CheckNTErrors
(
fp
,
"can not open the training file"
);
curIdx
=
0
;
bucketSize
=
myBucketSize
;
isTraining
=
training
;
LoadDataToBuffer
();
SortByLength
();
if
(
isTraining
)
BuildBucket
();
}
/* check if the buffer is empty */
bool
TrainDataSet
::
IsEmpty
()
{
if
(
curIdx
<
buffer
.
Size
())
return
false
;
return
true
;
}
/* reset the buffer */
void
TrainDataSet
::
ClearBuf
()
{
curIdx
=
0
;
int
srcVocabSize
=
0
;
int
tgtVocabSize
=
0
;
fread
(
&
srcVocabSize
,
sizeof
(
int
),
1
,
fp
);
fread
(
&
tgtVocabSize
,
sizeof
(
int
),
1
,
fp
);
CheckNTErrors
(
srcVocabSize
>
0
,
"Invalid source vocabulary size"
);
CheckNTErrors
(
tgtVocabSize
>
0
,
"Invalid target vocabulary size"
);
/* make different batches in different epochs */
SortByLength
(
);
fread
(
&
totalSampleNum
,
sizeof
(
totalSampleNum
),
1
,
fp
);
CheckNTErrors
(
totalSampleNum
>
0
,
"Invalid sentence pairs number"
);
if
(
isTraining
)
BuildBucket
()
;
bucketSize
=
myBucketSize
;
isTraining
=
training
;
}
/* group data
into buckets with similar length
*/
void
TrainDataSet
::
BuildBucket
()
/* group data
with similar length into buckets
*/
void
TrainDataSet
::
BuildBucket
(
XList
*
buf
)
{
size_
t
idx
=
0
;
in
t
idx
=
0
;
/* build
and shuffle bucket
s */
while
(
idx
<
buffer
.
Size
(
))
{
/* build
buckets by the length of source and target sentence
s */
while
(
idx
<
int
(
buf
->
Size
()
))
{
/* sentence number in a bucket */
size_
t
sentNum
=
1
;
in
t
sentNum
=
1
;
/* get the maximum source sentence length in a bucket */
size_t
maxSrcLen
=
buffer
[(
int
)
idx
]
->
srcSent
.
Size
();
/* bucketing for sentences */
while
((
sentNum
<
(
buffer
.
Size
()
-
idx
))
&&
(
sentNum
*
maxSrcLen
<
bucketSize
)
&&
(
sentNum
*
buffer
[(
int
)(
curIdx
+
sentNum
)]
->
srcSent
.
Size
()
<
bucketSize
))
{
if
(
maxSrcLen
<
buffer
[(
int
)(
idx
+
sentNum
)]
->
srcSent
.
Size
())
maxSrcLen
=
buffer
[(
int
)(
idx
+
sentNum
)]
->
srcSent
.
Size
();
int
maxSrcLen
=
MaxSrcLen
(
buf
,
idx
,
idx
+
sentNum
);
int
maxTgtLen
=
MaxTgtLen
(
buf
,
idx
,
idx
+
sentNum
);
int
maxLen
=
MAX
(
maxSrcLen
,
maxTgtLen
);
/* the maximum sentence number in a bucket */
const
int
MAX_SENT_NUM
=
5120
;
while
((
sentNum
<
(
buf
->
count
-
idx
))
&&
(
sentNum
<
MAX_SENT_NUM
)
&&
(
sentNum
*
maxLen
<=
bucketSize
))
{
sentNum
++
;
maxSrcLen
=
MaxSrcLen
(
buf
,
idx
,
idx
+
sentNum
);
maxTgtLen
=
MaxTgtLen
(
buf
,
idx
,
idx
+
sentNum
);
maxLen
=
MAX
(
maxSrcLen
,
maxTgtLen
);
}
/* make sure the number is valid */
if
((
buffer
.
Size
()
-
idx
)
<
sentNum
)
{
sentNum
=
buffer
.
Size
()
-
idx
;
if
((
sentNum
)
*
maxLen
>
bucketSize
||
sentNum
>=
MAX_SENT_NUM
)
{
sentNum
--
;
sentNum
=
max
(
8
*
(
sentNum
/
8
),
sentNum
%
8
);
}
if
((
int
(
buf
->
Size
())
-
idx
)
<
sentNum
)
sentNum
=
int
(
buf
->
Size
())
-
idx
;
/* assign the same key for items in a bucket */
int
randomKey
=
rand
();
/* shuffle items in a bucket */
for
(
size_t
i
=
0
;
i
<
sentNum
;
i
++
)
{
buffer
[(
int
)(
idx
+
i
)]
->
bucketKey
=
randomKey
;
for
(
int
i
=
0
;
i
<
sentNum
;
i
++
)
{
((
TrainExample
*
)(
buf
->
Get
(
idx
+
i
)))
->
bucketKey
=
randomKey
;
}
idx
+=
sentNum
;
}
SortBucket
();
/* sort items in a bucket */
idx
=
0
;
while
(
idx
<
buffer
.
Size
())
{
size_t
sentNum
=
0
;
int
bucketKey
=
buffer
[(
int
)(
idx
+
sentNum
)]
->
bucketKey
;
while
(
sentNum
<
(
buffer
.
Size
()
-
idx
)
&&
buffer
[(
int
)(
idx
+
sentNum
)]
->
bucketKey
==
bucketKey
)
{
buffer
[(
int
)(
idx
+
sentNum
)]
->
key
=
(
int
)
buffer
[(
int
)(
idx
+
sentNum
)]
->
srcSent
.
Size
();
sentNum
++
;
}
SortInBucket
((
int
)
idx
,
(
int
)(
idx
+
sentNum
));
idx
+=
sentNum
;
}
/* sort buckets by their keys */
SortBuckets
(
buf
);
}
/* de-constructor */
TrainDataSet
::~
TrainDataSet
()
{
fclose
(
fp
);
}
/* constructor */
TrainExample
::
TrainExample
(
int
myID
,
int
myKey
,
IntList
*
s
,
IntList
*
t
)
{
id
=
myID
;
bucketKey
=
myKey
;
srcSent
=
s
;
tgtSent
=
t
;
}
/* release the buffer */
for
(
int
i
=
0
;
i
<
buffer
.
Size
();
i
++
)
delete
buffer
[
i
];
/* de-constructor */
TrainExample
::~
TrainExample
()
{
delete
srcSent
;
delete
tgtSent
;
}
}
\ No newline at end of file
}
source/sample/transformer/train/TrainDataSet.h
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -30,7 +30,6 @@
#include "../../../tensor/XTensor.h"
#include "../../../tensor/XGlobal.h"
#define MAX_WORD_NUM 120
using
namespace
std
;
...
...
@@ -39,39 +38,54 @@ namespace nts {
/* a class of sentence pairs for training */
struct
TrainExample
{
public
:
/* id of the sentence pair */
int
id
;
/* source language setence (tokenized) */
IntList
srcSent
;
IntList
*
srcSent
;
/* target language setence (tokenized) */
IntList
tgtSent
;
/* the key used to shuffle items in a bucket */
int
key
;
IntList
*
tgtSent
;
/* the key used to shuffle buckets */
int
bucketKey
;
public
:
/* constructor */
TrainExample
(
int
myID
,
int
myKey
,
IntList
*
s
,
IntList
*
t
);
/* de-constructor */
~
TrainExample
();
};
struct
ReservedIDs
{
/* the padding id */
int
padID
;
/* the unk id */
int
unkID
;
/* start symbol */
int
startID
;
/* end symbol */
int
endID
;
};
/* A `TrainDataSet` is associated with a file which contains training data. */
struct
TrainDataSet
{
public
:
/* the data buffer */
TrainBufferType
buffer
;
/* a list of empty line number */
IntList
emptyLines
;
public
:
/* the pointer to file stream */
FILE
*
fp
;
/*
current index in the buffer
*/
size_t
curIdx
;
/*
number of training samples
*/
size_t
totalSampleNum
;
/*
size of used data in the buffer
*/
size_t
buffer
Used
;
/*
buffer size
*/
size_t
buffer
Size
;
/* size of the bucket used for grouping sentences */
size_t
bucketSize
;
...
...
@@ -79,34 +93,51 @@ public:
/* indicates whether it is used for training */
bool
isTraining
;
/* the padding id */
int
padID
;
/* the unk id */
int
unkID
;
/* start symbol */
int
startID
;
/* end symbol */
int
endID
;
/* the maximum length for a source sentence */
int
maxSrcLen
;
/* the maximum length for a target sentence */
int
maxTgtLen
;
public
:
/* sort the input by length (in descending order) */
void
SortByLength
();
/* get the maximum source sentence length in a range */
static
int
MaxSrcLen
(
XList
*
buf
,
int
begin
,
int
end
);
/* sort buckets by key (in descending order) */
void
SortBucket
();
/* get the maximum target sentence length in a range */
static
int
MaxTgtLen
(
XList
*
buf
,
int
begin
,
int
end
);
/* sort the
output by key
(in descending order) */
void
Sort
InBucket
(
int
begin
,
int
end
);
/* sort the
input by source sentence length
(in descending order) */
void
Sort
BySrcLength
(
XList
*
buf
);
/*
load data from a file to the buffer
*/
void
LoadDataToBuffer
(
);
/*
sort the input by target sentence length (in descending order)
*/
void
SortByTgtLength
(
XList
*
buf
);
/* generate a mini-batch */
UInt64List
LoadBatch
(
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
size_t
minSentBatch
,
size_t
batchSize
,
int
devID
);
/* sort buckets by key (in descending order) */
void
SortBuckets
(
XList
*
buf
);
/* load the samples into the buffer (a list) */
bool
LoadBatchToBuf
(
XList
*
buf
);
/* load the samples into tensors from the buffer */
static
bool
LoadBatch
(
XList
*
buf
,
bool
LoadBatch
(
XList
*
buf
,
int
&
curIdx
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
size_t
minSentBatch
,
size_t
batchSize
,
int
devID
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
int
minSentBatch
,
int
batchSize
,
int
devID
,
int
&
wc
,
int
&
sc
);
/* release the samples in a buffer */
...
...
@@ -116,14 +147,8 @@ public:
/* initialization function */
void
Init
(
const
char
*
dataFile
,
int
bucketSize
,
bool
training
);
/* check if the buffer is empty */
bool
IsEmpty
();
/* reset the buffer */
void
ClearBuf
();
/* group data into buckets with similar length */
void
BuildBucket
();
void
BuildBucket
(
XList
*
buf
);
/* de-constructor */
~
TrainDataSet
();
...
...
source/sample/transformer/train/Trainer.cpp
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -39,6 +39,41 @@ namespace nmt
Trainer
::
Trainer
()
{
cfg
=
NULL
;
lrate
=
0.0
F
;
lrbias
=
0.0
F
;
sBatchSize
=
0
;
wBatchSize
=
0
;
bucketSize
=
0
;
nstep
=
0
;
nepoch
=
0
;
logInterval
=
0
;
maxCheckpoint
=
0
;
d
=
0
;
nwarmup
=
0
;
vSize
=
0
;
vSizeTgt
=
0
;
useAdam
=
false
;
adamBeta1
=
0.0
F
;
adamBeta2
=
0.0
F
;
adamDelta
=
0.0
F
;
isShuffled
=
false
;
labelSmoothingP
=
0.0
F
;
nStepCheckpoint
=
0
;
useEpochCheckpoint
=
false
;
updateStep
=
0
;
isLenSorted
=
0
;
adamBeta1T
=
1.0
F
;
adamBeta2T
=
1.0
F
;
batchLoader
.
startID
=
0
;
batchLoader
.
endID
=
0
;
batchLoader
.
unkID
=
0
;
batchLoader
.
padID
=
0
;
batchLoader
.
maxSrcLen
=
0
;
batchLoader
.
maxTgtLen
=
0
;
batchLoader
.
bufferSize
=
0
;
}
/* de-constructor */
...
...
@@ -62,13 +97,15 @@ initialization
void
Trainer
::
Init
(
Config
&
config
)
{
cfg
=
&
config
;
lrate
=
config
.
lrate
;
lrbias
=
config
.
lrbias
;
sBatchSize
=
config
.
sBatchSize
;
wBatchSize
=
config
.
wBatchSize
;
bucketSize
=
config
.
bucketSize
;
nepoch
=
config
.
nepoch
;
nstep
=
config
.
nstep
;
nepoch
=
config
.
nepoch
;
logInterval
=
config
.
logInterval
;
maxCheckpoint
=
config
.
maxCheckpoint
;
d
=
config
.
modelSize
;
nwarmup
=
config
.
nwarmup
;
...
...
@@ -87,6 +124,14 @@ void Trainer::Init(Config& config)
adamBeta1T
=
1.0
F
;
adamBeta2T
=
1.0
F
;
batchLoader
.
startID
=
config
.
startID
;
batchLoader
.
endID
=
config
.
endID
;
batchLoader
.
unkID
=
config
.
unkID
;
batchLoader
.
padID
=
config
.
padID
;
batchLoader
.
maxSrcLen
=
config
.
maxSrcLen
;
batchLoader
.
maxTgtLen
=
config
.
maxTgtLen
;
batchLoader
.
bufferSize
=
config
.
bufSize
;
}
/*
...
...
@@ -106,7 +151,7 @@ void Trainer::Train(const char* fn, const char* validFN,
}
int
step
=
0
;
int
wc
=
0
;
int
ws
=
0
;
int
sc
=
0
;
int
wordCount
=
0
;
int
wordCountTotal
=
0
;
int
batchCountTotal
=
0
;
...
...
@@ -134,6 +179,9 @@ void Trainer::Train(const char* fn, const char* validFN,
double
startT
=
GetClockSec
();
int
curIdx
=
0
;
XList
*
buf
=
new
XList
;
batchLoader
.
Init
(
fn
,
bucketSize
,
true
);
for
(
epoch
=
1
;
epoch
<=
nepoch
;
epoch
++
)
{
...
...
@@ -141,10 +189,7 @@ void Trainer::Train(const char* fn, const char* validFN,
wordCount
=
0
;
loss
=
0
;
/* reset the batch loader */
batchLoader
.
ClearBuf
();
while
(
!
batchLoader
.
IsEmpty
())
while
(
step
++
<
nstep
)
{
XNet
net
;
net
.
Clear
();
...
...
@@ -160,21 +205,26 @@ void Trainer::Train(const char* fn, const char* validFN,
XTensor
paddingEnc
;
XTensor
paddingDec
;
UInt64List
info
=
batchLoader
.
LoadBatch
(
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
wBatchSize
,
devID
);
if
(
curIdx
==
0
||
curIdx
==
buf
->
Size
())
{
curIdx
=
0
;
batchLoader
.
LoadBatchToBuf
(
buf
);
}
batchLoader
.
LoadBatch
(
buf
,
curIdx
,
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
wBatchSize
,
devID
,
wc
,
sc
);
wc
=
(
int
)
info
[
0
];
ws
=
(
int
)
info
[
1
];
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
/* output probabilities */
XTensor
output
;
/* make the network */
if
(
model
->
isLM
)
if
(
model
->
isLM
)
{
model
->
MakeLM
(
batchEnc
,
output
,
paddingEnc
,
true
);
else
if
(
model
->
isMT
)
}
else
if
(
model
->
isMT
)
{
model
->
MakeMT
(
batchEnc
,
batchDec
,
output
,
paddingEnc
,
paddingDec
,
true
);
}
else
{
ShowNTErrors
(
"Illegal model type!"
);
}
...
...
@@ -192,15 +242,29 @@ void Trainer::Train(const char* fn, const char* validFN,
DTYPE
lossLocal
=
lossBatch
/
wc
;
bool
doUpdate
=
(
!
IsNAN
(
lossLocal
)
&&
!
IsINF
(
lossLocal
)
&&
lossLocal
<
1e3
F
);
net
.
isGradEfficient
=
true
;
bool
debug
(
false
);
if
(
debug
)
{
LOG
(
"after forward:"
);
batchEnc
.
mem
->
ShowMemUsage
(
stderr
);
exit
(
0
);
}
if
(
doUpdate
)
{
/* back-propagation */
net
.
Backward
(
lossTensor
);
if
(
model
->
encoder
->
useHistory
)
model
->
encoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
if
(
model
->
decoder
->
useHistory
)
model
->
decoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
gradStep
+=
1
;
loss
+=
lossBatch
;
wordCount
+=
wc
;
wordCountTotal
+=
wc
;
batchCountTotal
+=
ws
;
batchCountTotal
+=
sc
;
/* update the parameters */
if
(
gradStep
==
updateStep
)
{
...
...
@@ -227,18 +291,7 @@ void Trainer::Train(const char* fn, const char* validFN,
else
nSkipped
++
;
if
(
++
step
>=
nstep
)
{
isEnd
=
true
;
break
;
}
if
(
step
==
10
)
{
// LOG("after backward --------");
// lossTensor.mem->ShowMemUsage(stderr);
// exit(0);
}
if
(
step
%
100
==
0
)
{
if
(
step
%
logInterval
==
0
)
{
double
elapsed
=
GetClockSec
()
-
startT
;
LOG
(
"elapsed=%.1fs, step=%d, epoch=%d, "
"total word=%d, total batch=%d, loss=%.3f, ppl=%.3f, lr=%.2e"
,
...
...
@@ -256,13 +309,13 @@ void Trainer::Train(const char* fn, const char* validFN,
}
}
if
(
isEnd
)
break
;
if
(
useEpochCheckpoint
)
MakeCheckpoint
(
model
,
validFN
,
modelFN
,
"epoch"
,
epoch
);
}
batchLoader
.
ClearSamples
(
buf
);
delete
buf
;
double
elapsed
=
GetClockSec
()
-
startT
;
epoch
=
MIN
(
epoch
,
nepoch
);
...
...
@@ -287,8 +340,12 @@ test the model
*/
void
Trainer
::
Validate
(
const
char
*
fn
,
const
char
*
ofn
,
Model
*
model
)
{
double
startT
=
GetClockSec
();
DISABLE_GRAD
;
int
wc
=
0
;
int
ws
=
0
;
int
sc
=
0
;
int
wordCount
=
0
;
int
sentCount
=
0
;
float
loss
=
0
;
...
...
@@ -296,9 +353,14 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
/* data files */
batchLoader
.
Init
(
fn
,
0
,
false
);
double
startT
=
GetClockSec
();
int
curIdx
=
0
;
XList
*
buf
=
new
XList
;
/* set the buffer size to the size of valiadation set */
batchLoader
.
bufferSize
=
batchLoader
.
totalSampleNum
;
batchLoader
.
LoadBatchToBuf
(
buf
);
while
(
!
batchLoader
.
IsEmpty
()
)
while
(
curIdx
<
buf
->
count
)
{
/* batch of input sequences */
XTensor
batchEnc
;
...
...
@@ -318,10 +380,9 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
XTensor
labelOnehot
;
XTensor
lossTensor
;
UInt64List
info
=
batchLoader
.
LoadBatch
(
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
0
,
model
->
devID
);
wc
=
(
int
)
info
[
0
];
ws
=
(
int
)
info
[
1
];
batchLoader
.
LoadBatch
(
buf
,
curIdx
,
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
0
,
model
->
devID
,
wc
,
sc
);
CheckNTErrors
(
batchEnc
.
order
==
2
,
"Wrong tensor order of the sequence batch"
);
/* make the network */
...
...
@@ -337,18 +398,31 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
int
length
=
output
.
GetDim
(
1
);
labelOnehot
=
IndexToOnehot
(
label
,
vSizeTgt
,
0
);
lossTensor
=
CrossEntropy
(
output
,
labelOnehot
,
paddingDec
);
float
lossBatch
=
ReduceSumAllValue
(
lossTensor
);
loss
+=
lossBatch
;
wordCount
+=
wc
;
sentCount
+=
bSize
;
if
(
model
->
encoder
->
useHistory
)
model
->
encoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
if
(
model
->
decoder
->
useHistory
)
model
->
decoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
}
batchLoader
.
ClearSamples
(
buf
);
delete
buf
;
double
elapsed
=
GetClockSec
()
-
startT
;
LOG
(
"test finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)"
,
ENABLE_GRAD
;
LOG
(
"validating finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)"
,
elapsed
,
sentCount
,
wordCount
,
loss
/
wordCount
/
log
(
2.0
),
exp
(
loss
/
wordCount
));
}
...
...
@@ -428,7 +502,7 @@ void Trainer::Update(Model* model, const float lr)
_ScaleAndShiftMe
(
v
,
(
1.0
F
-
adamBeta2
),
0
);
/* v2 = m / (sqrt(v) + delta) */
XTensor
*
v2
=
NewTensorBuf
(
v
,
v
->
devID
);
XTensor
*
v2
=
NewTensorBuf
V2
(
v
,
v
->
devID
,
v
->
mem
);
_Power
(
v
,
v2
,
0.5
F
);
_ScaleAndShiftMe
(
v2
,
1.0
F
,
d
);
_Div
(
m
,
v2
,
v2
);
...
...
source/sample/transformer/train/Trainer.h
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -70,6 +70,9 @@ public:
/* traing step number */
int
nstep
;
/* interval step for logging */
int
logInterval
;
/* the maximum number of saved checkpoints */
int
maxCheckpoint
;
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论