Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
59d8b4bf
Commit
59d8b4bf
authored
6 years ago
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bug fix
parent
78bdfb45
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
46 行增加
和
12 行删除
+46
-12
source/sample/transformer/T2TAttention.cpp
+3
-3
source/sample/transformer/T2TEmbedding.cpp
+2
-1
source/sample/transformer/T2TModel.cpp
+3
-3
source/sample/transformer/T2TTrainer.cpp
+0
-0
source/sample/transformer/T2TTrainer.h
+25
-1
source/sample/transformer/Transformer.cpp
+12
-4
source/tensor/XGlobal.h
+1
-0
没有找到文件。
source/sample/transformer/T2TAttention.cpp
查看文件 @
59d8b4bf
...
...
@@ -121,10 +121,10 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
dot
=
BMMul
(
qheads
,
X_NOTRANS
,
kheads
,
X_TRANS
);
if
(
isMasked
)
dot
=
dot
+
mask
;
scalar
=
Softmax
(
Linear
(
dot
,
1
/
(
float
)
sqrt
((
float
)
dk
)),
-
1
);
scalar
=
Softmax
(
Linear
(
dot
,
1
.0
F
/
(
float
)
sqrt
((
float
)
dk
)),
-
1
);
if
(
ignored
>
0
)
_SetDataDim
(
&
scalar
,
0
,
ignored
,
scalar
.
order
-
2
,
1e-9
F
);
//
if(ignored > 0)
//
_SetDataDim(&scalar, 0, ignored, scalar.order - 2, 1e-9F);
att
=
BMMul
(
scalar
,
vheads
);
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TEmbedding.cpp
查看文件 @
59d8b4bf
...
...
@@ -123,7 +123,8 @@ XTensor T2TEmbedder::Make(XTensor &input)
}
/* we make positional embeddings first */
if
(
!
match
){
//if(!match){
if
(
true
){
InitTensor
(
&
posEmbedding
,
input
.
order
,
dims
,
X_FLOAT
,
1.0
F
,
devID
,
mem
);
XTensor
*
posTMP
=
NewTensorBuf
(
2
,
dims
+
1
,
X_FLOAT
,
1.0
F
,
devID
,
mem
);
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TModel.cpp
查看文件 @
59d8b4bf
...
...
@@ -55,7 +55,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
LoadParamInt
(
argc
,
argv
,
"dev"
,
&
devID
,
-
1
);
LoadParamBool
(
argc
,
argv
,
"mem"
,
&
useMem
,
useMem
);
LoadParamInt
(
argc
,
argv
,
"memsize"
,
&
memSize
,
256
);
LoadParamInt
(
argc
,
argv
,
"memsize"
,
&
memSize
,
1024
);
LoadParamBool
(
argc
,
argv
,
"lm"
,
&
isLM
,
true
);
LoadParamBool
(
argc
,
argv
,
"mt"
,
&
isMT
,
false
);
LoadParamInt
(
argc
,
argv
,
"nhead"
,
&
nhead
,
8
);
...
...
@@ -66,7 +66,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
mem
->
SetDesiredSize
(
devID
,
0
,
(
MTYPE
)
memSize
*
MILLION
);
}
encoder
.
InitModel
(
argc
,
argv
,
isLM
,
isLM
?
1
:
0
,
devID
,
mem
);
encoder
.
InitModel
(
argc
,
argv
,
isLM
,
0
,
devID
,
mem
);
outputLayer
.
InitModel
(
argc
,
argv
,
devID
,
mem
);
}
...
...
@@ -104,7 +104,7 @@ void T2TModel::Make(XTensor &input, XTensor &output)
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
a given sequence. */
_SetDataLowTri
(
&
mask
,
1e9
F
,
-
1
);
_SetDataLowTri
(
&
mask
,
1e9
F
,
0
);
_ScaleAndShiftMe
(
&
mask
,
1.0
F
,
-
1e9
F
);
encoding
=
MakeEncoding
(
input
,
mask
,
true
);
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TTrainer.cpp
查看文件 @
59d8b4bf
差异被折叠。
点击展开。
source/sample/transformer/T2TTrainer.h
查看文件 @
59d8b4bf
...
...
@@ -85,6 +85,22 @@ public:
/* traing step number */
int
nstep
;
/* indicates whether we use adam */
bool
useAdam
;
/* hyper parameters of adam*/
float
adamBeta1
;
float
adamBeta2
;
float
adamDelta
;
float
adamBeta1T
;
float
adamBeta2T
;
/* list of the moment of the parameter matrics */
XList
moments
;
/* list of the 2nd order moment of the parameter matrics */
XList
moments2nd
;
public
:
/* constructor */
T2TTrainer
();
...
...
@@ -98,11 +114,19 @@ public:
/* train the model */
void
Train
(
const
char
*
fn
,
T2TModel
*
model
);
/* test the model */
void
Test
(
const
char
*
fn
,
const
char
*
ofn
,
T2TModel
*
model
);
/* load data to buffer */
int
LoadBuf
(
FILE
*
file
);
/* clear data buffer */
void
ClearBuf
();
/* load a batch of sequences */
int
LoadBatch
(
FILE
*
file
,
XTensor
*
batch
,
XTensor
*
padding
,
int
LoadBatch
(
FILE
*
file
,
bool
isLM
,
XTensor
*
batch
,
XTensor
*
padding
,
XTensor
*
output
,
int
*
seqs
,
int
step
,
int
vs
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
wCount
,
int
devID
,
XMem
*
mem
);
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/Transformer.cpp
查看文件 @
59d8b4bf
...
...
@@ -40,20 +40,23 @@ int TransformerMain(int argc, const char ** argv)
char
*
trainFN
=
new
char
[
MAX_LINE_LENGTH
];
char
*
modelFN
=
new
char
[
MAX_LINE_LENGTH
];
char
*
testFN
=
new
char
[
MAX_LINE_LENGTH
];
char
*
outputFN
=
new
char
[
MAX_LINE_LENGTH
];
LoadParamString
(
argc
,
argv
,
"train"
,
trainFN
,
""
);
LoadParamString
(
argc
,
argv
,
"model"
,
modelFN
,
""
);
LoadParamString
(
argc
,
argv
,
"test"
,
testFN
,
""
);
LoadParamString
(
argc
,
argv
,
"output"
,
outputFN
,
""
);
T2TTrainer
trainer
;
trainer
.
Init
(
argc
,
argv
);
T2TModel
model
;
model
.
InitModel
(
argc
,
argv
);
/* learn model parameters */
if
(
strcmp
(
trainFN
,
""
)){
T2TTrainer
trainer
;
trainer
.
Init
(
argc
,
argv
);
if
(
strcmp
(
trainFN
,
""
))
trainer
.
Train
(
trainFN
,
&
model
);
}
/* save the final model */
if
(
strcmp
(
modelFN
,
""
)
&&
strcmp
(
trainFN
,
""
))
...
...
@@ -63,9 +66,14 @@ int TransformerMain(int argc, const char ** argv)
if
(
strcmp
(
modelFN
,
""
))
model
.
Read
(
modelFN
);
/* test the model on the new data */
if
(
strcmp
(
testFN
,
""
)
&&
strcmp
(
outputFN
,
""
))
trainer
.
Test
(
testFN
,
outputFN
,
&
model
);
delete
[]
trainFN
;
delete
[]
modelFN
;
delete
[]
testFN
;
delete
[]
outputFN
;
fclose
(
tmpFILE
);
...
...
This diff is collapsed.
Click to expand it.
source/tensor/XGlobal.h
查看文件 @
59d8b4bf
...
...
@@ -147,6 +147,7 @@ extern bool useCUDA;
#define XPRINT4(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4);FFLUSH(FILEH);}}
#define XPRINT5(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5);FFLUSH(FILEH);}}
#define XPRINT6(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6);FFLUSH(FILEH);}}
#define XPRINT7(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7);FFLUSH(FILEH);}}
#define B2I(V) V==0?false:true
...
...
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论