Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
98a9130d
Commit
98a9130d
authored
Feb 28, 2021
by
hello
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor class `TrainDataSet`
parent
4bbd6a27
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
180 行增加
和
78 行删除
+180
-78
source/sample/transformer/train/TrainDataSet.cpp
+0
-0
source/sample/transformer/train/TrainDataSet.h
+63
-38
source/sample/transformer/train/Trainer.cpp
+113
-39
source/sample/transformer/train/Trainer.h
+4
-1
没有找到文件。
source/sample/transformer/train/TrainDataSet.cpp
查看文件 @
98a9130d
差异被折叠。
点击展开。
source/sample/transformer/train/TrainDataSet.h
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -30,7 +30,6 @@
...
@@ -30,7 +30,6 @@
#include "../../../tensor/XTensor.h"
#include "../../../tensor/XTensor.h"
#include "../../../tensor/XGlobal.h"
#include "../../../tensor/XGlobal.h"
#define MAX_WORD_NUM 120
using
namespace
std
;
using
namespace
std
;
...
@@ -39,39 +38,54 @@ namespace nts {
...
@@ -39,39 +38,54 @@ namespace nts {
/* a class of sentence pairs for training */
/* a class of sentence pairs for training */
struct
TrainExample
{
struct
TrainExample
{
public
:
/* id of the sentence pair */
/* id of the sentence pair */
int
id
;
int
id
;
/* source language setence (tokenized) */
/* source language setence (tokenized) */
IntList
srcSent
;
IntList
*
srcSent
;
/* target language setence (tokenized) */
/* target language setence (tokenized) */
IntList
tgtSent
;
IntList
*
tgtSent
;
/* the key used to shuffle items in a bucket */
int
key
;
/* the key used to shuffle buckets */
/* the key used to shuffle buckets */
int
bucketKey
;
int
bucketKey
;
public
:
/* constructor */
TrainExample
(
int
myID
,
int
myKey
,
IntList
*
s
,
IntList
*
t
);
/* de-constructor */
~
TrainExample
();
};
struct
ReservedIDs
{
/* the padding id */
int
padID
;
/* the unk id */
int
unkID
;
/* start symbol */
int
startID
;
/* end symbol */
int
endID
;
};
};
/* A `TrainDataSet` is associated with a file which contains training data. */
/* A `TrainDataSet` is associated with a file which contains training data. */
struct
TrainDataSet
{
struct
TrainDataSet
{
public
:
/* the data buffer */
TrainBufferType
buffer
;
/* a list of empty line number */
public
:
IntList
emptyLines
;
/* the pointer to file stream */
/* the pointer to file stream */
FILE
*
fp
;
FILE
*
fp
;
/*
current index in the buffer
*/
/*
number of training samples
*/
size_t
curIdx
;
size_t
totalSampleNum
;
/*
size of used data in the buffer
*/
/*
buffer size
*/
size_t
buffer
Used
;
size_t
buffer
Size
;
/* size of the bucket used for grouping sentences */
/* size of the bucket used for grouping sentences */
size_t
bucketSize
;
size_t
bucketSize
;
...
@@ -79,34 +93,51 @@ public:
...
@@ -79,34 +93,51 @@ public:
/* indicates whether it is used for training */
/* indicates whether it is used for training */
bool
isTraining
;
bool
isTraining
;
/* the padding id */
int
padID
;
/* the unk id */
int
unkID
;
/* start symbol */
int
startID
;
/* end symbol */
int
endID
;
/* the maximum length for a source sentence */
int
maxSrcLen
;
/* the maximum length for a target sentence */
int
maxTgtLen
;
public
:
public
:
/* sort the input by length (in descending order) */
/* get the maximum source sentence length in a range */
void
SortByLength
();
static
int
MaxSrcLen
(
XList
*
buf
,
int
begin
,
int
end
);
/* sort buckets by key (in descending order) */
/* get the maximum target sentence length in a range */
void
SortBucket
();
static
int
MaxTgtLen
(
XList
*
buf
,
int
begin
,
int
end
);
/* sort the
output by key
(in descending order) */
/* sort the
input by source sentence length
(in descending order) */
void
Sort
InBucket
(
int
begin
,
int
end
);
void
Sort
BySrcLength
(
XList
*
buf
);
/*
load data from a file to the buffer
*/
/*
sort the input by target sentence length (in descending order)
*/
void
LoadDataToBuffer
(
);
void
SortByTgtLength
(
XList
*
buf
);
/* generate a mini-batch */
/* sort buckets by key (in descending order) */
UInt64List
LoadBatch
(
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
void
SortBuckets
(
XList
*
buf
);
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
size_t
minSentBatch
,
size_t
batchSize
,
int
devID
);
/* load the samples into the buffer (a list) */
/* load the samples into the buffer (a list) */
bool
LoadBatchToBuf
(
XList
*
buf
);
bool
LoadBatchToBuf
(
XList
*
buf
);
/* load the samples into tensors from the buffer */
/* load the samples into tensors from the buffer */
static
static
bool
LoadBatch
(
XList
*
buf
,
bool
LoadBatch
(
XList
*
buf
,
int
&
curIdx
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
label
,
int
minSentBatch
,
int
batchSize
,
int
devID
,
size_t
minSentBatch
,
size_t
batchSize
,
int
devID
,
int
&
wc
,
int
&
sc
);
int
&
wc
,
int
&
sc
);
/* release the samples in a buffer */
/* release the samples in a buffer */
...
@@ -116,14 +147,8 @@ public:
...
@@ -116,14 +147,8 @@ public:
/* initialization function */
/* initialization function */
void
Init
(
const
char
*
dataFile
,
int
bucketSize
,
bool
training
);
void
Init
(
const
char
*
dataFile
,
int
bucketSize
,
bool
training
);
/* check if the buffer is empty */
bool
IsEmpty
();
/* reset the buffer */
void
ClearBuf
();
/* group data into buckets with similar length */
/* group data into buckets with similar length */
void
BuildBucket
();
void
BuildBucket
(
XList
*
buf
);
/* de-constructor */
/* de-constructor */
~
TrainDataSet
();
~
TrainDataSet
();
...
...
source/sample/transformer/train/Trainer.cpp
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -39,6 +39,41 @@ namespace nmt
...
@@ -39,6 +39,41 @@ namespace nmt
Trainer
::
Trainer
()
Trainer
::
Trainer
()
{
{
cfg
=
NULL
;
cfg
=
NULL
;
lrate
=
0.0
F
;
lrbias
=
0.0
F
;
sBatchSize
=
0
;
wBatchSize
=
0
;
bucketSize
=
0
;
nstep
=
0
;
nepoch
=
0
;
logInterval
=
0
;
maxCheckpoint
=
0
;
d
=
0
;
nwarmup
=
0
;
vSize
=
0
;
vSizeTgt
=
0
;
useAdam
=
false
;
adamBeta1
=
0.0
F
;
adamBeta2
=
0.0
F
;
adamDelta
=
0.0
F
;
isShuffled
=
false
;
labelSmoothingP
=
0.0
F
;
nStepCheckpoint
=
0
;
useEpochCheckpoint
=
false
;
updateStep
=
0
;
isLenSorted
=
0
;
adamBeta1T
=
1.0
F
;
adamBeta2T
=
1.0
F
;
batchLoader
.
startID
=
0
;
batchLoader
.
endID
=
0
;
batchLoader
.
unkID
=
0
;
batchLoader
.
padID
=
0
;
batchLoader
.
maxSrcLen
=
0
;
batchLoader
.
maxTgtLen
=
0
;
batchLoader
.
bufferSize
=
0
;
}
}
/* de-constructor */
/* de-constructor */
...
@@ -62,13 +97,15 @@ initialization
...
@@ -62,13 +97,15 @@ initialization
void
Trainer
::
Init
(
Config
&
config
)
void
Trainer
::
Init
(
Config
&
config
)
{
{
cfg
=
&
config
;
cfg
=
&
config
;
lrate
=
config
.
lrate
;
lrate
=
config
.
lrate
;
lrbias
=
config
.
lrbias
;
lrbias
=
config
.
lrbias
;
sBatchSize
=
config
.
sBatchSize
;
sBatchSize
=
config
.
sBatchSize
;
wBatchSize
=
config
.
wBatchSize
;
wBatchSize
=
config
.
wBatchSize
;
bucketSize
=
config
.
bucketSize
;
bucketSize
=
config
.
bucketSize
;
nepoch
=
config
.
nepoch
;
nstep
=
config
.
nstep
;
nstep
=
config
.
nstep
;
nepoch
=
config
.
nepoch
;
logInterval
=
config
.
logInterval
;
maxCheckpoint
=
config
.
maxCheckpoint
;
maxCheckpoint
=
config
.
maxCheckpoint
;
d
=
config
.
modelSize
;
d
=
config
.
modelSize
;
nwarmup
=
config
.
nwarmup
;
nwarmup
=
config
.
nwarmup
;
...
@@ -87,6 +124,14 @@ void Trainer::Init(Config& config)
...
@@ -87,6 +124,14 @@ void Trainer::Init(Config& config)
adamBeta1T
=
1.0
F
;
adamBeta1T
=
1.0
F
;
adamBeta2T
=
1.0
F
;
adamBeta2T
=
1.0
F
;
batchLoader
.
startID
=
config
.
startID
;
batchLoader
.
endID
=
config
.
endID
;
batchLoader
.
unkID
=
config
.
unkID
;
batchLoader
.
padID
=
config
.
padID
;
batchLoader
.
maxSrcLen
=
config
.
maxSrcLen
;
batchLoader
.
maxTgtLen
=
config
.
maxTgtLen
;
batchLoader
.
bufferSize
=
config
.
bufSize
;
}
}
/*
/*
...
@@ -106,7 +151,7 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -106,7 +151,7 @@ void Trainer::Train(const char* fn, const char* validFN,
}
}
int
step
=
0
;
int
step
=
0
;
int
wc
=
0
;
int
wc
=
0
;
int
ws
=
0
;
int
sc
=
0
;
int
wordCount
=
0
;
int
wordCount
=
0
;
int
wordCountTotal
=
0
;
int
wordCountTotal
=
0
;
int
batchCountTotal
=
0
;
int
batchCountTotal
=
0
;
...
@@ -134,6 +179,9 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -134,6 +179,9 @@ void Trainer::Train(const char* fn, const char* validFN,
double
startT
=
GetClockSec
();
double
startT
=
GetClockSec
();
int
curIdx
=
0
;
XList
*
buf
=
new
XList
;
batchLoader
.
Init
(
fn
,
bucketSize
,
true
);
batchLoader
.
Init
(
fn
,
bucketSize
,
true
);
for
(
epoch
=
1
;
epoch
<=
nepoch
;
epoch
++
)
{
for
(
epoch
=
1
;
epoch
<=
nepoch
;
epoch
++
)
{
...
@@ -141,10 +189,7 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -141,10 +189,7 @@ void Trainer::Train(const char* fn, const char* validFN,
wordCount
=
0
;
wordCount
=
0
;
loss
=
0
;
loss
=
0
;
/* reset the batch loader */
while
(
step
++
<
nstep
)
batchLoader
.
ClearBuf
();
while
(
!
batchLoader
.
IsEmpty
())
{
{
XNet
net
;
XNet
net
;
net
.
Clear
();
net
.
Clear
();
...
@@ -160,21 +205,26 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -160,21 +205,26 @@ void Trainer::Train(const char* fn, const char* validFN,
XTensor
paddingEnc
;
XTensor
paddingEnc
;
XTensor
paddingDec
;
XTensor
paddingDec
;
UInt64List
info
=
batchLoader
.
LoadBatch
(
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
if
(
curIdx
==
0
||
curIdx
==
buf
->
Size
())
{
sBatchSize
,
wBatchSize
,
devID
);
curIdx
=
0
;
batchLoader
.
LoadBatchToBuf
(
buf
);
}
batchLoader
.
LoadBatch
(
buf
,
curIdx
,
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
wBatchSize
,
devID
,
wc
,
sc
);
wc
=
(
int
)
info
[
0
];
ws
=
(
int
)
info
[
1
];
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
/* output probabilities */
/* output probabilities */
XTensor
output
;
XTensor
output
;
/* make the network */
/* make the network */
if
(
model
->
isLM
)
if
(
model
->
isLM
)
{
model
->
MakeLM
(
batchEnc
,
output
,
paddingEnc
,
true
);
model
->
MakeLM
(
batchEnc
,
output
,
paddingEnc
,
true
);
else
if
(
model
->
isMT
)
}
else
if
(
model
->
isMT
)
{
model
->
MakeMT
(
batchEnc
,
batchDec
,
output
,
paddingEnc
,
paddingDec
,
true
);
model
->
MakeMT
(
batchEnc
,
batchDec
,
output
,
paddingEnc
,
paddingDec
,
true
);
}
else
{
else
{
ShowNTErrors
(
"Illegal model type!"
);
ShowNTErrors
(
"Illegal model type!"
);
}
}
...
@@ -192,15 +242,29 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -192,15 +242,29 @@ void Trainer::Train(const char* fn, const char* validFN,
DTYPE
lossLocal
=
lossBatch
/
wc
;
DTYPE
lossLocal
=
lossBatch
/
wc
;
bool
doUpdate
=
(
!
IsNAN
(
lossLocal
)
&&
!
IsINF
(
lossLocal
)
&&
lossLocal
<
1e3
F
);
bool
doUpdate
=
(
!
IsNAN
(
lossLocal
)
&&
!
IsINF
(
lossLocal
)
&&
lossLocal
<
1e3
F
);
net
.
isGradEfficient
=
true
;
bool
debug
(
false
);
if
(
debug
)
{
LOG
(
"after forward:"
);
batchEnc
.
mem
->
ShowMemUsage
(
stderr
);
exit
(
0
);
}
if
(
doUpdate
)
{
if
(
doUpdate
)
{
/* back-propagation */
net
.
Backward
(
lossTensor
);
net
.
Backward
(
lossTensor
);
if
(
model
->
encoder
->
useHistory
)
model
->
encoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
if
(
model
->
decoder
->
useHistory
)
model
->
decoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
gradStep
+=
1
;
gradStep
+=
1
;
loss
+=
lossBatch
;
loss
+=
lossBatch
;
wordCount
+=
wc
;
wordCount
+=
wc
;
wordCountTotal
+=
wc
;
wordCountTotal
+=
wc
;
batchCountTotal
+=
ws
;
batchCountTotal
+=
sc
;
/* update the parameters */
/* update the parameters */
if
(
gradStep
==
updateStep
)
{
if
(
gradStep
==
updateStep
)
{
...
@@ -227,18 +291,7 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -227,18 +291,7 @@ void Trainer::Train(const char* fn, const char* validFN,
else
else
nSkipped
++
;
nSkipped
++
;
if
(
++
step
>=
nstep
)
{
if
(
step
%
logInterval
==
0
)
{
isEnd
=
true
;
break
;
}
if
(
step
==
10
)
{
// LOG("after backward --------");
// lossTensor.mem->ShowMemUsage(stderr);
// exit(0);
}
if
(
step
%
100
==
0
)
{
double
elapsed
=
GetClockSec
()
-
startT
;
double
elapsed
=
GetClockSec
()
-
startT
;
LOG
(
"elapsed=%.1fs, step=%d, epoch=%d, "
LOG
(
"elapsed=%.1fs, step=%d, epoch=%d, "
"total word=%d, total batch=%d, loss=%.3f, ppl=%.3f, lr=%.2e"
,
"total word=%d, total batch=%d, loss=%.3f, ppl=%.3f, lr=%.2e"
,
...
@@ -256,13 +309,13 @@ void Trainer::Train(const char* fn, const char* validFN,
...
@@ -256,13 +309,13 @@ void Trainer::Train(const char* fn, const char* validFN,
}
}
}
}
if
(
isEnd
)
break
;
if
(
useEpochCheckpoint
)
if
(
useEpochCheckpoint
)
MakeCheckpoint
(
model
,
validFN
,
modelFN
,
"epoch"
,
epoch
);
MakeCheckpoint
(
model
,
validFN
,
modelFN
,
"epoch"
,
epoch
);
}
}
batchLoader
.
ClearSamples
(
buf
);
delete
buf
;
double
elapsed
=
GetClockSec
()
-
startT
;
double
elapsed
=
GetClockSec
()
-
startT
;
epoch
=
MIN
(
epoch
,
nepoch
);
epoch
=
MIN
(
epoch
,
nepoch
);
...
@@ -287,8 +340,12 @@ test the model
...
@@ -287,8 +340,12 @@ test the model
*/
*/
void
Trainer
::
Validate
(
const
char
*
fn
,
const
char
*
ofn
,
Model
*
model
)
void
Trainer
::
Validate
(
const
char
*
fn
,
const
char
*
ofn
,
Model
*
model
)
{
{
double
startT
=
GetClockSec
();
DISABLE_GRAD
;
int
wc
=
0
;
int
wc
=
0
;
int
ws
=
0
;
int
sc
=
0
;
int
wordCount
=
0
;
int
wordCount
=
0
;
int
sentCount
=
0
;
int
sentCount
=
0
;
float
loss
=
0
;
float
loss
=
0
;
...
@@ -296,9 +353,14 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
...
@@ -296,9 +353,14 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
/* data files */
/* data files */
batchLoader
.
Init
(
fn
,
0
,
false
);
batchLoader
.
Init
(
fn
,
0
,
false
);
double
startT
=
GetClockSec
();
int
curIdx
=
0
;
XList
*
buf
=
new
XList
;
/* set the buffer size to the size of valiadation set */
batchLoader
.
bufferSize
=
batchLoader
.
totalSampleNum
;
batchLoader
.
LoadBatchToBuf
(
buf
);
while
(
!
batchLoader
.
IsEmpty
()
)
while
(
curIdx
<
buf
->
count
)
{
{
/* batch of input sequences */
/* batch of input sequences */
XTensor
batchEnc
;
XTensor
batchEnc
;
...
@@ -318,10 +380,9 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
...
@@ -318,10 +380,9 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
XTensor
labelOnehot
;
XTensor
labelOnehot
;
XTensor
lossTensor
;
XTensor
lossTensor
;
UInt64List
info
=
batchLoader
.
LoadBatch
(
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
batchLoader
.
LoadBatch
(
buf
,
curIdx
,
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
label
,
sBatchSize
,
0
,
model
->
devID
);
sBatchSize
,
0
,
model
->
devID
,
wc
,
sc
);
wc
=
(
int
)
info
[
0
];
ws
=
(
int
)
info
[
1
];
CheckNTErrors
(
batchEnc
.
order
==
2
,
"Wrong tensor order of the sequence batch"
);
CheckNTErrors
(
batchEnc
.
order
==
2
,
"Wrong tensor order of the sequence batch"
);
/* make the network */
/* make the network */
...
@@ -337,18 +398,31 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
...
@@ -337,18 +398,31 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
int
length
=
output
.
GetDim
(
1
);
int
length
=
output
.
GetDim
(
1
);
labelOnehot
=
IndexToOnehot
(
label
,
vSizeTgt
,
0
);
labelOnehot
=
IndexToOnehot
(
label
,
vSizeTgt
,
0
);
lossTensor
=
CrossEntropy
(
output
,
labelOnehot
,
paddingDec
);
lossTensor
=
CrossEntropy
(
output
,
labelOnehot
,
paddingDec
);
float
lossBatch
=
ReduceSumAllValue
(
lossTensor
);
float
lossBatch
=
ReduceSumAllValue
(
lossTensor
);
loss
+=
lossBatch
;
loss
+=
lossBatch
;
wordCount
+=
wc
;
wordCount
+=
wc
;
sentCount
+=
bSize
;
sentCount
+=
bSize
;
if
(
model
->
encoder
->
useHistory
)
model
->
encoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
if
(
model
->
decoder
->
useHistory
)
model
->
decoder
->
history
->
ClearHistory
(
/*reset=*/
false
);
}
}
batchLoader
.
ClearSamples
(
buf
);
delete
buf
;
double
elapsed
=
GetClockSec
()
-
startT
;
double
elapsed
=
GetClockSec
()
-
startT
;
LOG
(
"test finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)"
,
ENABLE_GRAD
;
LOG
(
"validating finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)"
,
elapsed
,
sentCount
,
wordCount
,
loss
/
wordCount
/
log
(
2.0
),
exp
(
loss
/
wordCount
));
elapsed
,
sentCount
,
wordCount
,
loss
/
wordCount
/
log
(
2.0
),
exp
(
loss
/
wordCount
));
}
}
...
@@ -428,7 +502,7 @@ void Trainer::Update(Model* model, const float lr)
...
@@ -428,7 +502,7 @@ void Trainer::Update(Model* model, const float lr)
_ScaleAndShiftMe
(
v
,
(
1.0
F
-
adamBeta2
),
0
);
_ScaleAndShiftMe
(
v
,
(
1.0
F
-
adamBeta2
),
0
);
/* v2 = m / (sqrt(v) + delta) */
/* v2 = m / (sqrt(v) + delta) */
XTensor
*
v2
=
NewTensorBuf
(
v
,
v
->
devID
);
XTensor
*
v2
=
NewTensorBuf
V2
(
v
,
v
->
devID
,
v
->
mem
);
_Power
(
v
,
v2
,
0.5
F
);
_Power
(
v
,
v2
,
0.5
F
);
_ScaleAndShiftMe
(
v2
,
1.0
F
,
d
);
_ScaleAndShiftMe
(
v2
,
1.0
F
,
d
);
_Div
(
m
,
v2
,
v2
);
_Div
(
m
,
v2
,
v2
);
...
...
source/sample/transformer/train/Trainer.h
查看文件 @
98a9130d
/* NiuTrans.
NMT - an open-source neural machine translation system.
/* NiuTrans.
Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -70,6 +70,9 @@ public:
...
@@ -70,6 +70,9 @@ public:
/* traing step number */
/* traing step number */
int
nstep
;
int
nstep
;
/* interval step for logging */
int
logInterval
;
/* the maximum number of saved checkpoints */
/* the maximum number of saved checkpoints */
int
maxCheckpoint
;
int
maxCheckpoint
;
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论