Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
7b6840d4
Commit
7b6840d4
authored
Mar 20, 2021
by
huchi
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor the dataloader in transformer
parent
3852f15a
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
166 行增加
和
120 行删除
+166
-120
source/sample/transformer/train/TrainDataSet.cpp
+102
-77
source/sample/transformer/train/TrainDataSet.h
+41
-22
source/sample/transformer/train/Trainer.cpp
+23
-21
没有找到文件。
source/sample/transformer/train/TrainDataSet.cpp
查看文件 @
7b6840d4
...
...
@@ -31,7 +31,7 @@ using namespace nmt;
namespace
nts
{
/* get the maximum source sentence length in a range */
int
TrainDataSet
::
MaxSrcLen
(
XList
*
buf
,
int
begin
,
int
end
)
{
int
TrainDataSet
::
MaxSrcLen
(
int
begin
,
int
end
)
{
CheckNTErrors
((
end
>
begin
)
&&
(
begin
>=
0
)
&&
(
end
<=
buf
->
count
),
"Invalid range"
);
int
maxLen
=
0
;
for
(
int
i
=
begin
;
i
<
end
;
i
++
)
{
...
...
@@ -42,7 +42,7 @@ int TrainDataSet::MaxSrcLen(XList* buf, int begin, int end) {
}
/* get the maximum target sentence length in a range */
int
TrainDataSet
::
MaxTgtLen
(
XList
*
buf
,
int
begin
,
int
end
)
{
int
TrainDataSet
::
MaxTgtLen
(
int
begin
,
int
end
)
{
CheckNTErrors
((
end
>
begin
)
&&
(
begin
>=
0
)
&&
(
end
<=
buf
->
count
),
"Invalid range"
);
int
maxLen
=
0
;
for
(
int
i
=
begin
;
i
<
end
;
i
++
)
{
...
...
@@ -53,28 +53,28 @@ int TrainDataSet::MaxTgtLen(XList* buf, int begin, int end) {
}
/* sort the buffer by source sentence length (in descending order) */
void
TrainDataSet
::
SortBySrcLength
(
XList
*
buf
)
{
void
TrainDataSet
::
SortBySrcLength
()
{
stable_sort
(
buf
->
items
,
buf
->
items
+
buf
->
count
,
[](
void
*
a
,
void
*
b
)
{
return
((
TrainExample
*
)(
a
))
->
srcSent
->
Size
()
<
((
TrainExample
*
)(
b
))
->
srcSent
->
Size
();
});
[](
void
*
a
,
void
*
b
)
{
return
((
TrainExample
*
)(
a
))
->
srcSent
->
Size
()
<
((
TrainExample
*
)(
b
))
->
srcSent
->
Size
();
});
}
/* sort the buffer by target sentence length (in descending order) */
void
TrainDataSet
::
SortByTgtLength
(
XList
*
buf
)
{
void
TrainDataSet
::
SortByTgtLength
()
{
stable_sort
(
buf
->
items
,
buf
->
items
+
buf
->
count
,
[](
void
*
a
,
void
*
b
)
{
return
((
TrainExample
*
)(
a
))
->
tgtSent
->
Size
()
<
((
TrainExample
*
)(
b
))
->
tgtSent
->
Size
();
});
[](
void
*
a
,
void
*
b
)
{
return
((
TrainExample
*
)(
a
))
->
tgtSent
->
Size
()
<
((
TrainExample
*
)(
b
))
->
tgtSent
->
Size
();
});
}
/* sort buckets by key (in descending order) */
void
TrainDataSet
::
SortBuckets
(
XList
*
buf
)
{
void
TrainDataSet
::
SortBuckets
()
{
sort
(
buf
->
items
,
buf
->
items
+
buf
->
count
,
[](
void
*
a
,
void
*
b
)
{
return
((
TrainExample
*
)(
a
))
->
bucketKey
<
((
TrainExample
*
)(
b
))
->
bucketKey
;
return
((
TrainExample
*
)(
a
))
->
bucketKey
<
((
TrainExample
*
)(
b
))
->
bucketKey
;
});
}
...
...
@@ -87,11 +87,13 @@ source sentence length (4 bit)
target sentence length (4 bit)
source tokens (4 bit per token)
target tokens (4 bit per token)
>> buf - the buffer (list) of samples
*/
bool
TrainDataSet
::
LoadBatchToBuf
(
XList
*
buf
)
bool
TrainDataSet
::
LoadBatchToBuf
()
{
ClearSamples
(
buf
);
/* reset the buffer and index */
curIdx
=
0
;
ClearSamples
();
int
sampleNum
=
0
;
...
...
@@ -130,69 +132,50 @@ bool TrainDataSet::LoadBatchToBuf(XList* buf)
}
/* group samples in the buffer into buckets */
SortByTgtLength
(
buf
);
SortByTgtLength
();
SortBySrcLength
(
buf
);
SortBySrcLength
();
if
(
isTraining
)
BuildBucket
(
buf
);
BuildBucket
();
return
true
;
}
/*
load a mini-batch to a device
>> buf - the buffer (list) of samples
>> curIdx - the index of the buffer
>> batchEnc - a tensor to store the batch of encoder input
>> paddingEnc - a tensor to store the batch of encoder paddings
>> batchDec - a tensor to store the batch of decoder input
>> paddingDec - a tensor to store the batch of decoder paddings
>> label - a tensor to store the label of input
>> minSentBatch - the minimum number of sentence batch
>> batchSize - the maxium number of words in a batch
>> devID - the device id, -1 for the CPU
>> wc - number of target words in a batch
>> sc - number of samples in a batch
>> inputs - the list to store input tensors
>> golds - the list to store gold tensors
*/
bool
TrainDataSet
::
LoadBatch
(
XList
*
buf
,
int
&
curIdx
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
int
minSentBatch
,
int
batchSize
,
int
devID
,
int
&
wc
,
int
&
sc
)
bool
TrainDataSet
::
GetBatchSimple
(
XList
*
inputs
,
XList
*
golds
)
{
int
srcTokenNum
=
0
;
int
tgtTokenNum
=
0
;
int
realBatchSize
=
0
;
/* dynamic batching for sentences */
int
bucketKey
=
((
TrainExample
*
)(
buf
->
Get
(
curIdx
)))
->
bucketKey
;
while
((
realBatchSize
<
(
int
(
buf
->
Size
())
-
curIdx
))
&&
(((
TrainExample
*
)(
buf
->
Get
(
curIdx
+
realBatchSize
)))
->
bucketKey
==
bucketKey
))
{
realBatchSize
++
;
if
(
curIdx
==
0
||
curIdx
==
buf
->
Size
())
{
LoadBatchToBuf
();
}
realBatchSize
=
MIN
(
realBatchSize
,
(
int
(
buf
->
Size
())
-
curIdx
))
;
CheckNTErrors
(
realBatchSize
>
0
,
"Invalid batch size"
);
wc
=
0
;
GetSentNum
(
);
/* get the maximum target sentence length in a mini-batch */
int
maxSrcLen
=
MaxSrcLen
(
buf
,
curIdx
,
curIdx
+
realBatchSize
);
int
maxTgtLen
=
MaxTgtLen
(
buf
,
curIdx
,
curIdx
+
realBatchSize
);
int
maxSrcLen
=
MaxSrcLen
(
curIdx
,
curIdx
+
sc
);
int
maxTgtLen
=
MaxTgtLen
(
curIdx
,
curIdx
+
sc
);
CheckNTErrors
(
maxSrcLen
>
0
,
"Invalid source length for batching"
);
CheckNTErrors
(
maxTgtLen
>
0
,
"Invalid target length for batching"
);
int
*
batchEncValues
=
new
int
[
realBatchSize
*
maxSrcLen
];
float
*
paddingEncValues
=
new
float
[
realBatchSize
*
maxSrcLen
];
int
*
batchEncValues
=
new
int
[
sc
*
maxSrcLen
];
float
*
paddingEncValues
=
new
float
[
sc
*
maxSrcLen
];
int
*
labelVaues
=
new
int
[
realBatchSize
*
maxTgtLen
];
int
*
batchDecValues
=
new
int
[
realBatchSize
*
maxTgtLen
];
float
*
paddingDecValues
=
new
float
[
realBatchSize
*
maxTgtLen
];
int
*
labelVaues
=
new
int
[
sc
*
maxTgtLen
];
int
*
batchDecValues
=
new
int
[
sc
*
maxTgtLen
];
float
*
paddingDecValues
=
new
float
[
sc
*
maxTgtLen
];
for
(
int
i
=
0
;
i
<
realBatchSize
*
maxSrcLen
;
i
++
)
{
for
(
int
i
=
0
;
i
<
sc
*
maxSrcLen
;
i
++
)
{
batchEncValues
[
i
]
=
1
;
paddingEncValues
[
i
]
=
1.0
F
;
}
for
(
int
i
=
0
;
i
<
realBatchSize
*
maxTgtLen
;
i
++
)
{
for
(
int
i
=
0
;
i
<
sc
*
maxTgtLen
;
i
++
)
{
batchDecValues
[
i
]
=
1
;
labelVaues
[
i
]
=
1
;
paddingDecValues
[
i
]
=
1.0
F
;
...
...
@@ -206,11 +189,10 @@ bool TrainDataSet::LoadBatch(XList* buf, int & curIdx,
batchDec: begin with SOS (right padding)
label: end with EOS (right padding)
*/
for
(
int
i
=
0
;
i
<
realBatchSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
sc
;
++
i
)
{
TrainExample
*
sample
=
(
TrainExample
*
)(
buf
->
Get
(
curIdx
+
i
));
srcTokenNum
+=
int
(
sample
->
srcSent
->
Size
());
tgtTokenNum
+=
int
(
sample
->
tgtSent
->
Size
());
wc
+=
int
(
sample
->
tgtSent
->
Size
());
curSrc
=
maxSrcLen
*
i
;
for
(
int
j
=
0
;
j
<
sample
->
srcSent
->
Size
();
j
++
)
{
...
...
@@ -230,13 +212,25 @@ bool TrainDataSet::LoadBatch(XList* buf, int & curIdx,
paddingDecValues
[
curTgt
++
]
=
0
;
}
InitTensor2D
(
batchEnc
,
realBatchSize
,
maxSrcLen
,
X_INT
,
devID
);
InitTensor2D
(
paddingEnc
,
realBatchSize
,
maxSrcLen
,
X_FLOAT
,
devID
);
InitTensor2D
(
batchDec
,
realBatchSize
,
maxTgtLen
,
X_INT
,
devID
);
InitTensor2D
(
paddingDec
,
realBatchSize
,
maxTgtLen
,
X_FLOAT
,
devID
);
InitTensor2D
(
label
,
realBatchSize
,
maxTgtLen
,
X_INT
,
devID
);
XTensor
*
batchEnc
=
((
TensorList
*
)(
inputs
))
->
Get
(
0
);
XTensor
*
paddingEnc
=
((
TensorList
*
)(
inputs
))
->
Get
(
1
);
XTensor
*
batchDec
=
((
TensorList
*
)(
golds
))
->
Get
(
0
);
XTensor
*
paddingDec
=
((
TensorList
*
)(
golds
))
->
Get
(
1
);
XTensor
*
label
=
((
TensorList
*
)(
golds
))
->
Get
(
2
);
curIdx
+=
realBatchSize
;
InitTensor2D
(
batchEnc
,
sc
,
maxSrcLen
,
X_INT
);
InitTensor2D
(
paddingEnc
,
sc
,
maxSrcLen
,
X_FLOAT
);
InitTensor2D
(
batchDec
,
sc
,
maxTgtLen
,
X_INT
);
InitTensor2D
(
paddingDec
,
sc
,
maxTgtLen
,
X_FLOAT
);
InitTensor2D
(
label
,
sc
,
maxTgtLen
,
X_INT
);
inputs
->
Add
(
batchEnc
);
inputs
->
Add
(
paddingEnc
);
golds
->
Add
(
batchDec
);
golds
->
Add
(
paddingDec
);
golds
->
Add
(
label
);
curIdx
+=
sc
;
batchEnc
->
SetData
(
batchEncValues
,
batchEnc
->
unitNum
);
paddingEnc
->
SetData
(
paddingEncValues
,
paddingEnc
->
unitNum
);
...
...
@@ -250,8 +244,6 @@ bool TrainDataSet::LoadBatch(XList* buf, int & curIdx,
delete
[]
paddingDecValues
;
delete
[]
labelVaues
;
wc
=
tgtTokenNum
;
sc
=
realBatchSize
;
return
true
;
}
...
...
@@ -259,7 +251,7 @@ bool TrainDataSet::LoadBatch(XList* buf, int & curIdx,
clear the buffer
>> buf - the buffer (list) of samples
*/
void
TrainDataSet
::
ClearSamples
(
XList
*
buf
)
void
TrainDataSet
::
ClearSamples
()
{
for
(
int
i
=
0
;
i
<
buf
->
count
;
i
++
)
{
TrainExample
*
sample
=
(
TrainExample
*
)
buf
->
Get
(
i
);
...
...
@@ -274,7 +266,7 @@ the constructor of DataSet
>> bucketSize - size of the bucket to keep similar length sentence pairs
>> training - indicates whether it is used for training
*/
void
TrainDataSet
::
Init
(
const
char
*
dataFile
,
int
myBucketSize
,
bool
training
)
void
TrainDataSet
::
Init
(
const
char
*
dataFile
,
int
myB
atchSize
,
int
myB
ucketSize
,
bool
training
)
{
fp
=
fopen
(
dataFile
,
"rb"
);
CheckNTErrors
(
fp
,
"can not open the training file"
);
...
...
@@ -289,12 +281,15 @@ void TrainDataSet::Init(const char* dataFile, int myBucketSize, bool training)
fread
(
&
totalSampleNum
,
sizeof
(
totalSampleNum
),
1
,
fp
);
CheckNTErrors
(
totalSampleNum
>
0
,
"Invalid sentence pairs number"
);
batchSize
=
myBatchSize
;
bucketSize
=
myBucketSize
;
isTraining
=
training
;
buf
=
new
XList
;
}
/* group
data
with similar length into buckets */
void
TrainDataSet
::
BuildBucket
(
XList
*
buf
)
/* group
samples
with similar length into buckets */
void
TrainDataSet
::
BuildBucket
()
{
int
idx
=
0
;
...
...
@@ -305,8 +300,8 @@ void TrainDataSet::BuildBucket(XList * buf)
int
sentNum
=
1
;
/* get the maximum source sentence length in a bucket */
int
maxSrcLen
=
MaxSrcLen
(
buf
,
idx
,
idx
+
sentNum
);
int
maxTgtLen
=
MaxTgtLen
(
buf
,
idx
,
idx
+
sentNum
);
int
maxSrcLen
=
MaxSrcLen
(
idx
,
idx
+
sentNum
);
int
maxTgtLen
=
MaxTgtLen
(
idx
,
idx
+
sentNum
);
int
maxLen
=
MAX
(
maxSrcLen
,
maxTgtLen
);
/* the maximum sentence number in a bucket */
...
...
@@ -316,8 +311,8 @@ void TrainDataSet::BuildBucket(XList * buf)
&&
(
sentNum
<
MAX_SENT_NUM
)
&&
(
sentNum
*
maxLen
<=
bucketSize
))
{
sentNum
++
;
maxSrcLen
=
MaxSrcLen
(
buf
,
idx
,
idx
+
sentNum
);
maxTgtLen
=
MaxTgtLen
(
buf
,
idx
,
idx
+
sentNum
);
maxSrcLen
=
MaxSrcLen
(
idx
,
idx
+
sentNum
);
maxTgtLen
=
MaxTgtLen
(
idx
,
idx
+
sentNum
);
maxLen
=
MAX
(
maxSrcLen
,
maxTgtLen
);
}
...
...
@@ -339,12 +334,42 @@ void TrainDataSet::BuildBucket(XList * buf)
}
/* sort buckets by their keys */
SortBuckets
(
buf
);
SortBuckets
();
}
/* get the number of sentences in a mini-batch */
inline
int
TrainDataSet
::
GetSentNum
()
{
sc
=
0
;
/* dynamic batching for sentences */
int
bucketKey
=
((
TrainExample
*
)(
buf
->
Get
(
curIdx
)))
->
bucketKey
;
while
((
sc
<
(
int
(
buf
->
Size
())
-
curIdx
))
&&
(((
TrainExample
*
)(
buf
->
Get
(
curIdx
+
sc
)))
->
bucketKey
==
bucketKey
))
{
sc
++
;
}
sc
=
MIN
(
sc
,
(
int
(
buf
->
Size
())
-
curIdx
));
CheckNTErrors
(
sc
>
0
,
"Invalid batch size"
);
}
/* start the process */
bool
TrainDataSet
::
Start
()
{
return
false
;
}
/* end the process */
bool
TrainDataSet
::
End
()
{
return
true
;
}
/* de-constructor */
TrainDataSet
::~
TrainDataSet
()
{
ClearSamples
();
delete
buf
;
fclose
(
fp
);
}
...
...
source/sample/transformer/train/TrainDataSet.h
查看文件 @
7b6840d4
...
...
@@ -83,13 +83,13 @@ public:
FILE
*
fp
;
/* number of training samples */
size_
t
totalSampleNum
;
in
t
totalSampleNum
;
/* buffer size */
size_
t
bufferSize
;
in
t
bufferSize
;
/* size of the bucket used for grouping sentences */
size_
t
bucketSize
;
in
t
bucketSize
;
/* indicates whether it is used for training */
bool
isTraining
;
...
...
@@ -112,44 +112,63 @@ public:
/* the maximum length for a target sentence */
int
maxTgtLen
;
/* batch size (number of words) */
int
batchSize
;
/* word-counter */
int
wc
;
/* sentence-counter */
int
sc
;
/* current index of the buffer */
int
curIdx
;
/* the buffer (a list) of samples */
XList
*
buf
;
public
:
/* get the maximum source sentence length in a range */
static
int
MaxSrcLen
(
XList
*
buf
,
int
begin
,
int
end
);
int
MaxSrcLen
(
int
begin
,
int
end
);
/* get the maximum target sentence length in a range */
static
int
MaxTgtLen
(
XList
*
buf
,
int
begin
,
int
end
);
int
MaxTgtLen
(
int
begin
,
int
end
);
/* sort the input by source sentence length (in descending order) */
void
SortBySrcLength
(
XList
*
buf
);
void
SortBySrcLength
();
/* sort the input by target sentence length (in descending order) */
void
SortByTgtLength
(
XList
*
buf
);
void
SortByTgtLength
();
/* sort buckets by key (in descending order) */
void
SortBuckets
(
XList
*
buf
);
void
SortBuckets
();
/* load the samples into the buffer (a list) */
bool
LoadBatchToBuf
(
XList
*
buf
);
/* load the samples into tensors from the buffer */
static
bool
LoadBatch
(
XList
*
buf
,
int
&
curIdx
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
int
minSentBatch
,
int
batchSize
,
int
devID
,
int
&
wc
,
int
&
sc
);
bool
LoadBatchToBuf
();
/* release the samples in a buffer */
static
void
ClearSamples
(
XList
*
buf
);
void
ClearSamples
();
/* initialization function */
void
Init
(
const
char
*
dataFile
,
int
b
ucketSize
,
bool
training
);
void
Init
(
const
char
*
dataFile
,
int
myBatchSize
,
int
myB
ucketSize
,
bool
training
);
/* group data into buckets with similar length */
void
BuildBucket
(
XList
*
buf
);
void
BuildBucket
();
/* get the number of sentences in a mini-batch */
int
GetSentNum
();
public
:
/* start the process */
bool
Start
();
/* end the process */
bool
End
();
/* load the samples into tensors from the buffer */
bool
GetBatchSimple
(
XList
*
inputs
,
XList
*
golds
);
/* de-constructor */
~
TrainDataSet
();
...
...
source/sample/transformer/train/Trainer.cpp
查看文件 @
7b6840d4
...
...
@@ -179,9 +179,8 @@ void Trainer::Train(const char* fn, const char* validFN,
double
startT
=
GetClockSec
();
int
curIdx
=
0
;
XList
*
buf
=
new
XList
;
batchLoader
.
Init
(
fn
,
bucketSize
,
true
);
batchLoader
.
Init
(
fn
,
wBatchSize
,
bucketSize
,
true
);
for
(
epoch
=
1
;
epoch
<=
nepoch
;
epoch
++
)
{
...
...
@@ -204,13 +203,16 @@ void Trainer::Train(const char* fn, const char* validFN,
XTensor
paddingEnc
;
XTensor
paddingDec
;
if
(
curIdx
==
0
||
curIdx
==
buf
->
Size
())
{
curIdx
=
0
;
batchLoader
.
LoadBatchToBuf
(
buf
);
}
TensorList
inputs
;
TensorList
golds
;
inputs
.
Add
(
&
batchEnc
);
inputs
.
Add
(
&
paddingEnc
);
golds
.
Add
(
&
batchDec
);
golds
.
Add
(
&
paddingDec
);
golds
.
Add
(
&
label
);
batchLoader
.
LoadBatch
(
buf
,
curIdx
,
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
wBatchSize
,
devID
,
wc
,
sc
);
batchLoader
.
GetBatchSimple
((
XList
*
)(
&
inputs
),
(
XList
*
)(
&
golds
));
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
...
...
@@ -303,9 +305,6 @@ void Trainer::Train(const char* fn, const char* validFN,
MakeCheckpoint
(
model
,
validFN
,
modelFN
,
"epoch"
,
epoch
);
}
batchLoader
.
ClearSamples
(
buf
);
delete
buf
;
double
elapsed
=
GetClockSec
()
-
startT
;
epoch
=
MIN
(
epoch
,
nepoch
);
...
...
@@ -341,16 +340,15 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
float
loss
=
0
;
/* data files */
batchLoader
.
Init
(
fn
,
0
,
false
);
batchLoader
.
Init
(
fn
,
wBatchSize
,
0
,
false
);
int
curIdx
=
0
;
XList
*
buf
=
new
XList
;
/* set the buffer size to the size of valiadation set */
batchLoader
.
bufferSize
=
batchLoader
.
totalSampleNum
;
batchLoader
.
LoadBatchToBuf
(
buf
);
batchLoader
.
LoadBatchToBuf
();
while
(
curIdx
<
buf
->
count
)
while
(
curIdx
<
b
atchLoader
.
b
uf
->
count
)
{
/* batch of input sequences */
XTensor
batchEnc
;
...
...
@@ -370,8 +368,16 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
XTensor
labelOnehot
;
XTensor
lossTensor
;
batchLoader
.
LoadBatch
(
buf
,
curIdx
,
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
0
,
model
->
devID
,
wc
,
sc
);
TensorList
inputs
;
TensorList
golds
;
inputs
.
Add
(
&
batchEnc
);
inputs
.
Add
(
&
paddingEnc
);
golds
.
Add
(
&
batchDec
);
golds
.
Add
(
&
paddingDec
);
golds
.
Add
(
&
label
);
batchLoader
.
GetBatchSimple
((
XList
*
)(
&
inputs
),
(
XList
*
)(
&
golds
));
CheckNTErrors
(
batchEnc
.
order
==
2
,
"Wrong tensor order of the sequence batch"
);
...
...
@@ -404,10 +410,6 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
model
->
decoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
}
batchLoader
.
ClearSamples
(
buf
);
delete
buf
;
double
elapsed
=
GetClockSec
()
-
startT
;
ENABLE_GRAD
;
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论