Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
8b0e06ab
Commit
8b0e06ab
authored
Sep 19, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
improve the t2t implementation
parent
6fb9ad1b
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
173 行增加
和
73 行删除
+173
-73
source/sample/transformer/T2TAttention.cpp
+1
-1
source/sample/transformer/T2TAttention.h
+1
-1
source/sample/transformer/T2TDecoder.h
+1
-1
source/sample/transformer/T2TEmbedding.cpp
+1
-1
source/sample/transformer/T2TEmbedding.h
+1
-1
source/sample/transformer/T2TEncoder.cpp
+1
-1
source/sample/transformer/T2TEncoder.h
+1
-1
source/sample/transformer/T2TFNN.cpp
+1
-1
source/sample/transformer/T2TFNN.h
+1
-1
source/sample/transformer/T2TLayerNormal.cpp
+1
-1
source/sample/transformer/T2TLayerNormal.h
+1
-1
source/sample/transformer/T2TModel.cpp
+3
-3
source/sample/transformer/T2TModel.h
+1
-1
source/sample/transformer/T2TOutput.cpp
+1
-1
source/sample/transformer/T2TOutput.h
+1
-1
source/sample/transformer/T2TTrainer.cpp
+101
-35
source/sample/transformer/T2TTrainer.h
+23
-2
source/sample/transformer/T2TUtility.cpp
+5
-5
source/sample/transformer/T2TUtility.h
+5
-5
source/sample/transformer/Transformer.cpp
+22
-9
没有找到文件。
source/sample/transformer/T2TAttention.cpp
查看文件 @
8b0e06ab
...
...
@@ -53,7 +53,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void
T2TAttention
::
InitModel
(
int
argc
,
c
onst
c
har
**
argv
,
void
T2TAttention
::
InitModel
(
int
argc
,
char
**
argv
,
bool
myIsMasked
,
int
myIgnored
,
int
myDevID
,
XMem
*
myMem
)
{
...
...
source/sample/transformer/T2TAttention.h
查看文件 @
8b0e06ab
...
...
@@ -84,7 +84,7 @@ public:
~
T2TAttention
();
/* initialize the model */
void
InitModel
(
int
argc
,
c
onst
c
har
**
argv
,
void
InitModel
(
int
argc
,
char
**
argv
,
bool
myIsMasked
,
int
myIgnored
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
...
...
source/sample/transformer/T2TDecoder.h
查看文件 @
8b0e06ab
...
...
@@ -34,7 +34,7 @@ class AttDecoder : T2TDecoder
{
public
:
/* initialize the model */
void
InitModel
(
int
argc
,
c
onst
c
har
**
argv
);
void
InitModel
(
int
argc
,
char
**
argv
);
};
}
...
...
source/sample/transformer/T2TEmbedding.cpp
查看文件 @
8b0e06ab
...
...
@@ -48,7 +48,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void
T2TEmbedder
::
InitModel
(
int
argc
,
c
onst
c
har
**
argv
,
int
myDevID
,
XMem
*
myMem
)
void
T2TEmbedder
::
InitModel
(
int
argc
,
char
**
argv
,
int
myDevID
,
XMem
*
myMem
)
{
devID
=
myDevID
;
mem
=
myMem
;
...
...
source/sample/transformer/T2TEmbedding.h
查看文件 @
8b0e06ab
...
...
@@ -71,7 +71,7 @@ public:
~
T2TEmbedder
();
/* initialize the model */
void
InitModel
(
int
argc
,
c
onst
c
har
**
argv
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
void
InitModel
(
int
argc
,
char
**
argv
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
/* make positional embeddings */
void
MakePosEmbedding
(
int
eSize
,
int
d
,
int
length
);
...
...
source/sample/transformer/T2TEncoder.cpp
查看文件 @
8b0e06ab
...
...
@@ -51,7 +51,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void
AttEncoder
::
InitModel
(
int
argc
,
c
onst
c
har
**
argv
,
void
AttEncoder
::
InitModel
(
int
argc
,
char
**
argv
,
bool
myIsMasked
,
int
myIgnored
,
int
myDevID
,
XMem
*
myMem
)
{
...
...
source/sample/transformer/T2TEncoder.h
查看文件 @
8b0e06ab
...
...
@@ -113,7 +113,7 @@ public:
~
AttEncoder
();
/* initialize the model */
void
InitModel
(
int
argc
,
c
onst
c
har
**
argv
,
void
InitModel
(
int
argc
,
char
**
argv
,
bool
myIsMasked
,
int
myIgnored
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
...
...
source/sample/transformer/T2TFNN.cpp
查看文件 @
8b0e06ab
...
...
@@ -49,7 +49,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void
T2TFNN
::
InitModel
(
int
argc
,
c
onst
c
har
**
argv
,
int
myDevID
,
XMem
*
myMem
)
void
T2TFNN
::
InitModel
(
int
argc
,
char
**
argv
,
int
myDevID
,
XMem
*
myMem
)
{
devID
=
myDevID
;
mem
=
myMem
;
...
...
source/sample/transformer/T2TFNN.h
查看文件 @
8b0e06ab
...
...
@@ -69,7 +69,7 @@ public:
~
T2TFNN
();
/* initialize the model */
void
InitModel
(
int
argc
,
c
onst
c
har
**
argv
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
void
InitModel
(
int
argc
,
char
**
argv
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
/* make the network */
XTensor
Make
(
XTensor
&
input
);
...
...
source/sample/transformer/T2TLayerNormal.cpp
查看文件 @
8b0e06ab
...
...
@@ -47,7 +47,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void
T2TLN
::
InitModel
(
int
argc
,
c
onst
c
har
**
argv
,
int
myDevID
,
XMem
*
myMem
)
void
T2TLN
::
InitModel
(
int
argc
,
char
**
argv
,
int
myDevID
,
XMem
*
myMem
)
{
devID
=
myDevID
;
mem
=
myMem
;
...
...
source/sample/transformer/T2TLayerNormal.h
查看文件 @
8b0e06ab
...
...
@@ -54,7 +54,7 @@ public:
~
T2TLN
();
/* initialize the model */
void
InitModel
(
int
argc
,
c
onst
c
har
**
argv
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
void
InitModel
(
int
argc
,
char
**
argv
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
/* make the network */
XTensor
Make
(
XTensor
&
input
);
...
...
source/sample/transformer/T2TModel.cpp
查看文件 @
8b0e06ab
...
...
@@ -48,7 +48,7 @@ initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void
T2TModel
::
InitModel
(
int
argc
,
c
onst
c
har
**
argv
)
void
T2TModel
::
InitModel
(
int
argc
,
char
**
argv
)
{
bool
useMem
=
false
;
int
memSize
=
0
;
...
...
@@ -64,7 +64,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
if
(
useMem
){
delete
mem
;
mem
=
new
XMem
(
devID
,
isMemFreeOTF
?
FREE_ON_THE_FLY
:
UNI_FREE
,
(
MTYPE
)
MILLION
*
256
,
1024
,
MILLION
*
128
);
mem
=
new
XMem
(
devID
,
FREE_ON_THE_FLY
,
(
MTYPE
)
MILLION
*
256
,
1024
,
MILLION
*
128
);
mem
->
SetDesiredSize
(
devID
,
0
,
(
MTYPE
)
memSize
*
MILLION
);
}
...
...
@@ -144,7 +144,7 @@ void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding, bool isTr
//_Sum(&mask, padding3, &mask);
encoding
=
MakeEncoding
(
input
,
mask
,
tru
e
,
isTraining
);
encoding
=
MakeEncoding
(
input
,
mask
,
fals
e
,
isTraining
);
outputLayer
.
Make
(
encoding
,
output
);
delete
[]
dims
;
...
...
source/sample/transformer/T2TModel.h
查看文件 @
8b0e06ab
...
...
@@ -66,7 +66,7 @@ public:
~
T2TModel
();
/* initialize the model */
void
InitModel
(
int
argc
,
c
onst
c
har
**
argv
);
void
InitModel
(
int
argc
,
char
**
argv
);
/* make the encoding network */
XTensor
MakeEncoding
(
XTensor
&
input
,
XTensor
&
mask
,
bool
skipInputRes
,
bool
isTraining
);
...
...
source/sample/transformer/T2TOutput.cpp
查看文件 @
8b0e06ab
...
...
@@ -49,7 +49,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void
T2TOutput
::
InitModel
(
int
argc
,
c
onst
c
har
**
argv
,
int
myDevID
,
XMem
*
myMem
)
void
T2TOutput
::
InitModel
(
int
argc
,
char
**
argv
,
int
myDevID
,
XMem
*
myMem
)
{
devID
=
myDevID
;
mem
=
myMem
;
...
...
source/sample/transformer/T2TOutput.h
查看文件 @
8b0e06ab
...
...
@@ -59,7 +59,7 @@ public:
~
T2TOutput
();
/* initialize the model */
void
InitModel
(
int
argc
,
c
onst
c
har
**
argv
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
void
InitModel
(
int
argc
,
char
**
argv
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
/* make the network */
XTensor
Make
(
XTensor
&
input
);
...
...
source/sample/transformer/T2TTrainer.cpp
查看文件 @
8b0e06ab
...
...
@@ -26,6 +26,11 @@
#include "../../tensor/core/CHeader.h"
#include "../../network/XNoder.h"
#ifndef WIN32
#include <sys/time.h>
#include <unistd.h>
#endif
namespace
transformer
{
...
...
@@ -33,8 +38,16 @@ namespace transformer
T2TTrainer
::
T2TTrainer
()
{
seqLen
=
NULL
;
seqLen2
=
NULL
;
nseqBuf
=
0
;
nextSeq
=
-
1
;
argNum
=
0
;
argArray
=
NULL
;
buf
=
NULL
;
buf2
=
NULL
;
bufSize
=
0
;
seqOffset
=
NULL
;
}
/* de-constructor */
...
...
@@ -55,6 +68,11 @@ T2TTrainer::~T2TTrainer()
XTensor
*
m
=
(
XTensor
*
)
moments2nd
.
Get
(
i
);
delete
m
;
}
for
(
int
i
=
0
;
i
<
argNum
;
i
++
)
delete
[]
argArray
[
i
];
delete
[]
argArray
;
}
/*
...
...
@@ -62,8 +80,15 @@ initialization
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void
T2TTrainer
::
Init
(
int
argc
,
c
onst
c
har
**
argv
)
void
T2TTrainer
::
Init
(
int
argc
,
char
**
argv
)
{
argNum
=
argc
;
argArray
=
new
char
*
[
argc
];
for
(
int
i
=
0
;
i
<
argNum
;
i
++
){
argArray
[
i
]
=
new
char
[
strlen
(
argv
[
i
])
+
1
];
strcpy
(
argArray
[
i
],
argv
[
i
]);
}
bool
useMem
=
false
;
LoadParamBool
(
argc
,
argv
,
"mem"
,
&
useMem
,
useMem
);
...
...
@@ -82,6 +107,9 @@ void T2TTrainer::Init(int argc, const char ** argv)
LoadParamFloat
(
argc
,
argv
,
"adambeta1"
,
&
adamBeta1
,
0.9
F
);
LoadParamFloat
(
argc
,
argv
,
"adambeta2"
,
&
adamBeta2
,
0.999
F
);
LoadParamFloat
(
argc
,
argv
,
"adamdelta"
,
&
adamDelta
,
1e-8
F
);
LoadParamBool
(
argc
,
argv
,
"shuffled"
,
&
isShuffled
,
false
);
LoadParamInt
(
argc
,
argv
,
"nstepcheckpoint"
,
&
nStepCheckpoint
,
-
1
);
LoadParamBool
(
argc
,
argv
,
"epochcheckpoint"
,
&
useEpochCheckpoint
,
false
);
buf
=
new
int
[
bufSize
];
buf2
=
new
int
[
bufSize
];
...
...
@@ -91,7 +119,6 @@ void T2TTrainer::Init(int argc, const char ** argv)
adamBeta1T
=
1.0
F
;
adamBeta2T
=
1.0
F
;
}
int
tc
=
0
;
...
...
@@ -99,9 +126,11 @@ int tc = 0;
/*
train the model
>> fn - training data file
>> validFN - validation data file
>> modelFN - where we keep the model
>> model - model to train
*/
void
T2TTrainer
::
Train
(
const
char
*
fn
,
T2TModel
*
model
)
void
T2TTrainer
::
Train
(
const
char
*
fn
,
const
char
*
validFN
,
const
char
*
modelFN
,
T2TModel
*
model
)
{
int
epoch
=
0
;
int
step
=
0
;
...
...
@@ -111,32 +140,36 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
bool
isEnd
=
false
;
float
loss
=
0
;
float
lr
=
0
;
int
nStepCheck
=
0
;
int
nCheckpoint
=
0
;
char
*
trainFN
=
new
char
[(
int
)
strlen
(
fn
)
+
10
];
strcpy
(
trainFN
,
fn
);
#ifndef WIN32
if
(
isShuffled
)
sprintf
(
trainFN
,
"%s.random"
,
fn
);
#endif
PrepareModel
(
model
);
int
devID
=
model
->
devID
;
XMem
*
mem
=
model
->
mem
;
if
(
mem
!=
NULL
&&
mem
->
mode
==
UNI_FREE
)
mem
->
SetPin
();
XNet
net
;
tf
=
fopen
(
"tmp.xx.txt"
,
"wb"
);
tc
=
0
;
double
startT
=
GetClockSec
();
for
(
epoch
=
1
;
epoch
<=
nepoch
;
epoch
++
){
#ifndef WIN32
if
(
isShuffled
)
Shuffle
(
fn
,
trainFN
);
#endif
FILE
*
file
=
fopen
(
fn
,
"rb"
);
FILE
*
file
=
fopen
(
trainFN
,
"rb"
);
CheckNTErrors
(
file
,
"cannot open training file!"
);
wordCount
=
0
;
loss
=
0
;
if
(
mem
!=
NULL
)
mem
->
BackToPin
();
/* batch of input sequences */
XTensor
batch
;
...
...
@@ -186,22 +219,23 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
lr
,
elapsed
,
step
,
epoch
,
wordCountTotal
,
exp
(
loss
/
wordCount
),
exp
(
-
prob
/
wc
));
}
if
(
mem
!=
NULL
&&
mem
->
mode
==
UNI_FREE
)
mem
->
BackToPin
();
if
(
nStepCheckpoint
>
0
&&
++
nStepCheck
>=
nStepCheckpoint
){
MakeCheckpoint
(
model
,
validFN
,
modelFN
,
"step"
,
step
);
nStepCheck
=
0
;
nCheckpoint
++
;
}
}
fclose
(
file
);
if
(
isEnd
)
break
;
if
(
useEpochCheckpoint
)
MakeCheckpoint
(
model
,
validFN
,
modelFN
,
"epoch"
,
epoch
);
}
if
(
mem
!=
NULL
&&
mem
->
mode
==
UNI_FREE
)
mem
->
BackToPin
();
double
elapsed
=
GetClockSec
()
-
startT
;
fclose
(
tf
);
epoch
=
MIN
(
epoch
,
nepoch
);
...
...
@@ -209,6 +243,8 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
lr
,
elapsed
,
step
,
epoch
,
wordCountTotal
,
exp
(
loss
/
wordCount
));
XPRINT3
(
0
,
stderr
,
"[INFO] training finished (took %.1fs, step=%d and epoch=%d)
\n
"
,
elapsed
,
step
,
epoch
);
delete
[]
trainFN
;
}
/*
...
...
@@ -234,16 +270,10 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
XMem
*
mem
=
model
->
mem
;
XNet
net
;
tf
=
fopen
(
"tmp.xx.txt"
,
"wb"
);
tc
=
0
;
double
startT
=
GetClockSec
();
wordCount
=
0
;
if
(
mem
!=
NULL
&&
mem
->
mode
==
UNI_FREE
)
mem
->
BackToPin
();
/* batch of input sequences */
XTensor
batch
;
...
...
@@ -306,13 +336,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
loss
+=
-
prob
;
wordCount
+=
wc
;
wordCountTotal
+=
wc
;
if
(
mem
!=
NULL
&&
mem
->
mode
==
UNI_FREE
)
mem
->
BackToPin
();
}
if
(
mem
!=
NULL
&&
mem
->
mode
==
UNI_FREE
)
mem
->
BackToPin
();
fclose
(
file
);
fclose
(
ofile
);
...
...
@@ -320,13 +344,37 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
delete
[]
seqs
;
double
elapsed
=
GetClockSec
()
-
startT
;
fclose
(
tf
);
XPRINT3
(
0
,
stderr
,
"[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)
\n
"
,
elapsed
,
wordCountTotal
,
exp
(
loss
/
wordCount
));
}
/*
make a checkpoint
>> model - the model
>> validFN - validation data file
>> modelFN - model data file
>> label - label of the model
>> id - id of the checkpoint
*/
void
T2TTrainer
::
MakeCheckpoint
(
T2TModel
*
model
,
const
char
*
validFN
,
const
char
*
modelFN
,
const
char
*
label
,
int
id
)
{
char
*
fn
=
new
char
[
MAX_LINE_LENGTH
];
char
*
fn2
=
new
char
[
MAX_LINE_LENGTH
];
sprintf
(
fn
,
"%s.%s.%3d"
,
modelFN
,
label
,
id
);
sprintf
(
fn2
,
"%s.%s.%3d.output"
,
modelFN
,
label
,
id
);
//model->Dump(fn);
if
(
validFN
!=
NULL
){
T2TTrainer
trainer
;
trainer
.
Init
(
argNum
,
argArray
);
trainer
.
Test
(
validFN
,
fn2
,
model
);
}
delete
[]
fn
;
delete
[]
fn2
;
}
char
line
[
MAX_SEQUENCE_LENGTH
];
struct
SampleNode
...
...
@@ -583,6 +631,24 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM,
return
sc
;
}
/*
shuffle lines of the file
>> srcFile - the source file to shuffle
>> tgtFile - the resulting file
*/
void
T2TTrainer
::
Shuffle
(
const
char
*
srcFile
,
const
char
*
tgtFile
)
{
char
*
line
=
new
char
[
MAX_LINE_LENGTH
];
#ifndef WIN32
sprintf
(
line
,
"shuf %s > %s"
,
srcFile
,
tgtFile
);
system
(
line
);
#else
ShowNTErrors
(
"Cannot shuffle the file on WINDOWS systems!"
);
#endif
delete
[]
line
;
}
/*
get word probabilities for a batch of sequences
...
...
source/sample/transformer/T2TTrainer.h
查看文件 @
8b0e06ab
...
...
@@ -37,6 +37,12 @@ namespace transformer
class
T2TTrainer
{
public
:
/* paramter number */
int
argNum
;
/* parameter array */
char
**
argArray
;
/* buffer for loading words */
int
*
buf
;
...
...
@@ -107,6 +113,15 @@ public:
/* list of the 2nd order moment of the parameter matrics */
XList
moments2nd
;
/* indicates whether the data file is shuffled for training */
bool
isShuffled
;
/* number of steps after which we make a checkpoint */
int
nStepCheckpoint
;
/* indicates whether we make a checkpoint after each traing epoch */
bool
useEpochCheckpoint
;
public
:
/* constructor */
T2TTrainer
();
...
...
@@ -115,14 +130,17 @@ public:
~
T2TTrainer
();
/* initialize the trainer */
void
Init
(
int
argc
,
c
onst
c
har
**
argv
);
void
Init
(
int
argc
,
char
**
argv
);
/* train the model */
void
Train
(
const
char
*
fn
,
T2TModel
*
model
);
void
Train
(
const
char
*
fn
,
const
char
*
validFN
,
const
char
*
modelFN
,
T2TModel
*
model
);
/* test the model */
void
Test
(
const
char
*
fn
,
const
char
*
ofn
,
T2TModel
*
model
);
/* make a checkpoint */
void
MakeCheckpoint
(
T2TModel
*
model
,
const
char
*
validFN
,
const
char
*
modelFN
,
const
char
*
label
,
int
id
);
/* load data to buffer */
int
LoadBuf
(
FILE
*
file
,
bool
isSorted
,
int
step
);
...
...
@@ -136,6 +154,9 @@ public:
int
step
,
int
vs
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
wCount
,
int
devID
,
XMem
*
mem
);
/* shuffle the data file */
void
Shuffle
(
const
char
*
srcFile
,
const
char
*
tgtFile
);
/* get word probabilities for a batch of sequences */
float
GetProb
(
XTensor
*
output
,
XTensor
*
gold
,
XTensor
*
wordProbs
);
...
...
source/sample/transformer/T2TUtility.cpp
查看文件 @
8b0e06ab
...
...
@@ -30,7 +30,7 @@ FILE * tmpFILE;
int
llnum
=
0
;
FILE
*
tf
=
NULL
;
void
LoadParamString
(
int
argc
,
c
onst
c
har
**
argv
,
const
char
*
name
,
char
*
p
,
const
char
*
defaultP
)
void
LoadParamString
(
int
argc
,
char
**
argv
,
const
char
*
name
,
char
*
p
,
const
char
*
defaultP
)
{
char
vname
[
128
];
vname
[
0
]
=
'-'
;
...
...
@@ -47,7 +47,7 @@ void LoadParamString(int argc, const char ** argv, const char * name, char * p,
strcpy
(
p
,
defaultP
);
}
void
LoadParamInt
(
int
argc
,
c
onst
c
har
**
argv
,
const
char
*
name
,
int
*
p
,
int
defaultP
)
void
LoadParamInt
(
int
argc
,
char
**
argv
,
const
char
*
name
,
int
*
p
,
int
defaultP
)
{
char
vname
[
128
];
vname
[
0
]
=
'-'
;
...
...
@@ -64,7 +64,7 @@ void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int
*
p
=
defaultP
;
}
void
LoadParamBool
(
int
argc
,
c
onst
c
har
**
argv
,
const
char
*
name
,
bool
*
p
,
bool
defaultP
)
void
LoadParamBool
(
int
argc
,
char
**
argv
,
const
char
*
name
,
bool
*
p
,
bool
defaultP
)
{
char
vname
[
128
];
vname
[
0
]
=
'-'
;
...
...
@@ -81,7 +81,7 @@ void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bo
*
p
=
defaultP
;
}
void
LoadParamFloat
(
int
argc
,
c
onst
c
har
**
argv
,
const
char
*
name
,
float
*
p
,
float
defaultP
)
void
LoadParamFloat
(
int
argc
,
char
**
argv
,
const
char
*
name
,
float
*
p
,
float
defaultP
)
{
char
vname
[
128
];
vname
[
0
]
=
'-'
;
...
...
@@ -98,7 +98,7 @@ void LoadParamFloat(int argc, const char ** argv, const char * name, float * p,
*
p
=
defaultP
;
}
void
ShowParams
(
int
argc
,
c
onst
c
har
**
argv
)
void
ShowParams
(
int
argc
,
char
**
argv
)
{
fprintf
(
stderr
,
"args:
\n
"
);
for
(
int
i
=
0
;
i
<
argc
;
i
++
){
...
...
source/sample/transformer/T2TUtility.h
查看文件 @
8b0e06ab
...
...
@@ -30,13 +30,13 @@ namespace transformer
extern
FILE
*
tmpFILE
;
/* load arguments */
void
LoadParamString
(
int
argc
,
c
onst
c
har
**
argv
,
const
char
*
name
,
char
*
p
,
const
char
*
defaultP
);
void
LoadParamInt
(
int
argc
,
c
onst
c
har
**
argv
,
const
char
*
name
,
int
*
p
,
int
defaultP
);
void
LoadParamBool
(
int
argc
,
c
onst
c
har
**
argv
,
const
char
*
name
,
bool
*
p
,
bool
defaultP
);
void
LoadParamFloat
(
int
argc
,
c
onst
c
har
**
argv
,
const
char
*
name
,
float
*
p
,
float
defaultP
);
void
LoadParamString
(
int
argc
,
char
**
argv
,
const
char
*
name
,
char
*
p
,
const
char
*
defaultP
);
void
LoadParamInt
(
int
argc
,
char
**
argv
,
const
char
*
name
,
int
*
p
,
int
defaultP
);
void
LoadParamBool
(
int
argc
,
char
**
argv
,
const
char
*
name
,
bool
*
p
,
bool
defaultP
);
void
LoadParamFloat
(
int
argc
,
char
**
argv
,
const
char
*
name
,
float
*
p
,
float
defaultP
);
/* show arguments */
void
ShowParams
(
int
argc
,
c
onst
c
har
**
argv
);
void
ShowParams
(
int
argc
,
char
**
argv
);
extern
int
llnum
;
extern
FILE
*
tf
;
...
...
source/sample/transformer/Transformer.cpp
查看文件 @
8b0e06ab
...
...
@@ -33,30 +33,36 @@ int TransformerMain(int argc, const char ** argv)
if
(
argc
==
0
)
return
1
;
char
**
args
=
new
char
*
[
argc
];
for
(
int
i
=
0
;
i
<
argc
;
i
++
){
args
[
i
]
=
new
char
[
strlen
(
argv
[
i
])
+
1
];
strcpy
(
args
[
i
],
argv
[
i
]);
}
tmpFILE
=
fopen
(
"tmp.txt"
,
"wb"
);
ShowParams
(
argc
,
arg
v
);
ShowParams
(
argc
,
arg
s
);
char
*
trainFN
=
new
char
[
MAX_LINE_LENGTH
];
char
*
modelFN
=
new
char
[
MAX_LINE_LENGTH
];
char
*
testFN
=
new
char
[
MAX_LINE_LENGTH
];
char
*
outputFN
=
new
char
[
MAX_LINE_LENGTH
];
LoadParamString
(
argc
,
arg
v
,
"train"
,
trainFN
,
""
);
LoadParamString
(
argc
,
arg
v
,
"model"
,
modelFN
,
""
);
LoadParamString
(
argc
,
arg
v
,
"test"
,
testFN
,
""
);
LoadParamString
(
argc
,
arg
v
,
"output"
,
outputFN
,
""
);
LoadParamString
(
argc
,
arg
s
,
"train"
,
trainFN
,
""
);
LoadParamString
(
argc
,
arg
s
,
"model"
,
modelFN
,
""
);
LoadParamString
(
argc
,
arg
s
,
"test"
,
testFN
,
""
);
LoadParamString
(
argc
,
arg
s
,
"output"
,
outputFN
,
""
);
T2TTrainer
trainer
;
trainer
.
Init
(
argc
,
arg
v
);
trainer
.
Init
(
argc
,
arg
s
);
T2TModel
model
;
model
.
InitModel
(
argc
,
arg
v
);
model
.
InitModel
(
argc
,
arg
s
);
/* learn model parameters */
if
(
strcmp
(
trainFN
,
""
))
trainer
.
Train
(
trainFN
,
&
model
);
trainer
.
Train
(
trainFN
,
testFN
,
strcmp
(
modelFN
,
""
)
?
modelFN
:
"checkpoint.model"
,
&
model
);
/* save the final model */
if
(
strcmp
(
modelFN
,
""
)
&&
strcmp
(
trainFN
,
""
))
...
...
@@ -66,15 +72,22 @@ int TransformerMain(int argc, const char ** argv)
if
(
strcmp
(
modelFN
,
""
))
model
.
Read
(
modelFN
);
T2TTrainer
tester
;
tester
.
Init
(
argc
,
args
);
/* test the model on the new data */
if
(
strcmp
(
testFN
,
""
)
&&
strcmp
(
outputFN
,
""
))
t
rain
er
.
Test
(
testFN
,
outputFN
,
&
model
);
t
est
er
.
Test
(
testFN
,
outputFN
,
&
model
);
delete
[]
trainFN
;
delete
[]
modelFN
;
delete
[]
testFN
;
delete
[]
outputFN
;
for
(
int
i
=
0
;
i
<
argc
;
i
++
)
delete
[]
args
[
i
];
delete
[]
args
;
fclose
(
tmpFILE
);
return
0
;
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论