Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
Emmay
NiuTrans.Tensor
Commits
3bf085db
Commit
3bf085db
authored
Mar 13, 2019
by
姜雨帆
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
implement MulAndShift and bug fixs
parent
411faffa
隐藏空白字符变更
内嵌
并排
正在显示
26 个修改的文件
包含
636 行增加
和
116 行删除
+636
-116
source/network/XBackwardMath.cpp
+124
-0
source/network/XBackwardMath.h
+4
-0
source/sample/transformer/T2TAttention.cpp
+33
-6
source/sample/transformer/T2TAttention.h
+4
-2
source/sample/transformer/T2TDecoder.cpp
+31
-5
source/sample/transformer/T2TDecoder.h
+48
-1
source/sample/transformer/T2TEmbedding.cpp
+8
-2
source/sample/transformer/T2TEmbedding.h
+1
-1
source/sample/transformer/T2TEncoder.cpp
+1
-1
source/sample/transformer/T2TFNN.cpp
+4
-2
source/sample/transformer/T2TModel.cpp
+13
-11
source/sample/transformer/T2TOutput.cpp
+1
-1
source/sample/transformer/T2TTrainer.cpp
+117
-56
source/sample/transformer/T2TTrainer.h
+5
-5
source/sample/transformer/Transformer.cpp
+6
-4
source/tensor/XLink.cpp
+21
-0
source/tensor/XLink.h
+4
-0
source/tensor/XName.cpp
+2
-0
source/tensor/XName.h
+2
-1
source/tensor/core/CHeader.h
+1
-0
source/tensor/core/arithmetic/MulAndShift.cpp
+139
-0
source/tensor/core/arithmetic/MulAndShift.h
+37
-0
source/tensor/core/getandset/OnehotAndIndex.cpp
+14
-9
source/tensor/core/getandset/OnehotAndIndex.cu
+13
-6
source/tensor/core/getandset/OnehotAndIndex.cuh
+1
-1
source/tensor/core/getandset/OnehotAndIndex.h
+2
-2
没有找到文件。
source/network/XBackwardMath.cpp
查看文件 @
3bf085db
...
@@ -99,6 +99,8 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient)
...
@@ -99,6 +99,8 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient)
GradReduceSumSquared
(
node
,
isEfficient
);
GradReduceSumSquared
(
node
,
isEfficient
);
else
if
(
operID
==
REDUCE_REDUCEVARIANCE
)
else
if
(
operID
==
REDUCE_REDUCEVARIANCE
)
GradReduceVariance
(
node
,
isEfficient
);
GradReduceVariance
(
node
,
isEfficient
);
else
if
(
operID
==
MATH_MULANDSHIFT
)
GradMulAndShift
(
node
,
isEfficient
);
else
{
else
{
ShowNTErrors
(
"TODO!"
);
ShowNTErrors
(
"TODO!"
);
}
}
...
@@ -1487,4 +1489,126 @@ void XMathGrad::GradReduceVariance(XTensor * node, bool isEfficient)
...
@@ -1487,4 +1489,126 @@ void XMathGrad::GradReduceVariance(XTensor * node, bool isEfficient)
node
->
visitMark
=
NODE_FINISHED
;
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
gradient for operation
for c = matmul(x, w) + b
we have
dE/dx = dE/dc * w^T
dE/dw = x^T * dE/dc
dE/db = dE/dc * x.reduce(0,...,n-1,n+1,...)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void
XMathGrad
::
GradMulAndShift
(
XTensor
*
node
,
bool
isEfficient
)
{
XLink
&
income
=
node
->
income
;
CheckNTErrors
(
income
.
tailNum
==
3
,
"wrong input tensor number"
)
XTensor
*
x
=
income
.
tails
[
0
];
XTensor
*
w
=
income
.
tails
[
1
];
XTensor
*
b
=
income
.
tails
[
2
];
int
n
=
income
.
GetParamInt
(
0
);
MATRIX_TRANS_TYPE
transW
=
income
.
GetParamTrans
(
1
);
MATRIX_TRANS_TYPE
transX
=
income
.
GetParamTrans
(
2
);
if
(
!
isEfficient
||
w
->
isGrad
)
XNoder
::
MakeGrad
(
w
);
if
(
!
isEfficient
||
x
->
isGrad
)
XNoder
::
MakeGrad
(
x
);
if
(
!
isEfficient
||
b
->
isGrad
)
XNoder
::
MakeGrad
(
b
);
int
order
=
node
->
order
;
int
dimSize
[
MAX_TENSOR_DIM_NUM
];
memcpy
(
dimSize
,
node
->
dimSize
,
sizeof
(
int
)
*
node
->
order
);
/* compute dE/db */
if
(
n
==
order
-
1
)
{
int
reshapedSize
[
MAX_TENSOR_DIM_NUM
];
reshapedSize
[
0
]
=
node
->
unitNum
/
dimSize
[
order
-
1
];
reshapedSize
[
1
]
=
dimSize
[
order
-
1
];
/* we reshape dE/dc to a matrix whose column number is equal to the
size of b. Then we can reduce the matrix into a row vector. */
node
->
grad
->
Reshape
(
2
,
reshapedSize
);
XTensor
*
bGradTMP
=
NewTensorBuf
(
b
->
grad
,
b
->
devID
,
b
->
mem
);
_ReduceSum
(
node
->
grad
,
bGradTMP
,
0
);
_Sum
(
bGradTMP
,
b
->
grad
,
b
->
grad
);
DelTensorBuf
(
bGradTMP
);
node
->
grad
->
Reshape
(
order
,
dimSize
);
}
else
{
int
reshapedSize
[
MAX_TENSOR_DIM_NUM
];
reshapedSize
[
0
]
=
1
;
reshapedSize
[
1
]
=
dimSize
[
n
];
reshapedSize
[
2
]
=
1
;
for
(
int
i
=
0
;
i
<
order
;
i
++
)
{
if
(
i
<
n
)
reshapedSize
[
0
]
*=
dimSize
[
i
];
}
reshapedSize
[
2
]
=
node
->
unitNum
/
(
reshapedSize
[
0
]
*
reshapedSize
[
1
]);
/* we reshape dE/dc to a 3D tensor of size (x, y, z) where y = |b|.
Then reduce along with z and x to obtain dE/db. */
node
->
grad
->
Reshape
(
3
,
reshapedSize
);
XTensor
*
interGrad
=
NewTensorBuf
(
2
,
reshapedSize
,
b
->
dataType
,
b
->
denseRatio
,
b
->
devID
,
b
->
mem
);
_ReduceSum
(
node
->
grad
,
interGrad
,
2
);
XTensor
*
bGradTMP
=
NewTensorBuf
(
b
->
grad
,
b
->
devID
,
b
->
mem
);
_ReduceSum
(
interGrad
,
bGradTMP
,
0
);
_Sum
(
bGradTMP
,
b
->
grad
,
b
->
grad
);
DelTensorBuf
(
bGradTMP
);
node
->
grad
->
Reshape
(
order
,
dimSize
);
DelTensorBuf
(
interGrad
);
}
/* compute dE/dx, dE/dw */
XTensor
*
c
=
node
;
XTensor
*
dedc
=
node
->
grad
;
XTensor
*
dedw
=
w
->
grad
;
XTensor
*
dedx
=
x
->
grad
;
if
(
x
->
order
==
2
&&
w
->
order
==
2
)
GradMatrixMul
(
x
,
dedx
,
transX
,
w
,
dedw
,
transW
,
dedc
,
1.0
F
,
isEfficient
);
else
if
(
transX
==
X_NOTRANS
&&
x
->
order
>
2
&&
w
->
order
==
2
){
int
orderBackupX
=
x
->
order
;
int
orderBackupC
=
c
->
order
;
int
dimsBackupX
[
MAX_TENSOR_DIM_NUM
];
int
dimsBackupC
[
MAX_TENSOR_DIM_NUM
];
memcpy
(
dimsBackupX
,
x
->
dimSize
,
sizeof
(
int
)
*
x
->
order
);
memcpy
(
dimsBackupC
,
c
->
dimSize
,
sizeof
(
int
)
*
c
->
order
);
x
->
Reshape
(
x
->
unitNum
/
x
->
GetDim
(
-
1
),
x
->
GetDim
(
-
1
));
c
->
Reshape
(
c
->
unitNum
/
c
->
GetDim
(
-
1
),
c
->
GetDim
(
-
1
));
if
(
!
isEfficient
||
x
->
isGrad
)
dedx
->
Reshape
(
dedx
->
unitNum
/
dedx
->
GetDim
(
-
1
),
dedx
->
GetDim
(
-
1
));
dedc
->
Reshape
(
dedc
->
unitNum
/
dedc
->
GetDim
(
-
1
),
dedc
->
GetDim
(
-
1
));
GradMatrixMul
(
x
,
dedx
,
transX
,
w
,
dedw
,
transW
,
dedc
,
1.0
F
,
isEfficient
);
x
->
Reshape
(
orderBackupX
,
dimsBackupX
);
c
->
Reshape
(
orderBackupC
,
dimsBackupC
);
if
(
!
isEfficient
||
x
->
isGrad
)
dedx
->
Reshape
(
orderBackupX
,
dimsBackupX
);
dedc
->
Reshape
(
orderBackupC
,
dimsBackupC
);
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
}
source/network/XBackwardMath.h
查看文件 @
3bf085db
...
@@ -168,6 +168,10 @@ private:
...
@@ -168,6 +168,10 @@ private:
/* gradient for reduceVariance */
/* gradient for reduceVariance */
static
static
void
GradReduceVariance
(
XTensor
*
node
,
bool
isEfficient
);
void
GradReduceVariance
(
XTensor
*
node
,
bool
isEfficient
);
/* gradient for operation */
static
void
GradMulAndShift
(
XTensor
*
node
,
bool
isEfficient
);
};
};
}
}
...
...
source/sample/transformer/T2TAttention.cpp
查看文件 @
3bf085db
...
@@ -75,16 +75,19 @@ void T2TAttention::InitModel(int argc, char ** argv,
...
@@ -75,16 +75,19 @@ void T2TAttention::InitModel(int argc, char ** argv,
InitTensor2D
(
&
wq
,
d
,
dk
,
X_FLOAT
,
devID
,
mem
);
InitTensor2D
(
&
wq
,
d
,
dk
,
X_FLOAT
,
devID
,
mem
);
InitTensor2D
(
&
wv
,
d
,
dv
,
X_FLOAT
,
devID
,
mem
);
InitTensor2D
(
&
wv
,
d
,
dv
,
X_FLOAT
,
devID
,
mem
);
InitTensor2D
(
&
wa
,
d
,
d
,
X_FLOAT
,
devID
,
mem
);
InitTensor2D
(
&
wa
,
d
,
d
,
X_FLOAT
,
devID
,
mem
);
InitTensor2D
(
&
wbig
,
d
,
3
*
d
,
X_FLOAT
,
devID
,
mem
);
float
scale
=
1.0
F
;
float
scale
=
1.0
F
;
float
finfoutk
=
(
float
)
sqrt
(
6.0
F
*
scale
/
(
d
+
dk
));
float
finfoutk
=
(
float
)
sqrt
(
6.0
F
*
scale
/
(
d
+
dk
));
float
finfoutv
=
(
float
)
sqrt
(
6.0
F
*
scale
/
(
d
+
dv
));
float
finfoutv
=
(
float
)
sqrt
(
6.0
F
*
scale
/
(
d
+
dv
));
float
finfouta
=
(
float
)
sqrt
(
6.0
F
*
scale
/
(
d
+
d
));
float
finfouta
=
(
float
)
sqrt
(
6.0
F
*
scale
/
(
d
+
d
));
float
finfoutbig
=
(
float
)
sqrt
(
6.0
F
*
scale
/
(
d
+
3
*
d
));
wk
.
SetDataRand
(
-
finfoutk
,
finfoutk
);
wk
.
SetDataRand
(
-
finfoutk
,
finfoutk
);
wq
.
SetDataRand
(
-
finfoutk
,
finfoutk
);
wq
.
SetDataRand
(
-
finfoutk
,
finfoutk
);
wv
.
SetDataRand
(
-
finfoutv
,
finfoutv
);
wv
.
SetDataRand
(
-
finfoutv
,
finfoutv
);
wa
.
SetDataRand
(
-
finfouta
,
finfouta
);
wa
.
SetDataRand
(
-
finfouta
,
finfouta
);
wbig
.
SetDataRand
(
-
finfoutbig
,
finfoutbig
);
}
}
/*
/*
...
@@ -98,16 +101,40 @@ make the network
...
@@ -98,16 +101,40 @@ make the network
>> isTraining - indicates whether the model is used for training
>> isTraining - indicates whether the model is used for training
<< return - multi-attention result
<< return - multi-attention result
*/
*/
XTensor
T2TAttention
::
Make
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
XTensor
&
mask
,
bool
isTraining
)
XTensor
T2TAttention
::
Make
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
XTensor
&
mask
,
bool
isTraining
,
bool
selfatt
)
{
{
XTensor
k2
;
XTensor
k2
;
XTensor
q2
;
XTensor
q2
;
XTensor
v2
;
XTensor
v2
;
/* linear transofmration before self-attention */
if
(
selfatt
){
k2
=
MMul
(
k
,
wk
);
q2
=
MMul
(
q
,
wq
);
XTensor
con
;
v2
=
MMul
(
v
,
wv
);
XList
split
;
con
=
MMul
(
k
,
wbig
);
int
d1
=
con
.
GetDim
(
0
);
int
d2
=
con
.
GetDim
(
1
);
int
d3
=
con
.
GetDim
(
2
)
/
3
;
InitTensor3D
(
&
k2
,
d1
,
d2
,
d3
,
X_FLOAT
,
devID
,
mem
);
InitTensor3D
(
&
q2
,
d1
,
d2
,
d3
,
X_FLOAT
,
devID
,
mem
);
InitTensor3D
(
&
v2
,
d1
,
d2
,
d3
,
X_FLOAT
,
devID
,
mem
);
split
.
Add
(
&
q2
);
split
.
Add
(
&
k2
);
split
.
Add
(
&
v2
);
Split
(
con
,
split
,
2
,
3
);
}
else
{
/* linear transofmration before self-attention */
k2
=
MMul
(
k
,
wk
);
q2
=
MMul
(
q
,
wq
);
v2
=
MMul
(
v
,
wv
);
}
XTensor
kheads
;
XTensor
kheads
;
XTensor
qheads
;
XTensor
qheads
;
...
...
source/sample/transformer/T2TAttention.h
查看文件 @
3bf085db
...
@@ -59,7 +59,9 @@ public:
...
@@ -59,7 +59,9 @@ public:
/* transformation after dot-product attention */
/* transformation after dot-product attention */
XTensor
wa
;
XTensor
wa
;
XTensor
wbig
;
/* size of transformed Q and K */
/* size of transformed Q and K */
int
dk
;
int
dk
;
...
@@ -95,7 +97,7 @@ public:
...
@@ -95,7 +97,7 @@ public:
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
/* make the network */
/* make the network */
XTensor
Make
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
XTensor
&
mask
,
bool
isTraining
);
XTensor
Make
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
XTensor
&
mask
,
bool
isTraining
,
bool
selfatt
);
};
};
}
}
...
...
source/sample/transformer/T2TDecoder.cpp
查看文件 @
3bf085db
...
@@ -21,6 +21,8 @@
...
@@ -21,6 +21,8 @@
#include <math.h>
#include <math.h>
#include "T2TDecoder.h"
#include "T2TDecoder.h"
#include "T2TUtility.h"
#include "T2TLayerNormal.h"
#include "../../tensor/core/CHeader.h"
#include "../../tensor/core/CHeader.h"
namespace
transformer
namespace
transformer
...
@@ -53,14 +55,38 @@ void AttDecoder::InitModel(int argc, char ** argv,
...
@@ -53,14 +55,38 @@ void AttDecoder::InitModel(int argc, char ** argv,
bool
myIsMasked
,
int
myIgnored
,
bool
myIsMasked
,
int
myIgnored
,
int
myDevID
,
XMem
*
myMem
)
int
myDevID
,
XMem
*
myMem
)
{
{
AttEncoder
::
InitModel
(
argc
,
argv
,
myIsMasked
,
myIgnored
,
myDevID
,
myMem
);
//
AttEncoder::InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
devID
=
myDevID
;
mem
=
myMem
;
ignored
=
myIgnored
;
LoadParamInt
(
argc
,
argv
,
"nlayer"
,
&
nlayer
,
6
);
LoadParamInt
(
argc
,
argv
,
"hsize"
,
&
hSize
,
DEFAULT_EMBEDDING_SIZE
);
LoadParamInt
(
argc
,
argv
,
"esize"
,
&
eSize
,
DEFAULT_EMBEDDING_SIZE
);
LoadParamInt
(
argc
,
argv
,
"vsizetgt"
,
&
vSize
,
-
1
);
LoadParamFloat
(
argc
,
argv
,
"dropout"
,
&
dropoutP
,
0
);
CheckNTErrors
(
nlayer
>=
1
,
"We have one encoding layer at least!"
);
CheckNTErrors
(
vSize
>
1
,
"set vocabulary size by
\"
-vsize
\"
"
);
/* embedding model */
embedder
.
InitModel
(
argc
,
argv
,
devID
,
mem
,
false
);
attentions
=
new
T2TAttention
[
nlayer
];
fnns
=
new
T2TFNN
[
nlayer
];
attLayerNorms
=
new
T2TLN
[
nlayer
];
fnnLayerNorms
=
new
T2TLN
[
nlayer
];
attentionsEnde
=
new
T2TAttention
[
nlayer
];
attentionsEnde
=
new
T2TAttention
[
nlayer
];
attEndeLayerNorms
=
new
T2TLN
[
nlayer
];
attEndeLayerNorms
=
new
T2TLN
[
nlayer
];
/* initialize the stacked layers */
/* initialize the stacked layers */
for
(
int
i
=
0
;
i
<
nlayer
;
i
++
){
for
(
int
i
=
0
;
i
<
nlayer
;
i
++
)
{
attentionsEnde
[
i
].
InitModel
(
argc
,
argv
,
myIsMasked
,
myIgnored
,
myDevID
,
myMem
);
attentions
[
i
].
InitModel
(
argc
,
argv
,
myIsMasked
,
myIgnored
,
myDevID
,
myMem
);
fnns
[
i
].
InitModel
(
argc
,
argv
,
myDevID
,
myMem
);
attLayerNorms
[
i
].
InitModel
(
argc
,
argv
,
myDevID
,
myMem
);
fnnLayerNorms
[
i
].
InitModel
(
argc
,
argv
,
myDevID
,
myMem
);
attentionsEnde
[
i
].
InitModel
(
argc
,
argv
,
true
,
myIgnored
,
myDevID
,
myMem
);
attEndeLayerNorms
[
i
].
InitModel
(
argc
,
argv
,
myDevID
,
myMem
);
attEndeLayerNorms
[
i
].
InitModel
(
argc
,
argv
,
myDevID
,
myMem
);
}
}
}
}
...
@@ -93,7 +119,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
...
@@ -93,7 +119,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/******************/
/******************/
/* self attention */
/* self attention */
att
=
attentions
[
i
].
Make
(
x
,
x
,
x
,
mask
,
isTraining
);
att
=
attentions
[
i
].
Make
(
x
,
x
,
x
,
mask
,
isTraining
,
true
);
/* dropout */
/* dropout */
if
(
isTraining
&&
dropoutP
>
0
)
if
(
isTraining
&&
dropoutP
>
0
)
...
@@ -107,7 +133,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
...
@@ -107,7 +133,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/*****************************/
/*****************************/
/* encoder-decoder attention */
/* encoder-decoder attention */
ende
=
attentionsEnde
[
i
].
Make
(
outputEnc
,
x
,
outputEnc
,
maskEncDec
,
isTraining
);
ende
=
attentionsEnde
[
i
].
Make
(
outputEnc
,
x
,
outputEnc
,
maskEncDec
,
isTraining
,
false
);
/* dropout */
/* dropout */
if
(
isTraining
&&
dropoutP
>
0
)
if
(
isTraining
&&
dropoutP
>
0
)
...
...
source/sample/transformer/T2TDecoder.h
查看文件 @
3bf085db
...
@@ -27,9 +27,56 @@
...
@@ -27,9 +27,56 @@
namespace
transformer
namespace
transformer
{
{
class
AttDecoder
:
public
AttEncoder
class
AttDecoder
{
{
public
:
public
:
/* device id */
int
devID
;
/* memory pool */
XMem
*
mem
;
/* layer number */
int
nlayer
;
/* hidden layer size of the FNN layer */
int
hSize
;
/* embedding size */
int
eSize
;
/* vocabulary size */
int
vSize
;
/* dropout probability */
DTYPE
dropoutP
;
/* some positions can be ignored in attention. this is useful in lm where the first position needs
* special design for the attention model. */
int
ignored
;
/* embedding of word at each position */
T2TEmbedder
embedder
;
/* FNN model of each layer */
T2TFNN
*
fnns
;
/* attention model of each layer */
T2TAttention
*
attentions
;
/* layer normalization for fnn */
T2TLN
*
fnnLayerNorms
;
/* layer normalization for attention */
T2TLN
*
attLayerNorms
;
/* input tensor of the encoder */
XTensor
*
input
;
/* output tensor of the encoder */
XTensor
*
output
;
/* encoder-decoder attention model of each layer */
/* encoder-decoder attention model of each layer */
T2TAttention
*
attentionsEnde
;
T2TAttention
*
attentionsEnde
;
...
...
source/sample/transformer/T2TEmbedding.cpp
查看文件 @
3bf085db
...
@@ -48,12 +48,18 @@ initialize the model
...
@@ -48,12 +48,18 @@ initialize the model
>> myDevID - device id
>> myDevID - device id
>> myMem - the memory pool
>> myMem - the memory pool
*/
*/
void
T2TEmbedder
::
InitModel
(
int
argc
,
char
**
argv
,
int
myDevID
,
XMem
*
myMem
)
void
T2TEmbedder
::
InitModel
(
int
argc
,
char
**
argv
,
int
myDevID
,
XMem
*
myMem
,
bool
isEnc
)
{
{
devID
=
myDevID
;
devID
=
myDevID
;
mem
=
myMem
;
mem
=
myMem
;
LoadParamInt
(
argc
,
argv
,
"vsize"
,
&
vSize
,
-
1
);
if
(
isEnc
){
LoadParamInt
(
argc
,
argv
,
"vsize"
,
&
vSize
,
-
1
);
}
else
{
LoadParamInt
(
argc
,
argv
,
"vsizetgt"
,
&
vSize
,
-
1
);
}
//LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamInt
(
argc
,
argv
,
"maxlen"
,
&
maxLength
,
512
);
LoadParamInt
(
argc
,
argv
,
"maxlen"
,
&
maxLength
,
512
);
LoadParamInt
(
argc
,
argv
,
"d"
,
&
eSize
,
DEFAULT_EMBEDDING_SIZE
);
LoadParamInt
(
argc
,
argv
,
"d"
,
&
eSize
,
DEFAULT_EMBEDDING_SIZE
);
LoadParamInt
(
argc
,
argv
,
"d"
,
&
d
,
DEFAULT_EMBEDDING_SIZE
);
LoadParamInt
(
argc
,
argv
,
"d"
,
&
d
,
DEFAULT_EMBEDDING_SIZE
);
...
...
source/sample/transformer/T2TEmbedding.h
查看文件 @
3bf085db
...
@@ -71,7 +71,7 @@ public:
...
@@ -71,7 +71,7 @@ public:
~
T2TEmbedder
();
~
T2TEmbedder
();
/* initialize the model */
/* initialize the model */
void
InitModel
(
int
argc
,
char
**
argv
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
void
InitModel
(
int
argc
,
char
**
argv
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
,
bool
isEnc
=
true
);
/* make positional embeddings */
/* make positional embeddings */
void
MakePosEmbedding
(
int
eSize
,
int
d
,
int
length
);
void
MakePosEmbedding
(
int
eSize
,
int
d
,
int
length
);
...
...
source/sample/transformer/T2TEncoder.cpp
查看文件 @
3bf085db
...
@@ -114,7 +114,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
...
@@ -114,7 +114,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
XTensor
res
;
XTensor
res
;
/* self attention */
/* self attention */
att
=
attentions
[
i
].
Make
(
x
,
x
,
x
,
mask
,
isTraining
);
att
=
attentions
[
i
].
Make
(
x
,
x
,
x
,
mask
,
isTraining
,
true
);
/* dropout */
/* dropout */
if
(
isTraining
&&
dropoutP
>
0
)
if
(
isTraining
&&
dropoutP
>
0
)
...
...
source/sample/transformer/T2TFNN.cpp
查看文件 @
3bf085db
...
@@ -89,13 +89,15 @@ XTensor T2TFNN::Make(XTensor &input, bool isTraining)
...
@@ -89,13 +89,15 @@ XTensor T2TFNN::Make(XTensor &input, bool isTraining)
XTensor
t1
;
XTensor
t1
;
/* t1 = max(0, x * w1 + b1) */
/* t1 = max(0, x * w1 + b1) */
t1
=
Rectify
(
MMul
(
input
,
w1
)
+
b1
);
//t1 = Rectify(MMul(input, w1) + b1);
t1
=
Rectify
(
MulAndShift
(
input
,
w1
,
b1
));
if
(
isTraining
&&
dropoutP
>
0
)
if
(
isTraining
&&
dropoutP
>
0
)
t1
=
Dropout
(
t1
,
dropoutP
);
t1
=
Dropout
(
t1
,
dropoutP
);
/* result = t1 * w2 + b2 */
/* result = t1 * w2 + b2 */
return
MMul
(
t1
,
w2
)
+
b2
;
//return MMul(t1, w2) + b2;
return
MulAndShift
(
t1
,
w2
,
b2
);
}
}
...
...
source/sample/transformer/T2TModel.cpp
查看文件 @
3bf085db
...
@@ -219,7 +219,7 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
...
@@ -219,7 +219,7 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
dims
[
i
+
1
]
=
inputDec
.
GetDim
(
i
);
dims
[
i
+
1
]
=
inputDec
.
GetDim
(
i
);
dims
[
0
]
=
nhead
;
dims
[
0
]
=
nhead
;
dims
[
inputDec
.
order
+
1
]
=
len
;
dims
[
inputDec
.
order
+
1
]
=
len
;
InitTensor
(
&
maskDec
,
inputDec
.
order
+
2
,
dims
,
X_FLOAT
,
1.0
F
,
padding
Enc
.
devID
,
paddingEn
c
.
mem
);
InitTensor
(
&
maskDec
,
inputDec
.
order
+
2
,
dims
,
X_FLOAT
,
1.0
F
,
padding
Dec
.
devID
,
paddingDe
c
.
mem
);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
this matrix can be used to prevent the attention to current or following words in
...
@@ -236,10 +236,10 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
...
@@ -236,10 +236,10 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
XTensor
*
maskEncDecTMPDec
=
NewTensorBuf
(
maskEncDecTMPEnc
,
paddingEnc
.
devID
,
paddingEnc
.
mem
);
XTensor
*
maskEncDecTMPDec
=
NewTensorBuf
(
maskEncDecTMPEnc
,
paddingEnc
.
devID
,
paddingEnc
.
mem
);
_Unsqueeze
(
&
paddingEnc
,
maskEncDecTMPEnc
,
paddingEnc
.
order
-
1
,
paddingDec
.
GetDim
(
-
1
));
_Unsqueeze
(
&
paddingEnc
,
maskEncDecTMPEnc
,
paddingEnc
.
order
-
1
,
paddingDec
.
GetDim
(
-
1
));
_Unsqueeze
(
&
paddingDec
,
maskEncDecTMPDec
,
paddingEnc
.
order
,
paddingEnc
.
GetDim
(
-
1
));
//
_Unsqueeze(&paddingDec, maskEncDecTMPDec, paddingEnc.order, paddingEnc.GetDim(-1));
_Multiply
(
maskEncDecTMPDec
,
maskEncDecTMPEnc
,
maskEncDecTMPDec
);
//
_Multiply(maskEncDecTMPDec, maskEncDecTMPEnc, maskEncDecTMPDec);
_ScaleAndShiftMe
(
maskEncDecTMP
De
c
,
1e9
F
,
-
1e9
F
);
_ScaleAndShiftMe
(
maskEncDecTMP
En
c
,
1e9
F
,
-
1e9
F
);
_Unsqueeze
(
maskEncDecTMP
De
c
,
&
maskEncDec
,
0
,
dims
[
0
]);
_Unsqueeze
(
maskEncDecTMP
En
c
,
&
maskEncDec
,
0
,
dims
[
0
]);
DelTensorBuf
(
maskEncDecTMPDec
);
DelTensorBuf
(
maskEncDecTMPDec
);
DelTensorBuf
(
maskEncDecTMPEnc
);
DelTensorBuf
(
maskEncDecTMPEnc
);
...
@@ -300,9 +300,10 @@ void T2TModel::GetParams(XList &list)
...
@@ -300,9 +300,10 @@ void T2TModel::GetParams(XList &list)
list
.
Add
(
&
encoder
->
fnns
[
i
].
b1
);
list
.
Add
(
&
encoder
->
fnns
[
i
].
b1
);
list
.
Add
(
&
encoder
->
fnns
[
i
].
w2
);
list
.
Add
(
&
encoder
->
fnns
[
i
].
w2
);
list
.
Add
(
&
encoder
->
fnns
[
i
].
b2
);
list
.
Add
(
&
encoder
->
fnns
[
i
].
b2
);
list
.
Add
(
&
encoder
->
attentions
[
i
].
wk
);
//list.Add(&encoder->attentions[i].wk);
list
.
Add
(
&
encoder
->
attentions
[
i
].
wq
);
//list.Add(&encoder->attentions[i].wq);
list
.
Add
(
&
encoder
->
attentions
[
i
].
wv
);
//list.Add(&encoder->attentions[i].wv);
list
.
Add
(
&
encoder
->
attentions
[
i
].
wbig
);
list
.
Add
(
&
encoder
->
attentions
[
i
].
wa
);
list
.
Add
(
&
encoder
->
attentions
[
i
].
wa
);
list
.
Add
(
&
encoder
->
fnnLayerNorms
[
i
].
w
);
list
.
Add
(
&
encoder
->
fnnLayerNorms
[
i
].
w
);
list
.
Add
(
&
encoder
->
fnnLayerNorms
[
i
].
b
);
list
.
Add
(
&
encoder
->
fnnLayerNorms
[
i
].
b
);
...
@@ -324,9 +325,10 @@ void T2TModel::GetParams(XList &list)
...
@@ -324,9 +325,10 @@ void T2TModel::GetParams(XList &list)
list
.
Add
(
&
decoder
->
attentionsEnde
[
i
].
wa
);
list
.
Add
(
&
decoder
->
attentionsEnde
[
i
].
wa
);
list
.
Add
(
&
decoder
->
attEndeLayerNorms
[
i
].
w
);
list
.
Add
(
&
decoder
->
attEndeLayerNorms
[
i
].
w
);
list
.
Add
(
&
decoder
->
attEndeLayerNorms
[
i
].
b
);
list
.
Add
(
&
decoder
->
attEndeLayerNorms
[
i
].
b
);
list
.
Add
(
&
decoder
->
attentions
[
i
].
wk
);
//list.Add(&decoder->attentions[i].wk);
list
.
Add
(
&
decoder
->
attentions
[
i
].
wq
);
//list.Add(&decoder->attentions[i].wq);
list
.
Add
(
&
decoder
->
attentions
[
i
].
wv
);
//list.Add(&decoder->attentions[i].wv);
list
.
Add
(
&
decoder
->
attentions
[
i
].
wbig
);
list
.
Add
(
&
decoder
->
attentions
[
i
].
wa
);
list
.
Add
(
&
decoder
->
attentions
[
i
].
wa
);
list
.
Add
(
&
decoder
->
fnnLayerNorms
[
i
].
w
);
list
.
Add
(
&
decoder
->
fnnLayerNorms
[
i
].
w
);
list
.
Add
(
&
decoder
->
fnnLayerNorms
[
i
].
b
);
list
.
Add
(
&
decoder
->
fnnLayerNorms
[
i
].
b
);
...
...
source/sample/transformer/T2TOutput.cpp
查看文件 @
3bf085db
...
@@ -56,7 +56,7 @@ void T2TOutput::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
...
@@ -56,7 +56,7 @@ void T2TOutput::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
float
minmax
=
0
;
float
minmax
=
0
;
LoadParamInt
(
argc
,
argv
,
"vsize"
,
&
vSize
,
-
1
);
LoadParamInt
(
argc
,
argv
,
"vsize
tgt
"
,
&
vSize
,
-
1
);
LoadParamInt
(
argc
,
argv
,
"d"
,
&
inSize
,
DEFAULT_EMBEDDING_SIZE
);
LoadParamInt
(
argc
,
argv
,
"d"
,
&
inSize
,
DEFAULT_EMBEDDING_SIZE
);
LoadParamInt
(
argc
,
argv
,
"d"
,
&
hSize
,
DEFAULT_EMBEDDING_SIZE
);
LoadParamInt
(
argc
,
argv
,
"d"
,
&
hSize
,
DEFAULT_EMBEDDING_SIZE
);
LoadParamFloat
(
argc
,
argv
,
"outputminmax"
,
&
minmax
,
0.08
F
);
LoadParamFloat
(
argc
,
argv
,
"outputminmax"
,
&
minmax
,
0.08
F
);
...
...
source/sample/transformer/T2TTrainer.cpp
查看文件 @
3bf085db
...
@@ -148,8 +148,11 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
...
@@ -148,8 +148,11 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
{
{
int
step
=
0
;
int
step
=
0
;
int
wc
=
0
;
int
wc
=
0
;
int
ws
=
0
;
int
wordCount
=
0
;
int
wordCount
=
0
;
int
totalW
;
int
wordCountTotal
=
0
;
int
wordCountTotal
=
0
;
int
wordCountBatch
=
0
;
bool
isEnd
=
false
;
bool
isEnd
=
false
;
float
loss
=
0
;
float
loss
=
0
;
float
lr
=
0
;
float
lr
=
0
;
...
@@ -195,6 +198,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
...
@@ -195,6 +198,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
XTensor
batchEnc
;
XTensor
batchEnc
;
XTensor
batchDec
;
XTensor
batchDec
;
/* labels */
XTensor
label
;
/* padding */
/* padding */
XTensor
paddingEnc
;
XTensor
paddingEnc
;
XTensor
paddingDec
;
XTensor
paddingDec
;
...
@@ -205,9 +211,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
...
@@ -205,9 +211,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* label smoothed gold standard (if needed) */
/* label smoothed gold standard (if needed) */
XTensor
goldSmoothed
;
XTensor
goldSmoothed
;
while
(
LoadBatch
(
file
,
model
->
isLM
,
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
gold
,
while
(
LoadBatch
(
file
,
model
->
isLM
,
&
batchEnc
,
&
paddingEnc
,
&
batchDec
,
&
paddingDec
,
&
gold
,
&
label
,
NULL
,
vSize
,
vSizeTgt
,
NULL
,
vSize
,
vSizeTgt
,
sBatchSize
,
wBatchSize
,
isLenSorted
,
wc
,
devID
,
mem
,
true
))
sBatchSize
,
wBatchSize
,
isLenSorted
,
w
s
,
w
c
,
devID
,
mem
,
true
))
{
{
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
...
@@ -225,34 +231,41 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
...
@@ -225,34 +231,41 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
}
}
/* back-propagation for obtaining gradients */
/* back-propagation for obtaining gradients */
if
(
labelSmoothingP
>
0
)
//if (labelSmoothingP > 0)
LabelSmooth
(
&
gold
,
&
goldSmoothed
,
labelSmoothingP
);
// LabelSmooth(&gold, &goldSmoothed, labelSmoothingP);
XTensor
labelOnehot
;
labelOnehot
=
IndexToOnehot
(
label
,
vSizeTgt
,
labelSmoothingP
);
/* make paddings for the output */
/* make paddings for the output */
if
(
output
.
GetDim
(
0
)
>
1
)
if
(
output
.
GetDim
(
0
)
>
0
)
PadOutput
(
&
output
,
&
gold
,
&
paddingDec
);
PadOutput
(
&
output
,
&
labelOnehot
,
&
paddingDec
);
/* get probabilities */
/* get probabilities */
float
prob
=
GetProb
(
&
output
,
&
gold
,
NULL
);
float
prob
=
GetProb
(
&
output
,
&
labelOnehot
,
NULL
);
DTYPE
lossLocal
=
-
prob
/
wc
;
DTYPE
lossLocal
=
-
prob
/
wc
;
bool
doUpdate
=
(
!
IsNAN
(
lossLocal
)
&&
!
IsINF
(
lossLocal
)
&&
lossLocal
<
1e3
F
);
bool
doUpdate
=
(
!
IsNAN
(
lossLocal
)
&&
!
IsINF
(
lossLocal
)
&&
lossLocal
<
1e3
F
);
XTensor
&
g
=
labelSmoothingP
>
0
?
goldSmoothed
:
gold
;
//XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold;
if
(
doUpdate
)
{
if
(
doUpdate
)
{
/* recale the output for normalized loss */
/* recale the output for normalized loss */
RescaleOutput
(
&
output
,
&
g
,
&
paddingDec
);
RescaleOutput
(
&
output
,
&
labelOnehot
,
&
paddingDec
);
/* back-propagation */
/* back-propagation */
net
.
Backward
(
output
,
g
,
paddingDec
,
CROSSENTROPY
);
net
.
Backward
(
output
,
labelOnehot
,
paddingDec
,
CROSSENTROPY
);
//net.Backward(output, label, labelSmoothingP, CROSSENTROPY);
gradStep
+=
1
;
gradStep
+=
1
;
loss
+=
-
prob
;
loss
+=
-
prob
;
wordCount
+=
wc
;
wordCount
+=
wc
;
wordCountTotal
+=
wc
;
wordCountTotal
+=
wc
;
//totalW = wc + ws;
wordCountBatch
+=
ws
;
/* update the parameters */
/* update the parameters */
if
(
gradStep
==
updateStep
){
if
(
gradStep
==
updateStep
){
...
@@ -276,8 +289,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
...
@@ -276,8 +289,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if
(
step
%
100
==
0
)
{
if
(
step
%
100
==
0
)
{
double
elapsed
=
GetClockSec
()
-
startT
;
double
elapsed
=
GetClockSec
()
-
startT
;
XPRINT8
(
0
,
stderr
,
"[INFO]
lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d,
word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f"
,
XPRINT8
(
0
,
stderr
,
"[INFO]
elapsed=%.1fs, step=%d, epoch=%d, tword=%d, s
word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f"
,
lr
,
elapsed
,
step
,
epoch
,
wordCountTotal
,
loss
/
wordCount
,
exp
(
loss
/
wordCount
),
exp
(
-
prob
/
wc
));
elapsed
,
step
,
epoch
,
wordCountTotal
,
wordCountBatch
,
loss
/
wordCount
,
exp
(
loss
/
wordCount
),
exp
(
-
prob
/
wc
));
if
(
!
doUpdate
)
if
(
!
doUpdate
)
XPRINT
(
0
,
stderr
,
" (no update)"
);
XPRINT
(
0
,
stderr
,
" (no update)"
);
XPRINT
(
0
,
stderr
,
"
\n
"
);
XPRINT
(
0
,
stderr
,
"
\n
"
);
...
@@ -320,6 +333,7 @@ test the model
...
@@ -320,6 +333,7 @@ test the model
void
T2TTrainer
::
Test
(
const
char
*
fn
,
const
char
*
ofn
,
T2TModel
*
model
)
void
T2TTrainer
::
Test
(
const
char
*
fn
,
const
char
*
ofn
,
T2TModel
*
model
)
{
{
int
wc
=
0
;
int
wc
=
0
;
int
ws
=
0
;
int
wordCount
=
0
;
int
wordCount
=
0
;
int
wordCountTotal
=
0
;
int
wordCountTotal
=
0
;
int
sentCount
=
0
;
int
sentCount
=
0
;
...
@@ -344,6 +358,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
...
@@ -344,6 +358,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
XTensor
batchEnc
;
XTensor
batchEnc
;
XTensor
batchDec
;
XTensor
batchDec
;
/* label */
XTensor
label
;
/* padding */
/* padding */
XTensor
paddingEnc
;
XTensor
paddingEnc
;
XTensor
paddingDec
;
XTensor
paddingDec
;
...
@@ -356,9 +373,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
...
@@ -356,9 +373,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
ClearBuf
();
ClearBuf
();
while
(
LoadBatch
(
file
,
model
->
isLM
,
&
batchEnc
,
&
paddingEnc
,
&
paddingDec
,
&
paddingDec
,
&
gold
,
while
(
LoadBatch
(
file
,
model
->
isLM
,
&
batchEnc
,
&
paddingEnc
,
&
paddingDec
,
&
paddingDec
,
&
gold
,
&
label
,
seqs
,
vSize
,
vSizeTgt
,
seqs
,
vSize
,
vSizeTgt
,
1
,
1
,
false
,
wc
,
devID
,
mem
,
false
))
1
,
1
,
false
,
w
s
,
w
c
,
devID
,
mem
,
false
))
{
{
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
...
@@ -441,11 +458,11 @@ void T2TTrainer::MakeCheckpoint(T2TModel * model, const char * validFN, const ch
...
@@ -441,11 +458,11 @@ void T2TTrainer::MakeCheckpoint(T2TModel * model, const char * validFN, const ch
sprintf
(
fn2
,
"%s.%s.%03d.output"
,
modelFN
,
label
,
id
);
sprintf
(
fn2
,
"%s.%s.%03d.output"
,
modelFN
,
label
,
id
);
model
->
Dump
(
fn
);
model
->
Dump
(
fn
);
if
(
validFN
!=
NULL
){
//
if(validFN != NULL){
T2TTrainer
trainer
;
//
T2TTrainer trainer;
trainer
.
Init
(
argNum
,
argArray
);
//
trainer.Init(argNum, argArray);
trainer
.
Test
(
validFN
,
fn2
,
model
);
//
trainer.Test(validFN, fn2, model);
}
//
}
delete
[]
fn
;
delete
[]
fn
;
delete
[]
fn2
;
delete
[]
fn2
;
...
@@ -460,7 +477,7 @@ struct SampleNode
...
@@ -460,7 +477,7 @@ struct SampleNode
int
*
p
;
int
*
p
;
int
size
;
int
size
;
int
value
;
int
value
;
int
key
;
int
key
;
};
};
int
CompareSampleNode
(
const
void
*
a
,
const
void
*
b
)
int
CompareSampleNode
(
const
void
*
a
,
const
void
*
b
)
...
@@ -650,22 +667,22 @@ load a batch of sequences
...
@@ -650,22 +667,22 @@ load a batch of sequences
int
T2TTrainer
::
LoadBatch
(
FILE
*
file
,
bool
isLM
,
int
T2TTrainer
::
LoadBatch
(
FILE
*
file
,
bool
isLM
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
wCount
,
bool
isSorted
,
int
&
w
s
,
int
&
w
Count
,
int
devID
,
XMem
*
mem
,
int
devID
,
XMem
*
mem
,
bool
isTraining
)
bool
isTraining
)
{
{
if
(
isLM
){
if
(
isLM
){
return
LoadBatchLM
(
file
,
batchEnc
,
paddingEnc
,
batchDec
,
paddingDec
,
gold
,
return
LoadBatchLM
(
file
,
batchEnc
,
paddingEnc
,
batchDec
,
paddingDec
,
gold
,
label
,
seqs
,
vsEnc
,
sBatch
,
wBatch
,
seqs
,
vsEnc
,
sBatch
,
wBatch
,
isSorted
,
wCount
,
devID
,
mem
,
isTraining
);
isSorted
,
wCount
,
devID
,
mem
,
isTraining
);
}
}
else
{
else
{
return
LoadBatchMT
(
file
,
batchEnc
,
paddingEnc
,
batchDec
,
paddingDec
,
gold
,
return
LoadBatchMT
(
file
,
batchEnc
,
paddingEnc
,
batchDec
,
paddingDec
,
gold
,
label
,
seqs
,
vsEnc
,
vsDec
,
sBatch
,
wBatch
,
seqs
,
vsEnc
,
vsDec
,
sBatch
,
wBatch
,
isSorted
,
wCount
,
devID
,
mem
,
isTraining
);
isSorted
,
w
s
,
w
Count
,
devID
,
mem
,
isTraining
);
}
}
}
}
...
@@ -691,7 +708,7 @@ load a batch of sequences (for LM)
...
@@ -691,7 +708,7 @@ load a batch of sequences (for LM)
int
T2TTrainer
::
LoadBatchLM
(
FILE
*
file
,
int
T2TTrainer
::
LoadBatchLM
(
FILE
*
file
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
*
seqs
,
int
vs
,
int
sBatch
,
int
wBatch
,
int
vs
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
wCount
,
bool
isSorted
,
int
&
wCount
,
...
@@ -733,22 +750,29 @@ int T2TTrainer::LoadBatchLM(FILE * file,
...
@@ -733,22 +750,29 @@ int T2TTrainer::LoadBatchLM(FILE * file,
dims
[
2
]
=
vs
;
dims
[
2
]
=
vs
;
InitTensor2D
(
batchEnc
,
sc
,
max
,
X_INT
,
devID
,
mem
);
InitTensor2D
(
batchEnc
,
sc
,
max
,
X_INT
,
devID
,
mem
);
InitTensor2D
(
label
,
sc
,
max
,
X_INT
,
devID
,
mem
);
InitTensor
(
gold
,
3
,
dims
,
X_FLOAT
,
1.0
F
,
devID
,
mem
);
InitTensor
(
gold
,
3
,
dims
,
X_FLOAT
,
1.0
F
,
devID
,
mem
);
InitTensor2D
(
paddingEnc
,
sc
,
max
,
X_FLOAT
,
devID
,
mem
);
InitTensor2D
(
paddingEnc
,
sc
,
max
,
X_FLOAT
,
devID
,
mem
);
InitTensor2D
(
paddingDec
,
sc
,
max
,
X_FLOAT
,
devID
,
mem
);
InitTensor2D
(
paddingDec
,
sc
,
max
,
X_FLOAT
,
devID
,
mem
);
batchEnc
->
SetZeroAll
();
batchEnc
->
SetZeroAll
();
label
->
SetZeroAll
();
gold
->
SetZeroAll
();
gold
->
SetZeroAll
();
paddingEnc
->
SetZeroAll
();
paddingEnc
->
SetZeroAll
();
paddingDec
->
SetZeroAll
();
paddingDec
->
SetZeroAll
();
int
seqSize
=
0
;
int
seqSize
=
0
;
int
wGold
=
0
;
int
*
batchEncValues
=
new
int
[
batchEnc
->
unitNum
];
int
*
batchEncValues
=
new
int
[
batchEnc
->
unitNum
];
int
*
labelValues
=
new
int
[
label
->
unitNum
];
MTYPE
*
goldOffsets
=
new
MTYPE
[
gold
->
unitNum
];
MTYPE
*
goldOffsets
=
new
MTYPE
[
gold
->
unitNum
];
MTYPE
*
paddingEncOffsets
=
new
MTYPE
[
paddingEnc
->
unitNum
];
MTYPE
*
paddingDecOffsets
=
new
MTYPE
[
paddingDec
->
unitNum
];
int
wGold
=
0
;
memset
(
batchEncValues
,
0
,
sizeof
(
int
)
*
batchEnc
->
unitNum
);
memset
(
batchEncValues
,
0
,
sizeof
(
int
)
*
batchEnc
->
unitNum
);
memset
(
labelValues
,
0
,
sizeof
(
int
)
*
label
->
unitNum
);
for
(
int
s
=
seq
;
s
<
seq
+
sc
;
s
++
){
for
(
int
s
=
seq
;
s
<
seq
+
sc
;
s
++
){
int
len
=
isDoubledEnd
?
seqLen
[
s
]
:
seqLen
[
s
]
-
1
;
int
len
=
isDoubledEnd
?
seqLen
[
s
]
:
seqLen
[
s
]
-
1
;
...
@@ -756,15 +780,23 @@ int T2TTrainer::LoadBatchLM(FILE * file,
...
@@ -756,15 +780,23 @@ int T2TTrainer::LoadBatchLM(FILE * file,
for
(
int
w
=
0
;
w
<
len
;
w
++
){
for
(
int
w
=
0
;
w
<
len
;
w
++
){
int
num
=
buf
[
seqOffset
[
s
]
+
w
];
int
num
=
buf
[
seqOffset
[
s
]
+
w
];
batchEncValues
[(
int
)
batchEnc
->
GetOffset2D
(
s
-
seq
,
w
)]
=
num
;
batchEncValues
[(
int
)
batchEnc
->
GetOffset2D
(
s
-
seq
,
w
)]
=
num
;
paddingEncOffsets
[
wCount
]
=
paddingEnc
->
GetOffset2D
(
s
-
seq
,
w
);
if
(
w
>
0
)
paddingDecOffsets
[
wCount
]
=
paddingDec
->
GetOffset2D
(
s
-
seq
,
w
);
if
(
w
>
0
)
{
goldOffsets
[
wGold
++
]
=
gold
->
GetOffset3D
(
s
-
seq
,
w
-
1
,
num
);
goldOffsets
[
wGold
++
]
=
gold
->
GetOffset3D
(
s
-
seq
,
w
-
1
,
num
);
labelValues
[(
int
)
label
->
GetOffset2D
(
s
-
seq
,
w
-
1
)]
=
buf
[
seqOffset
[
s
]
+
w
];
}
if
(
w
==
len
-
1
)
{
if
(
w
==
len
-
1
)
{
if
(
isDoubledEnd
)
if
(
isDoubledEnd
)
{
goldOffsets
[
wGold
++
]
=
gold
->
GetOffset3D
(
s
-
seq
,
w
,
num
);
goldOffsets
[
wGold
++
]
=
gold
->
GetOffset3D
(
s
-
seq
,
w
,
num
);
else
labelValues
[(
int
)
label
->
GetOffset2D
(
s
-
seq
,
w
)]
=
buf
[
seqOffset
[
s
]
+
w
];
}
else
{
goldOffsets
[
wGold
++
]
=
gold
->
GetOffset3D
(
s
-
seq
,
w
,
buf
[
seqOffset
[
s
]
+
w
+
1
]);
goldOffsets
[
wGold
++
]
=
gold
->
GetOffset3D
(
s
-
seq
,
w
,
buf
[
seqOffset
[
s
]
+
w
+
1
]);
labelValues
[(
int
)
label
->
GetOffset2D
(
s
-
seq
,
w
)]
=
buf
[
seqOffset
[
s
]
+
w
+
1
];
}
}
}
wCount
++
;
wCount
++
;
...
@@ -780,9 +812,12 @@ int T2TTrainer::LoadBatchLM(FILE * file,
...
@@ -780,9 +812,12 @@ int T2TTrainer::LoadBatchLM(FILE * file,
}
}
batchEnc
->
SetData
(
batchEncValues
,
batchEnc
->
unitNum
);
batchEnc
->
SetData
(
batchEncValues
,
batchEnc
->
unitNum
);
label
->
SetData
(
labelValues
,
label
->
unitNum
);
gold
->
SetDataBatched
(
goldOffsets
,
1.0
F
,
wGold
);
gold
->
SetDataBatched
(
goldOffsets
,
1.0
F
,
wGold
);
paddingEnc
->
SetDataBatched
(
paddingEncOffsets
,
1.0
F
,
wCount
);
paddingDec
->
SetDataBatched
(
paddingDecOffsets
,
1.0
F
,
wCount
);
XTensor
*
tmp
=
NewTensorBuf
(
paddingEnc
,
devID
,
mem
);
/*
XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
_ConvertDataType(batchEnc, tmp);
_ConvertDataType(batchEnc, tmp);
_NotEqual(tmp, paddingEnc, 0);
_NotEqual(tmp, paddingEnc, 0);
DelTensorBuf(tmp);
DelTensorBuf(tmp);
...
@@ -790,10 +825,13 @@ int T2TTrainer::LoadBatchLM(FILE * file,
...
@@ -790,10 +825,13 @@ int T2TTrainer::LoadBatchLM(FILE * file,
XTensor * tmp2 = NewTensorBuf(paddingDec, devID, mem);
XTensor * tmp2 = NewTensorBuf(paddingDec, devID, mem);
_ConvertDataType(batchEnc, tmp2);
_ConvertDataType(batchEnc, tmp2);
_NotEqual(tmp2, paddingDec, 0);
_NotEqual(tmp2, paddingDec, 0);
DelTensorBuf
(
tmp2
);
DelTensorBuf(tmp2);
*/
delete
[]
batchEncValues
;
delete
[]
batchEncValues
;
delete
[]
labelValues
;
delete
[]
goldOffsets
;
delete
[]
goldOffsets
;
delete
[]
paddingEncOffsets
;
delete
[]
paddingDecOffsets
;
fflush
(
tf
);
fflush
(
tf
);
...
@@ -828,10 +866,10 @@ load a batch of sequences (for MT)
...
@@ -828,10 +866,10 @@ load a batch of sequences (for MT)
int
T2TTrainer
::
LoadBatchMT
(
FILE
*
file
,
int
T2TTrainer
::
LoadBatchMT
(
FILE
*
file
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
wCount
,
bool
isSorted
,
int
&
w
s
,
int
&
w
Count
,
int
devID
,
XMem
*
mem
,
int
devID
,
XMem
*
mem
,
bool
isTraining
)
bool
isTraining
)
{
{
...
@@ -915,25 +953,32 @@ int T2TTrainer::LoadBatchMT(FILE * file,
...
@@ -915,25 +953,32 @@ int T2TTrainer::LoadBatchMT(FILE * file,
InitTensor2D
(
paddingEnc
,
sCount
,
maxEnc
,
X_FLOAT
,
devID
,
mem
);
InitTensor2D
(
paddingEnc
,
sCount
,
maxEnc
,
X_FLOAT
,
devID
,
mem
);
InitTensor2D
(
batchDec
,
sCount
,
maxDec
,
X_INT
,
devID
,
mem
);
InitTensor2D
(
batchDec
,
sCount
,
maxDec
,
X_INT
,
devID
,
mem
);
InitTensor2D
(
paddingDec
,
sCount
,
maxDec
,
X_FLOAT
,
devID
,
mem
);
InitTensor2D
(
paddingDec
,
sCount
,
maxDec
,
X_FLOAT
,
devID
,
mem
);
InitTensor
(
gold
,
3
,
dimsDec
,
X_FLOAT
,
1.0
F
,
devID
,
mem
);
InitTensor2D
(
label
,
sCount
,
maxDec
,
X_INT
,
devID
,
mem
);
//InitTensor(gold, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
batchEnc
->
SetZeroAll
();
batchEnc
->
SetZeroAll
();
paddingEnc
->
SetZeroAll
();
paddingEnc
->
SetZeroAll
();
batchDec
->
SetZeroAll
();
batchDec
->
SetZeroAll
();
paddingDec
->
SetZeroAll
();
paddingDec
->
SetZeroAll
();
gold
->
SetZeroAll
();
label
->
SetZeroAll
();
//gold->SetZeroAll();
int
wCountEnc
=
0
;
int
wCountEnc
=
0
;
int
wCountDec
=
0
;
int
wCountDec
=
0
;
int
wCountPad
=
0
;
int
wGold
=
0
;
int
wGold
=
0
;
wCount
=
0
;
wCount
=
0
;
int
*
batchEncValues
=
new
int
[
batchEnc
->
unitNum
];
int
*
batchEncValues
=
new
int
[
batchEnc
->
unitNum
];
int
*
batchDecValues
=
new
int
[
batchDec
->
unitNum
];
int
*
batchDecValues
=
new
int
[
batchDec
->
unitNum
];
MTYPE
*
goldOffsets
=
new
MTYPE
[
sc
*
maxDec
/
2
];
int
*
labelValues
=
new
int
[
label
->
unitNum
];
//MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2];
MTYPE
*
paddingDecOffsets
=
new
MTYPE
[
sc
*
maxDec
/
2
];
//MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2];
memset
(
batchEncValues
,
0
,
sizeof
(
int
)
*
batchEnc
->
unitNum
);
memset
(
batchEncValues
,
0
,
sizeof
(
int
)
*
batchEnc
->
unitNum
);
memset
(
batchDecValues
,
0
,
sizeof
(
int
)
*
batchDec
->
unitNum
);
memset
(
batchDecValues
,
0
,
sizeof
(
int
)
*
batchDec
->
unitNum
);
memset
(
labelValues
,
0
,
sizeof
(
int
)
*
batchDec
->
unitNum
);
/* batch of the source-side sequences */
/* batch of the source-side sequences */
for
(
int
s
=
seq
;
s
<
seq
+
sc
;
s
+=
2
){
for
(
int
s
=
seq
;
s
<
seq
+
sc
;
s
+=
2
){
...
@@ -942,11 +987,13 @@ int T2TTrainer::LoadBatchMT(FILE * file,
...
@@ -942,11 +987,13 @@ int T2TTrainer::LoadBatchMT(FILE * file,
for
(
int
w
=
0
;
w
<
len
;
w
++
){
for
(
int
w
=
0
;
w
<
len
;
w
++
){
int
num
=
buf
[
seqOffset
[
s
]
+
w
];
int
num
=
buf
[
seqOffset
[
s
]
+
w
];
batchEncValues
[
batchEnc
->
GetOffset2D
(
sent
,
w
)]
=
num
;
batchEncValues
[
batchEnc
->
GetOffset2D
(
sent
,
w
)]
=
num
;
//paddingEncOffsets[wCountEnc] = paddingEnc->GetOffset2D(sent, w);
wCountEnc
++
;
wCountEnc
++
;
}
}
}
}
ws
=
wCountEnc
;
batchEnc
->
SetData
(
batchEncValues
,
batchEnc
->
unitNum
);
batchEnc
->
SetData
(
batchEncValues
,
batchEnc
->
unitNum
);
//paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc);
XTensor
*
tmp
=
NewTensorBuf
(
paddingEnc
,
devID
,
mem
);
XTensor
*
tmp
=
NewTensorBuf
(
paddingEnc
,
devID
,
mem
);
_ConvertDataType
(
batchEnc
,
tmp
);
_ConvertDataType
(
batchEnc
,
tmp
);
_NotEqual
(
tmp
,
paddingEnc
,
0
);
_NotEqual
(
tmp
,
paddingEnc
,
0
);
...
@@ -960,17 +1007,26 @@ int T2TTrainer::LoadBatchMT(FILE * file,
...
@@ -960,17 +1007,26 @@ int T2TTrainer::LoadBatchMT(FILE * file,
for
(
int
w
=
0
;
w
<
len
;
w
++
){
for
(
int
w
=
0
;
w
<
len
;
w
++
){
int
num
=
buf
[
seqOffset
[
s
]
+
w
];
int
num
=
buf
[
seqOffset
[
s
]
+
w
];
batchDecValues
[
batchDec
->
GetOffset2D
(
sent
,
w
)]
=
num
;
batchDecValues
[
batchDec
->
GetOffset2D
(
sent
,
w
)]
=
num
;
//paddingDecOffsets[wCountDec] = paddingDec->GetOffset2D(sent, w);
if
(
w
>
0
)
if
(
w
<
len
-
1
){
goldOffsets
[
wGold
++
]
=
gold
->
GetOffset3D
(
sent
,
w
-
1
,
buf
[
seqOffset
[
s
]
+
w
]);
paddingDecOffsets
[
wCountPad
++
]
=
paddingDec
->
GetOffset2D
(
sent
,
w
);
wCount
++
;
}
if
(
w
>
0
)
{
//goldOffsets[wGold++] = gold->GetOffset3D(sent, w - 1, buf[seqOffset[s] + w]);
labelValues
[
label
->
GetOffset2D
(
sent
,
w
-
1
)]
=
buf
[
seqOffset
[
s
]
+
w
];
}
if
(
w
==
len
-
1
)
{
if
(
w
==
len
-
1
)
{
if
(
isDoubledEnd
)
if
(
isDoubledEnd
)
{
goldOffsets
[
wGold
++
]
=
gold
->
GetOffset3D
(
sent
,
w
,
buf
[
seqOffset
[
s
]
+
w
]);
//goldOffsets[wGold++] = gold->GetOffset3D(sent, w, buf[seqOffset[s] + w]);
else
labelValues
[
label
->
GetOffset2D
(
sent
,
w
)]
=
buf
[
seqOffset
[
s
]
+
w
];
goldOffsets
[
wGold
++
]
=
gold
->
GetOffset3D
(
sent
,
w
,
buf
[
seqOffset
[
s
]
+
w
+
1
]);
}
else
{
//goldOffsets[wGold++] = gold->GetOffset3D(sent, w, buf[seqOffset[s] + w + 1]);
labelValues
[
label
->
GetOffset2D
(
sent
,
w
)]
=
buf
[
seqOffset
[
s
]
+
w
+
1
];
}
}
}
wCount
++
;
//
wCount++;
wCountDec
++
;
wCountDec
++
;
if
(
seqs
!=
NULL
)
if
(
seqs
!=
NULL
)
seqs
[
seqSize
++
]
=
buf
[
seqOffset
[
s
]
+
w
];
seqs
[
seqSize
++
]
=
buf
[
seqOffset
[
s
]
+
w
];
...
@@ -983,17 +1039,22 @@ int T2TTrainer::LoadBatchMT(FILE * file,
...
@@ -983,17 +1039,22 @@ int T2TTrainer::LoadBatchMT(FILE * file,
}
}
batchDec
->
SetData
(
batchDecValues
,
batchDec
->
unitNum
);
batchDec
->
SetData
(
batchDecValues
,
batchDec
->
unitNum
);
label
->
SetData
(
labelValues
,
label
->
unitNum
);
paddingDec
->
SetDataBatched
(
paddingDecOffsets
,
1.0
F
,
wCountPad
);
XTensor
*
tmp2
=
NewTensorBuf
(
paddingDec
,
devID
,
mem
);
//
XTensor * tmp2 = NewTensorBuf(paddingDec, devID, mem);
_ConvertDataType
(
batchDec
,
tmp2
);
//
_ConvertDataType(batchDec, tmp2);
_NotEqual
(
tmp2
,
paddingDec
,
0
);
//
_NotEqual(tmp2, paddingDec, 0);
DelTensorBuf
(
tmp2
);
//
DelTensorBuf(tmp2);
gold
->
SetDataBatched
(
goldOffsets
,
1.0
F
,
wGold
);
//
gold->SetDataBatched(goldOffsets, 1.0F, wGold);
delete
[]
batchEncValues
;
delete
[]
batchEncValues
;
delete
[]
batchDecValues
;
delete
[]
batchDecValues
;
delete
[]
goldOffsets
;
delete
[]
labelValues
;
//delete[] paddingEncOffsets;
delete
[]
paddingDecOffsets
;
//delete[] goldOffsets;
return
sc
;
return
sc
;
}
}
...
...
source/sample/transformer/T2TTrainer.h
查看文件 @
3bf085db
...
@@ -208,10 +208,10 @@ public:
...
@@ -208,10 +208,10 @@ public:
int
LoadBatch
(
FILE
*
file
,
bool
isLM
,
int
LoadBatch
(
FILE
*
file
,
bool
isLM
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
wCount
,
bool
isSorted
,
int
&
w
s
,
int
&
w
Count
,
int
devID
,
XMem
*
mem
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
bool
isTraining
);
...
@@ -219,7 +219,7 @@ public:
...
@@ -219,7 +219,7 @@ public:
int
LoadBatchLM
(
FILE
*
file
,
int
LoadBatchLM
(
FILE
*
file
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vs
,
int
sBatch
,
int
wBatch
,
int
*
seqs
,
int
vs
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
wCount
,
bool
isSorted
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
int
devID
,
XMem
*
mem
,
...
@@ -229,9 +229,9 @@ public:
...
@@ -229,9 +229,9 @@ public:
int
LoadBatchMT
(
FILE
*
file
,
int
LoadBatchMT
(
FILE
*
file
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
wCount
,
bool
isSorted
,
int
&
w
s
,
int
&
w
Count
,
int
devID
,
XMem
*
mem
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
bool
isTraining
);
...
...
source/sample/transformer/Transformer.cpp
查看文件 @
3bf085db
...
@@ -57,6 +57,8 @@ int TransformerMain(int argc, const char ** argv)
...
@@ -57,6 +57,8 @@ int TransformerMain(int argc, const char ** argv)
LoadParamString
(
argc
,
args
,
"test"
,
testFN
,
""
);
LoadParamString
(
argc
,
args
,
"test"
,
testFN
,
""
);
LoadParamString
(
argc
,
args
,
"output"
,
outputFN
,
""
);
LoadParamString
(
argc
,
args
,
"output"
,
outputFN
,
""
);
srand
((
unsigned
int
)
time
(
NULL
));
T2TTrainer
trainer
;
T2TTrainer
trainer
;
trainer
.
Init
(
argc
,
args
);
trainer
.
Init
(
argc
,
args
);
...
@@ -68,12 +70,12 @@ int TransformerMain(int argc, const char ** argv)
...
@@ -68,12 +70,12 @@ int TransformerMain(int argc, const char ** argv)
trainer
.
Train
(
trainFN
,
testFN
,
strcmp
(
modelFN
,
""
)
?
modelFN
:
"checkpoint.model"
,
&
model
);
trainer
.
Train
(
trainFN
,
testFN
,
strcmp
(
modelFN
,
""
)
?
modelFN
:
"checkpoint.model"
,
&
model
);
/* save the final model */
/* save the final model */
if
(
strcmp
(
modelFN
,
""
)
&&
strcmp
(
trainFN
,
""
))
//
if(strcmp(modelFN, "") && strcmp(trainFN, ""))
model
.
Dump
(
modelFN
);
//
model.Dump(modelFN);
/* load the model if neccessary */
/* load the model if neccessary */
if
(
strcmp
(
modelFN
,
""
))
//
if(strcmp(modelFN, ""))
model
.
Read
(
modelFN
);
//
model.Read(modelFN);
T2TTrainer
tester
;
T2TTrainer
tester
;
tester
.
Init
(
argc
,
args
);
tester
.
Init
(
argc
,
args
);
...
...
source/tensor/XLink.cpp
查看文件 @
3bf085db
...
@@ -307,6 +307,27 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id
...
@@ -307,6 +307,27 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id
MakeLink
(
&
list
,
h
,
id
);
MakeLink
(
&
list
,
h
,
id
);
}
}
/*
create a hyperedge with two input tensors and a output tensor
>> t1 - a tail tensor
>> t2 - the second tail tensor
>> t3 - the third tail tensor
>> h - head tensor
>> id - id of the edge type
*/
void
XLink
::
MakeLink
(
const
XTensor
*
t1
,
const
XTensor
*
t2
,
const
XTensor
*
t3
,
XTensor
*
h
,
int
id
)
{
if
(
h
==
NULL
)
return
;
XList
list
(
3
);
list
.
Add
(
t1
);
list
.
Add
(
t2
);
list
.
Add
(
t3
);
MakeLink
(
&
list
,
h
,
id
);
}
/*
/*
create a hyper edge with a list of tensors and a output tensor
create a hyper edge with a list of tensors and a output tensor
>> list - a list of input tensors
>> list - a list of input tensors
...
...
source/tensor/XLink.h
查看文件 @
3bf085db
...
@@ -138,6 +138,10 @@ struct XLink
...
@@ -138,6 +138,10 @@ struct XLink
static
static
void
MakeLink
(
const
XTensor
*
t1
,
const
XTensor
*
t2
,
XTensor
*
h
,
int
id
);
void
MakeLink
(
const
XTensor
*
t1
,
const
XTensor
*
t2
,
XTensor
*
h
,
int
id
);
/* create a hyper edge with three input tensors and a output tensor */
static
void
MakeLink
(
const
XTensor
*
t1
,
const
XTensor
*
t2
,
const
XTensor
*
t3
,
XTensor
*
h
,
int
id
);
/* create a hyper edge with a list of input tensors and a output tensor */
/* create a hyper edge with a list of input tensors and a output tensor */
static
static
void
MakeLink
(
const
XList
*
list
,
XTensor
*
h
,
int
id
);
void
MakeLink
(
const
XList
*
list
,
XTensor
*
h
,
int
id
);
...
...
source/tensor/XName.cpp
查看文件 @
3bf085db
...
@@ -77,6 +77,8 @@ const char * GetOPName(int type)
...
@@ -77,6 +77,8 @@ const char * GetOPName(int type)
return
"M_POWER"
;
return
"M_POWER"
;
else
if
(
type
==
MATH_SCALEANDSHIFT
)
else
if
(
type
==
MATH_SCALEANDSHIFT
)
return
"M_SCALEANDSHIFT"
;
return
"M_SCALEANDSHIFT"
;
else
if
(
type
==
MATH_MULANDSHIFT
)
return
"M_OPERATION"
;
else
if
(
type
==
MATH_SIGN
)
else
if
(
type
==
MATH_SIGN
)
return
"M_SIGN"
;
return
"M_SIGN"
;
else
if
(
type
==
MATH_SUB
)
else
if
(
type
==
MATH_SUB
)
...
...
source/tensor/XName.h
查看文件 @
3bf085db
...
@@ -57,7 +57,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
...
@@ -57,7 +57,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_NORMALIZE MATH_NEGATE + 1
#define MATH_NORMALIZE MATH_NEGATE + 1
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
#define MATH_SIGN MATH_SCALEANDSHIFT + 1
#define MATH_MULANDSHIFT MATH_SCALEANDSHIFT + 1
#define MATH_SIGN MATH_MULANDSHIFT + 1
#define MATH_SUB MATH_SIGN + 1
#define MATH_SUB MATH_SIGN + 1
#define MATH_SUBDIM MATH_SUB + 1
#define MATH_SUBDIM MATH_SUB + 1
#define MATH_SUM MATH_SUBDIM + 1
#define MATH_SUM MATH_SUBDIM + 1
...
...
source/tensor/core/CHeader.h
查看文件 @
3bf085db
...
@@ -44,6 +44,7 @@
...
@@ -44,6 +44,7 @@
#include "arithmetic/SumByColumnVT.h"
#include "arithmetic/SumByColumnVT.h"
#include "arithmetic/SumDim.h"
#include "arithmetic/SumDim.h"
#include "arithmetic/XTensorBLAS.h"
#include "arithmetic/XTensorBLAS.h"
#include "arithmetic/MulAndShift.h"
#include "getandset/ConvertDataType.h"
#include "getandset/ConvertDataType.h"
#include "getandset/OnehotAndIndex.h"
#include "getandset/OnehotAndIndex.h"
...
...
source/tensor/core/arithmetic/MulAndShift.cpp
0 → 100644
查看文件 @
3bf085db
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2019-02-27
*/
#include "../../XTensor.h"
#include "../../XDevice.h"
#include "../../XName.h"
#include "MulAndShift.h"
#include "MatrixMul.h"
#include "Sum.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/*
return a dimension if the sum is performed as SumDim (in more details in SumDim.h)
>> a - a tensor
>> b - another tensor for sum
*/
int
GetSumIndex
(
const
XTensor
&
a
,
const
XTensor
&
b
)
{
if
(
a
.
order
<
b
.
order
)
return
-
1
;
if
(
XTensor
::
IsSameShaped
(
&
a
,
&
b
))
return
-
1
;
int
hitCount
=
0
;
int
hitDim
=
-
1
;
for
(
int
i
=
0
;
i
<
b
.
order
;
i
++
)
{
if
(
b
.
dimSize
[
b
.
order
-
1
-
i
]
==
1
)
continue
;
else
if
(
b
.
dimSize
[
b
.
order
-
1
-
i
]
==
a
.
dimSize
[
a
.
order
-
1
-
i
])
{
hitCount
++
;
hitDim
=
a
.
order
-
b
.
order
+
i
;
}
}
if
(
hitCount
==
1
)
return
hitDim
;
else
return
-
1
;
}
/*
operation c = x * w + b MulAndShift
>> x - tensor x
>> w - tensor w
>> b - tensor b
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
*/
XTensor
MulAndShift
(
const
XTensor
&
x
,
const
XTensor
&
w
,
const
XTensor
&
b
,
DTYPE
alpha
,
XPRunner
*
parallelRunner
)
{
CheckNTErrors
(
x
.
dataType
==
w
.
dataType
,
"Input tensors should have the same data type!"
);
CheckNTErrors
(
x
.
order
>=
2
&&
w
.
order
>=
2
,
"Input tensors must have a order >= 2!"
);
int
xn
=
x
.
dimSizeRDI
[
1
];
int
xm
=
x
.
dimSizeRDI
[
0
];
int
wn
=
w
.
dimSizeRDI
[
1
];
int
wm
=
w
.
dimSizeRDI
[
0
];
CheckNTErrors
(
xm
==
wn
,
"Unmatched tensors in multiplication!"
);
int
order
=
x
.
order
+
w
.
order
-
2
;
int
sub
=
0
;
int
*
dimSize
=
new
int
[
order
];
for
(
int
i
=
2
;
i
<
x
.
order
;
i
++
)
dimSize
[
sub
++
]
=
x
.
dimSizeRDI
[
x
.
order
+
1
-
i
];
for
(
int
i
=
2
;
i
<
w
.
order
;
i
++
)
dimSize
[
sub
++
]
=
w
.
dimSizeRDI
[
w
.
order
+
1
-
i
];
dimSize
[
sub
++
]
=
xn
;
dimSize
[
sub
++
]
=
wm
;
float
dr
=
(
!
x
.
isSparse
||
!
w
.
isSparse
)
?
1.0
F
:
MAX
(
x
.
denseRatio
,
w
.
denseRatio
);
XTensor
*
tmp
=
NewTensorBuf
(
order
,
dimSize
,
x
.
dataType
,
dr
,
x
.
devID
,
x
.
mem
);
/* call _MatrixMul function */
_MatrixMul
(
&
x
,
X_NOTRANS
,
&
w
,
X_NOTRANS
,
tmp
,
alpha
,
0
,
parallelRunner
);
XTensor
c
(
tmp
);
c
.
SetTMPFlag
();
int
n
=
GetSumIndex
(
tmp
,
b
);
if
(
n
==
-
1
)
{
/* call _Sum function */
_Sum
(
tmp
,
&
b
,
&
c
);
// TODO!!
ShowNTErrors
(
"TODO!"
);
}
else
if
(
n
>=
0
&&
n
<
tmp
->
order
)
{
/* call _SumDim function */
_SumDim
(
tmp
,
&
b
,
&
c
,
n
);
}
else
{
ShowNTErrors
(
"Something is wrong!"
);
}
/* tensor connections */
XLink
::
MakeLink
(
&
x
,
&
w
,
&
b
,
&
c
,
MATH_MULANDSHIFT
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
//XLink::AddParamToHead(&c, beta);
/* destroy variables */
delete
[]
dimSize
;
DelTensorBuf
(
tmp
);
return
c
;
}
}
\ No newline at end of file
source/tensor/core/arithmetic/MulAndShift.h
0 → 100644
查看文件 @
3bf085db
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2019-02-27
*/
#ifndef __MULANDSHIFT_H__
#define __MULANDSHIFT_H__
#include "../../XTensor.h"
#include "../CHeader.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
XTensor
MulAndShift
(
const
XTensor
&
x
,
const
XTensor
&
w
,
const
XTensor
&
b
,
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
XPRunner
*
parallelRunner
=
NULL
);
}
// namespace nts(NiuTrans.Tensor)
#endif // __OPERATION_H__
source/tensor/core/getandset/OnehotAndIndex.cpp
查看文件 @
3bf085db
...
@@ -99,11 +99,11 @@ convert index tensor to onehot tensor
...
@@ -99,11 +99,11 @@ convert index tensor to onehot tensor
>> onehot - onehot tensor, which value is 0 or 1
>> onehot - onehot tensor, which value is 0 or 1
>> size - the last dimension size of the onehot tensor
>> size - the last dimension size of the onehot tensor
*/
*/
void
_IndexToOnehot
(
XTensor
*
index
,
XTensor
*
onehot
,
int
size
)
void
_IndexToOnehot
(
XTensor
*
index
,
XTensor
*
onehot
,
int
size
,
float
labelSmoothingP
)
{
{
CheckNTErrors
(
onehot
->
GetDim
(
-
1
)
==
size
,
"Illegal tensor dimension!"
);
CheckNTErrors
(
onehot
->
GetDim
(
-
1
)
==
size
,
"Illegal tensor dimension!"
);
CheckNTErrors
(
onehot
->
order
==
index
->
order
+
1
,
"Illegal tensor order!"
);
CheckNTErrors
(
onehot
->
order
==
index
->
order
+
1
,
"Illegal tensor order!"
);
CheckNTErrors
(
onehot
->
dataType
==
X_INT
,
"The onehot tensor must be in X_INT!"
)
//
CheckNTErrors(onehot->dataType == X_INT, "The onehot tensor must be in X_INT!")
CheckNTErrors
(
index
->
dataType
==
X_INT
,
"The index tensor must be in X_INT!"
)
CheckNTErrors
(
index
->
dataType
==
X_INT
,
"The index tensor must be in X_INT!"
)
for
(
int
i
=
0
;
i
<
index
->
order
;
i
++
)
for
(
int
i
=
0
;
i
<
index
->
order
;
i
++
)
...
@@ -111,9 +111,12 @@ void _IndexToOnehot(XTensor * index, XTensor * onehot, int size)
...
@@ -111,9 +111,12 @@ void _IndexToOnehot(XTensor * index, XTensor * onehot, int size)
onehot
->
SetZeroAll
();
onehot
->
SetZeroAll
();
float
confidence
=
1
-
labelSmoothingP
;
float
lowconfidence
=
labelSmoothingP
/
size
;
#ifdef USE_CUDA
#ifdef USE_CUDA
if
(
onehot
->
devID
>=
0
&&
index
->
devID
>=
0
)
{
if
(
onehot
->
devID
>=
0
&&
index
->
devID
>=
0
)
{
_CudaIndexToOnehot
(
index
,
onehot
,
size
);
_CudaIndexToOnehot
(
index
,
onehot
,
size
,
confidence
,
lowconfidence
);
return
;
return
;
}
}
#endif
#endif
...
@@ -122,12 +125,13 @@ void _IndexToOnehot(XTensor * index, XTensor * onehot, int size)
...
@@ -122,12 +125,13 @@ void _IndexToOnehot(XTensor * index, XTensor * onehot, int size)
int
stride
=
size
;
int
stride
=
size
;
int
*
indexData
=
(
int
*
)
index
->
data
;
int
*
indexData
=
(
int
*
)
index
->
data
;
int
*
onehotData
=
(
int
*
)
onehot
->
data
;
DTYPE
*
onehotData
=
(
DTYPE
*
)
onehot
->
data
;
for
(
int
i
=
0
;
i
<
blockNum
;
i
++
)
{
for
(
int
i
=
0
;
i
<
blockNum
;
i
++
)
{
int
id
=
indexData
[
i
];
int
id
=
indexData
[
i
];
int
*
od
=
onehotData
+
i
*
stride
;
DTYPE
*
od
=
onehotData
+
i
*
stride
;
od
[
id
]
=
1
;
od
[
id
]
=
2
;
//onehotData[i * stride + id] = 1;
}
}
}
}
...
@@ -138,9 +142,10 @@ make a new tensor to keep the result and return it
...
@@ -138,9 +142,10 @@ make a new tensor to keep the result and return it
>> index - index tensor, which value is an integer num
>> index - index tensor, which value is an integer num
>> size - the last dimension size of the onehot tensor
>> size - the last dimension size of the onehot tensor
>> confidence - labelsmoothing
<< return - the onehot tensor
<< return - the onehot tensor
*/
*/
XTensor
IndexToOnehot
(
XTensor
&
index
,
int
size
)
XTensor
IndexToOnehot
(
XTensor
&
index
,
int
size
,
float
labelSmoothingP
)
{
{
CheckNTErrors
(
index
.
dataType
==
X_INT
,
"The onehot tensor must be in X_INT!"
)
CheckNTErrors
(
index
.
dataType
==
X_INT
,
"The onehot tensor must be in X_INT!"
)
...
@@ -151,9 +156,9 @@ XTensor IndexToOnehot(XTensor & index, int size)
...
@@ -151,9 +156,9 @@ XTensor IndexToOnehot(XTensor & index, int size)
int
*
dim
=
new
int
[
order
+
1
];
int
*
dim
=
new
int
[
order
+
1
];
memcpy
(
dim
,
index
.
dimSize
,
order
*
sizeof
(
int
));
memcpy
(
dim
,
index
.
dimSize
,
order
*
sizeof
(
int
));
dim
[
order
]
=
size
;
dim
[
order
]
=
size
;
InitTensor
(
&
onehot
,
index
.
order
+
1
,
dim
,
X_
IN
T
,
1.0
F
,
index
.
devID
,
index
.
mem
);
InitTensor
(
&
onehot
,
index
.
order
+
1
,
dim
,
X_
FLOA
T
,
1.0
F
,
index
.
devID
,
index
.
mem
);
_IndexToOnehot
(
&
index
,
&
onehot
,
size
);
_IndexToOnehot
(
&
index
,
&
onehot
,
size
,
labelSmoothingP
);
delete
[]
dim
;
delete
[]
dim
;
...
...
source/tensor/core/getandset/OnehotAndIndex.cu
查看文件 @
3bf085db
...
@@ -96,7 +96,7 @@ convert index tensor to onehot tensor (kernel version)
...
@@ -96,7 +96,7 @@ convert index tensor to onehot tensor (kernel version)
>> stride - stride of a data block
>> stride - stride of a data block
*/
*/
__global__
__global__
void KernelIndexToOnehot(
int * onehotData, int * indexData, int blockNum, int strid
e)
void KernelIndexToOnehot(
DTYPE * onehotData, int * indexData, int blockNum, int stride, float confidence, float lowconfidenc
e)
{
{
/* block id */
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.x * blockIdx.x + threadIdx.x;
...
@@ -107,10 +107,17 @@ void KernelIndexToOnehot(int * onehotData, int * indexData, int blockNum, int st
...
@@ -107,10 +107,17 @@ void KernelIndexToOnehot(int * onehotData, int * indexData, int blockNum, int st
if (i >= blockNum || offset >= stride)
if (i >= blockNum || offset >= stride)
return;
return;
int
* od = onehotData + i * stride;
DTYPE
* od = onehotData + i * stride;
int id = indexData[i];
int id = indexData[i];
od[id] = 1;
//od[id] = 2.0;
//onehotData[i * stride + id] = 0.1;
if (offset == id)
od[offset] = confidence;
else{
od[offset] = lowconfidence;
}
}
}
/*
/*
...
@@ -120,7 +127,7 @@ convert index tensor to onehot tensor (cuda version)
...
@@ -120,7 +127,7 @@ convert index tensor to onehot tensor (cuda version)
>> onehot - onehot tensor, which value is 0 or 1
>> onehot - onehot tensor, which value is 0 or 1
>> size - the last dimension size of the onehot tensor
>> size - the last dimension size of the onehot tensor
*/
*/
void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size)
void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size
, float confidence, float lowconfidence
)
{
{
int devID = onehot->devID;
int devID = onehot->devID;
...
@@ -138,10 +145,10 @@ void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size)
...
@@ -138,10 +145,10 @@ void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size)
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int * onehotData = (int
*)onehot->data;
DTYPE * onehotData = (DTYPE
*)onehot->data;
int * indexData = (int *)index->data;
int * indexData = (int *)index->data;
KernelIndexToOnehot<<<blocks, threads >>>(onehotData, indexData, blockNum, stride);
KernelIndexToOnehot<<<blocks, threads >>>(onehotData, indexData, blockNum, stride
, confidence, lowconfidence
);
BacktoCudaDev(devID, devIDBackup);
BacktoCudaDev(devID, devIDBackup);
}
}
...
...
source/tensor/core/getandset/OnehotAndIndex.cuh
查看文件 @
3bf085db
...
@@ -30,7 +30,7 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
...
@@ -30,7 +30,7 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
void _CudaOnehotToIndex(XTensor * onehot, XTensor * index, int size);
void _CudaOnehotToIndex(XTensor * onehot, XTensor * index, int size);
/* convert index tensor to onehot tensor (cuda version) */
/* convert index tensor to onehot tensor (cuda version) */
void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size);
void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size
, float confidence, float lowconfidence
);
} // namespace nts(NiuTrans.Tensor)
} // namespace nts(NiuTrans.Tensor)
...
...
source/tensor/core/getandset/OnehotAndIndex.h
查看文件 @
3bf085db
...
@@ -34,11 +34,11 @@ make a new tensor to keep the result and return it */
...
@@ -34,11 +34,11 @@ make a new tensor to keep the result and return it */
XTensor
OnehotToIndex
(
XTensor
&
onehot
,
int
num
);
XTensor
OnehotToIndex
(
XTensor
&
onehot
,
int
num
);
/* convert index tensor to onehot tensor */
/* convert index tensor to onehot tensor */
void
_IndexToOnehot
(
XTensor
*
index
,
XTensor
*
onehot
,
int
size
);
void
_IndexToOnehot
(
XTensor
*
index
,
XTensor
*
onehot
,
int
size
,
float
labelSmoothingP
);
/* convert index tensor to onehot tensor (return an XTensor structure)
/* convert index tensor to onehot tensor (return an XTensor structure)
make a new tensor to keep the result and return it */
make a new tensor to keep the result and return it */
XTensor
IndexToOnehot
(
XTensor
&
index
,
int
num
);
XTensor
IndexToOnehot
(
XTensor
&
index
,
int
num
,
float
labelSmoothingP
);
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论