Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
bfa6fc90
Commit
bfa6fc90
authored
Feb 10, 2020
by
huchi
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
replace cache type with XTensor
parent
e925cfd9
隐藏空白字符变更
内嵌
并排
正在显示
14 个修改的文件
包含
161 行增加
和
247 行删除
+161
-247
source/Main.cpp
+8
-0
source/sample/transformer/T2TAttention.cpp
+54
-95
source/sample/transformer/T2TAttention.h
+20
-35
source/sample/transformer/T2TDecoder.cpp
+2
-2
source/sample/transformer/T2TEncoder.cpp
+1
-0
source/sample/transformer/T2TLayerNormal.cpp
+0
-3
source/sample/transformer/T2TModel.cpp
+13
-3
source/sample/transformer/T2TModel.h
+1
-1
source/sample/transformer/T2TOutput.cpp
+0
-11
source/sample/transformer/T2TPredictor.cpp
+15
-53
source/sample/transformer/T2TPredictor.h
+1
-8
source/sample/transformer/T2TSearch.cpp
+42
-27
source/sample/transformer/T2TTester.cpp
+3
-8
source/sample/transformer/Transformer.cpp
+1
-1
没有找到文件。
source/Main.cpp
查看文件 @
bfa6fc90
...
@@ -41,6 +41,14 @@ int main( int argc, const char ** argv )
...
@@ -41,6 +41,14 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(2708);
//_CrtSetBreakAlloc(2708);
TransformerMain
(
argc
-
1
,
argv
+
1
);
TransformerMain
(
argc
-
1
,
argv
+
1
);
/*XTensor x;
InitTensor2D(&x, 2, 2);
float d[]{ 1,2,3,4 };
x.SetData(d, 4);
XTensor y;
y = ReduceSum(x, 0);
y.Dump(stderr);*/
//_CrtDumpMemoryLeaks();
//_CrtDumpMemoryLeaks();
...
...
source/sample/transformer/T2TAttention.cpp
查看文件 @
bfa6fc90
...
@@ -28,8 +28,6 @@
...
@@ -28,8 +28,6 @@
namespace
transformer
namespace
transformer
{
{
enum
{
NONE
,
SELF
,
CONTEXT
};
/* constructor */
/* constructor */
T2TAttention
::
T2TAttention
()
T2TAttention
::
T2TAttention
()
{
{
...
@@ -86,88 +84,55 @@ void T2TAttention::InitModel(int argc, char** argv,
...
@@ -86,88 +84,55 @@ void T2TAttention::InitModel(int argc, char** argv,
/*
/*
make the network
make the network
>> k - keys. It might be of size B * L * H
>> k - keys. It might be of size B * L * H
where B = batch size, L = sequence length,
where B = batch size, L = sequence length,
and H = vector size of each position
and H = vector size of each position
>> q - queries
>> q - queries
>> v - values
>> v - values
>> mask - as it is
>> mask - as it is
>> isTraining - indicates whether the model is used for training
>> isTraining - indicates whether the model is used for training
>> cache -
the
cache list
>> cache -
layer
cache list
>> cacheType -
type of the cache
>> cacheType -
which type that cache is
<< return - multi-attention result
<< return - multi-attention result
*/
*/
XTensor
T2TAttention
::
Make
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
XTensor
*
mask
,
XTensor
T2TAttention
::
Make
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
XTensor
*
mask
,
bool
isTraining
,
Cache
*
cache
,
int
cacheType
)
bool
isTraining
,
Cache
*
cache
,
int
cacheType
)
{
{
bool
is_encoder
=
(
!
cache
)
?
true
:
false
;
const
bool
isEnc
=
(
!
cache
)
?
true
:
false
;
int
k2Dim
[]{
k
.
GetDim
(
0
),
k
.
GetDim
(
1
),
wk
.
GetDim
(
1
)
};
int
v2Dim
[]{
v
.
GetDim
(
0
),
v
.
GetDim
(
1
),
wv
.
GetDim
(
1
)
};
XTensor
*
q2
=
NewTensor3DV2
(
q
.
GetDim
(
0
),
q
.
GetDim
(
1
),
wq
.
GetDim
(
1
),
X_FLOAT
,
q
.
devID
);
XTensor
*
k2
=
NULL
;
XTensor
*
v2
=
NULL
;
XTensor
*
kNewCache
=
NULL
;
XTensor
*
vNewCache
=
NULL
;
/* linear transformation before self-attention */
/* linear transformation before self-attention */
/* notice that all weights are transposed!!! */
XTensor
q2
,
k2
,
v2
;
q2
=
MatrixMul
(
q
,
X_NOTRANS
,
wq
,
X_TRANS
)
+
bq
;
_MatrixMul
(
&
q
,
X_NOTRANS
,
&
wq
,
X_TRANS
,
q2
);
_SumDim
(
q2
,
&
bq
,
2
);
if
(
!
cache
)
{
if
(
!
cache
)
{
k2
=
NewTensor3DV2
(
k
.
GetDim
(
0
),
k
.
GetDim
(
1
),
wk
.
GetDim
(
1
),
X_FLOAT
,
k
.
devID
);
/* self attention for encoder layers */
v2
=
NewTensor3DV2
(
v
.
GetDim
(
0
),
v
.
GetDim
(
1
),
wv
.
GetDim
(
1
),
X_FLOAT
,
v
.
devID
);
k2
=
MatrixMul
(
k
,
X_NOTRANS
,
wk
,
X_TRANS
)
+
bk
;
_MatrixMul
(
&
k
,
X_NOTRANS
,
&
wk
,
X_TRANS
,
k2
);
v2
=
MatrixMul
(
v
,
X_NOTRANS
,
wv
,
X_TRANS
)
+
bv
;
_SumDim
(
k2
,
&
bk
,
2
);
return
MakeRPRAttention
(
k2
,
q2
,
v2
,
mask
,
isTraining
,
isEnc
);
_MatrixMul
(
&
v
,
X_NOTRANS
,
&
wv
,
X_TRANS
,
v2
);
_SumDim
(
v2
,
&
bv
,
2
);
}
}
else
{
else
{
if
(
cacheType
==
SELF
)
{
if
(
cacheType
==
SELF_ATT
)
{
k2
=
NewTensor3DV2
(
q
.
GetDim
(
0
),
q
.
GetDim
(
1
),
wk
.
GetDim
(
1
),
X_FLOAT
,
q
.
devID
);
k2
=
MatrixMul
(
k
,
X_NOTRANS
,
wk
,
X_TRANS
)
+
bk
;
v2
=
NewTensor3DV2
(
q
.
GetDim
(
0
),
q
.
GetDim
(
1
),
wv
.
GetDim
(
1
),
X_FLOAT
,
q
.
devID
);
v2
=
MatrixMul
(
v
,
X_NOTRANS
,
wv
,
X_TRANS
)
+
bv
;
_MatrixMul
(
&
q
,
X_NOTRANS
,
&
wk
,
X_TRANS
,
k2
);
_SumDim
(
k2
,
&
bk
,
2
);
/* if hit, we only concat the cache with the new token */
_MatrixMul
(
&
q
,
X_NOTRANS
,
&
wv
,
X_TRANS
,
v2
);
if
(
!
cache
->
miss
)
{
_SumDim
(
v2
,
&
bv
,
2
);
k2
=
Concatenate
(
cache
->
key
,
k2
,
1
);
if
(
!
cache
->
IsEmpty
())
{
v2
=
Concatenate
(
cache
->
value
,
v2
,
1
);
XTensor
*
kOldCache
=
cache
->
GetK
();
XTensor
*
vOldCache
=
cache
->
GetV
();
kNewCache
=
NewTensor3DV2
(
kOldCache
->
GetDim
(
0
),
kOldCache
->
GetDim
(
1
)
+
k2
->
GetDim
(
1
),
kOldCache
->
GetDim
(
2
),
X_FLOAT
,
k2
->
devID
);
vNewCache
=
NewTensor3DV2
(
vOldCache
->
GetDim
(
0
),
vOldCache
->
GetDim
(
1
)
+
v2
->
GetDim
(
1
),
vOldCache
->
GetDim
(
2
),
X_FLOAT
,
v2
->
devID
);
_Concatenate
(
kOldCache
,
k2
,
kNewCache
,
1
);
_Concatenate
(
vOldCache
,
v2
,
vNewCache
,
1
);
DelTensor
(
k2
);
DelTensor
(
v2
);
k2
=
kNewCache
;
v2
=
vNewCache
;
}
}
cache
->
Update
(
k2
,
v2
);
cache
->
key
=
k2
;
cache
->
value
=
v2
;
cache
->
miss
=
false
;
return
MakeRPRAttention
(
cache
->
key
,
q2
,
cache
->
value
,
mask
,
isTraining
,
isEnc
);
}
}
else
if
(
cacheType
==
CONTEXT
)
{
else
if
(
cacheType
==
EN_DE_ATT
)
{
if
(
cache
->
IsEmpty
())
{
if
(
cache
->
miss
)
{
k2
=
NewTensor3DV2
(
k
.
GetDim
(
0
),
k
.
GetDim
(
1
),
wk
.
GetDim
(
1
),
X_FLOAT
,
k
.
devID
);
cache
->
key
=
MatrixMul
(
k
,
X_NOTRANS
,
wk
,
X_TRANS
)
+
bk
;
v2
=
NewTensor3DV2
(
v
.
GetDim
(
0
),
v
.
GetDim
(
1
),
wv
.
GetDim
(
1
),
X_FLOAT
,
v
.
devID
);
cache
->
value
=
MatrixMul
(
v
,
X_NOTRANS
,
wv
,
X_TRANS
)
+
bv
;
_MatrixMul
(
&
k
,
X_NOTRANS
,
&
wk
,
X_TRANS
,
k2
);
cache
->
miss
=
false
;
_SumDim
(
k2
,
&
bk
,
2
);
_MatrixMul
(
&
v
,
X_NOTRANS
,
&
wv
,
X_TRANS
,
v2
);
_SumDim
(
v2
,
&
bv
,
2
);
cache
->
Update
(
k2
,
v2
);
}
else
{
k2
=
cache
->
GetK
();
v2
=
cache
->
GetV
();
}
}
return
MakeAttention
(
cache
->
key
,
q2
,
cache
->
value
,
mask
,
isTraining
,
isEnc
);
}
}
else
{
CheckNTErrors
(
0
,
"invalid cache type"
);
CheckNTErrors
(
0
,
"invalid cache type"
);
}
}
}
if
(
cacheType
==
CONTEXT
)
return
MakeAttention
(
k2
,
q2
,
v2
,
mask
,
isTraining
,
is_encoder
);
return
MakeRPRAttention
(
k2
,
q2
,
v2
,
mask
,
isTraining
,
is_encoder
);
}
}
/*
/*
...
@@ -180,16 +145,16 @@ make the attention network given keys, queries and values (after linear transfor
...
@@ -180,16 +145,16 @@ make the attention network given keys, queries and values (after linear transfor
>> mask - as it is
>> mask - as it is
>> isTraining - indicates whether the model is used for training
>> isTraining - indicates whether the model is used for training
*/
*/
XTensor
T2TAttention
::
MakeAttention
(
XTensor
*
k
,
XTensor
*
q
,
XTensor
*
v
,
const
XTensor
*
mask
,
bool
isTraining
,
bool
is_encoder
)
XTensor
T2TAttention
::
MakeAttention
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
XTensor
*
mask
,
bool
isTraining
,
bool
is_encoder
)
{
{
XTensor
kheads
;
XTensor
kheads
;
XTensor
qheads
;
XTensor
qheads
;
XTensor
vheads
;
XTensor
vheads
;
/* multi head */
/* multi head */
kheads
=
Split
(
*
k
,
k
->
order
-
1
,
nhead
);
kheads
=
Split
(
k
,
k
.
order
-
1
,
nhead
);
qheads
=
Split
(
*
q
,
q
->
order
-
1
,
nhead
);
qheads
=
Split
(
q
,
q
.
order
-
1
,
nhead
);
vheads
=
Split
(
*
v
,
v
->
order
-
1
,
nhead
);
vheads
=
Split
(
v
,
v
.
order
-
1
,
nhead
);
XTensor
att
;
XTensor
att
;
XTensor
dot
;
XTensor
dot
;
...
@@ -198,16 +163,16 @@ XTensor T2TAttention::MakeAttention(XTensor *k, XTensor *q, XTensor *v, const XT
...
@@ -198,16 +163,16 @@ XTensor T2TAttention::MakeAttention(XTensor *k, XTensor *q, XTensor *v, const XT
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */
dot
=
BMMul
(
qheads
,
X_NOTRANS
,
kheads
,
X_TRANS
);
dot
=
BMMul
(
qheads
,
X_NOTRANS
,
kheads
,
X_TRANS
);
if
(
isMasked
&&
mask
)
{
/*
if (isMasked && mask) {
_SumMe(&dot, mask);
_SumMe(&dot, mask);
}
}
*/
dot
=
Linear
(
dot
,
1.0
F
/
(
float
)
sqrt
((
float
)
dk
/
nhead
));
dot
=
Linear
(
dot
,
1.0
F
/
(
float
)
sqrt
((
float
)
dk
/
nhead
));
scalar
=
Softmax
(
dot
,
-
1
);
scalar
=
Softmax
(
dot
,
-
1
);
if
(
isTraining
&&
dropoutP
>
0
)
/*
if(isTraining && dropoutP > 0)
scalar
=
Dropout
(
scalar
,
dropoutP
);
scalar = Dropout(scalar, dropoutP);
*/
att
=
BMMul
(
scalar
,
vheads
);
att
=
BMMul
(
scalar
,
vheads
);
...
@@ -225,32 +190,32 @@ make the attention network by incorporating the relative position representation
...
@@ -225,32 +190,32 @@ make the attention network by incorporating the relative position representation
>> mask - as it is
>> mask - as it is
>> isTraining - indicates whether the model is used for training
>> isTraining - indicates whether the model is used for training
*/
*/
XTensor
T2TAttention
::
MakeRPRAttention
(
XTensor
*
k
,
XTensor
*
q
,
XTensor
*
v
,
XTensor
*
mask
,
bool
isTraining
,
bool
is_encoder
)
XTensor
T2TAttention
::
MakeRPRAttention
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
XTensor
*
mask
,
bool
isTraining
,
bool
is_encoder
)
{
{
XTensor
kheads
;
XTensor
kheads
;
XTensor
qheads
;
XTensor
qheads
;
XTensor
vheads
;
XTensor
vheads
;
const
int
batch_size
=
q
->
GetDim
(
0
);
const
int
batch_size
=
q
.
GetDim
(
0
);
const
int
len_q
=
q
->
GetDim
(
1
);
const
int
len_q
=
q
.
GetDim
(
1
);
const
int
len_kv
=
k
->
GetDim
(
1
);
const
int
len_kv
=
k
.
GetDim
(
1
);
/* multi head */
/* multi head */
kheads
=
Split
(
*
k
,
k
->
order
-
1
,
nhead
);
kheads
=
Split
(
k
,
k
.
order
-
1
,
nhead
);
qheads
=
Split
(
*
q
,
q
->
order
-
1
,
nhead
);
qheads
=
Split
(
q
,
q
.
order
-
1
,
nhead
);
vheads
=
Split
(
*
v
,
v
->
order
-
1
,
nhead
);
vheads
=
Split
(
v
,
v
.
order
-
1
,
nhead
);
XTensor
att
;
XTensor
att
;
XTensor
dot
;
XTensor
dot
;
XTensor
scalar
;
XTensor
scalar
;
XTensor
emb_matrix
,
relative_key
;
XTensor
emb_matrix
,
relative_key
;
InitTensor2DV2
(
&
emb_matrix
,
len_q
,
len_kv
,
X_INT
,
q
->
devID
);
InitTensor2DV2
(
&
emb_matrix
,
len_q
,
len_kv
,
X_INT
,
q
.
devID
);
InitTensor3DV2
(
&
relative_key
,
len_q
,
len_kv
,
kheads
.
GetDim
(
-
1
),
X_FLOAT
,
q
->
devID
);
InitTensor3DV2
(
&
relative_key
,
len_q
,
len_kv
,
kheads
.
GetDim
(
-
1
),
X_FLOAT
,
q
.
devID
);
InitTensor4DV2
(
&
dot
,
nhead
,
batch_size
,
len_q
,
len_kv
,
X_FLOAT
,
q
->
devID
);
InitTensor4DV2
(
&
dot
,
nhead
,
batch_size
,
len_q
,
len_kv
,
X_FLOAT
,
q
.
devID
);
/* generate the relative emb index (L_q, L_kv) */
/* generate the relative emb index (L_q, L_kv) */
GetRPEmbedding
(
&
emb_matrix
,
len_q
,
len_kv
,
max_relative_position
,
q
->
devID
,
is_encoder
);
GetRPEmbedding
(
&
emb_matrix
,
len_q
,
len_kv
,
max_relative_position
,
q
.
devID
,
is_encoder
);
/* generate the relative key from the rp_embedding_k (L_q, L_kv, H/K) */
/* generate the relative key from the rp_embedding_k (L_q, L_kv, H/K) */
...
@@ -261,8 +226,8 @@ XTensor T2TAttention::MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTens
...
@@ -261,8 +226,8 @@ XTensor T2TAttention::MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTens
RPDotProduct
(
&
qheads
,
&
kheads
,
&
relative_key
,
&
dot
,
true
);
RPDotProduct
(
&
qheads
,
&
kheads
,
&
relative_key
,
&
dot
,
true
);
if
(
isMasked
&&
mask
)
/*
if (isMasked && mask)
_SumMe
(
&
dot
,
mask
);
_SumMe(&dot, mask);
*/
/* scale the dot result */
/* scale the dot result */
//dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead));
//dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead));
...
@@ -270,18 +235,12 @@ XTensor T2TAttention::MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTens
...
@@ -270,18 +235,12 @@ XTensor T2TAttention::MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTens
/* softmax */
/* softmax */
scalar
=
Softmax
(
dot
,
-
1
);
scalar
=
Softmax
(
dot
,
-
1
);
if
(
isTraining
&&
dropoutP
>
0
)
/*
if (isTraining && dropoutP > 0)
scalar
=
Dropout
(
scalar
,
dropoutP
);
scalar = Dropout(scalar, dropoutP);
*/
/* generate the relative attention output (K, B, L_q, H/K) */
/* generate the relative attention output (K, B, L_q, H/K) */
att
=
BMMul
(
scalar
,
vheads
);
att
=
BMMul
(
scalar
,
vheads
);
if
(
is_encoder
)
{
DelTensor
(
k
);
DelTensor
(
q
);
DelTensor
(
v
);
}
/* concatenate the heads */
/* concatenate the heads */
return
MulAndShift
(
Merge
(
att
,
att
.
order
-
1
),
X_NOTRANS
,
wa
,
X_TRANS
,
ba
);
return
MulAndShift
(
Merge
(
att
,
att
.
order
-
1
),
X_NOTRANS
,
wa
,
X_TRANS
,
ba
);
}
}
...
...
source/sample/transformer/T2TAttention.h
查看文件 @
bfa6fc90
...
@@ -28,46 +28,31 @@ using namespace nts;
...
@@ -28,46 +28,31 @@ using namespace nts;
namespace
transformer
namespace
transformer
{
{
/* attention type */
enum
{
NONE
,
SELF_ATT
,
EN_DE_ATT
};
/* layer cache for key and value */
/* layer cache for keys and values */
class
Cache
{
class
Cache
{
public
:
public
:
/* cache for keys */
XTensor
key
;
/* cache for key */
/* cache for values */
XTensor
*
k
{
NULL
};
XTensor
value
;
/* cache for value */
XTensor
*
v
{
NULL
};
public
:
public
:
bool
IsEmpty
()
{
bool
miss
;
return
(
k
==
NULL
)
&&
(
v
==
NULL
);
}
void
Clear
()
{
if
(
k
&&
v
&&
k
->
id
>
0
&&
v
->
id
>
0
)
{
DelTensor
(
k
);
DelTensor
(
v
);
}
k
=
NULL
;
v
=
NULL
;
}
void
Update
(
XTensor
*
newK
,
XTensor
*
newV
)
{
if
(
!
newK
||
(
k
==
newK
)
||
!
newV
||
(
v
==
newV
))
return
;
Clear
();
k
=
newK
;
v
=
newV
;
}
XTensor
*
GetK
()
{
Cache
()
{
return
k
;
miss
=
true
;
}
}
XTensor
*
GetV
()
{
void
Update
(
XTensor
&&
k
,
XTensor
&&
v
)
{
return
v
;
key
=
k
;
value
=
v
;
miss
=
false
;
}
}
};
};
...
@@ -153,14 +138,14 @@ public:
...
@@ -153,14 +138,14 @@ public:
int
myDevID
=
-
1
);
int
myDevID
=
-
1
);
/* make the network */
/* make the network */
XTensor
Make
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
XTensor
*
mask
,
XTensor
Make
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
bool
isTraining
,
Cache
*
cache
,
int
cacheType
);
XTensor
*
mask
,
bool
isTraining
,
Cache
*
cache
,
int
cacheType
);
/* make the attention network given keys, queries and values (after linear transformation) */
/* make the attention network given keys, queries and values (after linear transformation) */
XTensor
MakeAttention
(
XTensor
*
k
,
XTensor
*
q
,
XTensor
*
v
,
const
XTensor
*
mask
,
bool
isTraining
,
bool
is_encoder
);
XTensor
MakeAttention
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
XTensor
*
mask
,
bool
isTraining
,
bool
is_encoder
);
/* make the attention network given keys, queries and values (after linear transformation) */
/* make the attention network given keys, queries and values (after linear transformation) */
XTensor
MakeRPRAttention
(
XTensor
*
k
,
XTensor
*
q
,
XTensor
*
v
,
XTensor
*
mask
,
bool
isTraining
,
bool
is_encoder
);
XTensor
MakeRPRAttention
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
XTensor
*
mask
,
bool
isTraining
,
bool
is_encoder
);
void
GetRPEmbedding
(
XTensor
*
emb_matrix
,
const
int
len_q
,
const
int
len_kv
,
const
int
max_relative_length
,
const
int
device_id
,
const
bool
is_encoder
);
void
GetRPEmbedding
(
XTensor
*
emb_matrix
,
const
int
len_q
,
const
int
len_kv
,
const
int
max_relative_length
,
const
int
device_id
,
const
bool
is_encoder
);
...
...
source/sample/transformer/T2TDecoder.cpp
查看文件 @
bfa6fc90
...
@@ -136,7 +136,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X
...
@@ -136,7 +136,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X
/******************/
/******************/
/* self attention */
/* self attention */
att
=
attentions
[
i
].
Make
(
inputNorm
,
inputNorm
,
inputNorm
,
NULL
,
isTraining
,
&
selfCache
[
i
],
1
);
att
=
attentions
[
i
].
Make
(
inputNorm
,
inputNorm
,
inputNorm
,
NULL
,
isTraining
,
&
selfCache
[
i
],
SELF_ATT
);
/* dropout */
/* dropout */
if
(
isTraining
&&
dropoutP
>
0
)
if
(
isTraining
&&
dropoutP
>
0
)
...
@@ -151,7 +151,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X
...
@@ -151,7 +151,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X
//attNorm.Dump(stderr, "attNorm", 10);
//attNorm.Dump(stderr, "attNorm", 10);
/* encoder-decoder attention */
/* encoder-decoder attention */
ende
=
attentionsEnde
[
i
].
Make
(
outputEnc
,
attNorm
,
outputEnc
,
&
maskEncDec
,
isTraining
,
&
contextCache
[
i
],
2
);
ende
=
attentionsEnde
[
i
].
Make
(
outputEnc
,
attNorm
,
outputEnc
,
&
maskEncDec
,
isTraining
,
&
contextCache
[
i
],
EN_DE_ATT
);
//ende.Dump(stderr, "ende atten", 10);
//ende.Dump(stderr, "ende atten", 10);
...
...
source/sample/transformer/T2TEncoder.cpp
查看文件 @
bfa6fc90
...
@@ -123,6 +123,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor *mask, XTensor &maskEncDec, boo
...
@@ -123,6 +123,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor *mask, XTensor &maskEncDec, boo
/* fnn */
/* fnn */
x
=
fnns
[
i
].
Make
(
res
,
isTraining
);
x
=
fnns
[
i
].
Make
(
res
,
isTraining
);
}
}
x
=
encodeLayerNorm
->
Make
(
x
);
x
=
encodeLayerNorm
->
Make
(
x
);
...
...
source/sample/transformer/T2TLayerNormal.cpp
查看文件 @
bfa6fc90
...
@@ -55,9 +55,6 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID)
...
@@ -55,9 +55,6 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID)
InitTensor1D
(
&
w
,
d
,
X_FLOAT
,
devID
);
InitTensor1D
(
&
w
,
d
,
X_FLOAT
,
devID
);
InitTensor1D
(
&
b
,
d
,
X_FLOAT
,
devID
);
InitTensor1D
(
&
b
,
d
,
X_FLOAT
,
devID
);
w
.
SetDataRand
(
1.0
F
,
1.0
F
);
b
.
SetZeroAll
();
}
}
/*
/*
...
...
source/sample/transformer/T2TModel.cpp
查看文件 @
bfa6fc90
...
@@ -491,11 +491,10 @@ void T2TModel::Read(const char * fn)
...
@@ -491,11 +491,10 @@ void T2TModel::Read(const char * fn)
GetParams
(
params
);
GetParams
(
params
);
size_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
params
.
count
;
i
++
){
for
(
int
i
=
0
;
i
<
params
.
count
;
i
++
){
XTensor
*
p
=
(
XTensor
*
)
params
.
Get
(
i
);
XTensor
*
p
=
(
XTensor
*
)
params
.
Get
(
i
);
p
->
BinaryRead
(
file
,
offset
);
FastRead
(
p
,
file
);
offset
+=
p
->
unitNum
;
// p->Read(file, "")
;
}
}
fclose
(
file
);
fclose
(
file
);
...
@@ -505,4 +504,14 @@ void T2TModel::Read(const char * fn)
...
@@ -505,4 +504,14 @@ void T2TModel::Read(const char * fn)
XPRINT1
(
0
,
stderr
,
"[INFO] model loaded (took %.1fs)
\n
"
,
elapsed
);
XPRINT1
(
0
,
stderr
,
"[INFO] model loaded (took %.1fs)
\n
"
,
elapsed
);
}
}
void
FastRead
(
XTensor
*
x
,
FILE
*
f
)
{
float
*
dataBuf
=
new
float
[
x
->
unitNum
];
fread
(
dataBuf
,
sizeof
(
char
),
sizeof
(
float
)
*
x
->
unitNum
,
f
);
x
->
SetData
(
dataBuf
,
x
->
unitNum
);
delete
[]
dataBuf
;
}
}
}
\ No newline at end of file
source/sample/transformer/T2TModel.h
查看文件 @
bfa6fc90
...
@@ -103,7 +103,7 @@ public:
...
@@ -103,7 +103,7 @@ public:
/* read the parameters */
/* read the parameters */
void
Read
(
const
char
*
fn
);
void
Read
(
const
char
*
fn
);
};
};
void
FastRead
(
XTensor
*
x
,
FILE
*
f
);
}
}
#endif
#endif
source/sample/transformer/T2TOutput.cpp
查看文件 @
bfa6fc90
...
@@ -61,18 +61,7 @@ void T2TOutput::InitModel(int argc, char ** argv, int myDevID)
...
@@ -61,18 +61,7 @@ void T2TOutput::InitModel(int argc, char ** argv, int myDevID)
InitTensor2D
(
&
w
,
hSize
,
vSize
,
X_FLOAT
,
devID
);
InitTensor2D
(
&
w
,
hSize
,
vSize
,
X_FLOAT
,
devID
);
}
}
/*
make the network
y = softmax(x * w)
>> input - input tensor
<< return - output tensor
*/
XTensor
T2TOutput
::
Make
(
XTensor
&
input
)
{
XTensor
&
x
=
input
;
return
Softmax
(
MMul
(
x
,
X_NOTRANS
,
w
,
X_TRANS
),
-
1
);
}
/*
/*
make the network (redefined output tensor)
make the network (redefined output tensor)
...
...
source/sample/transformer/T2TPredictor.cpp
查看文件 @
bfa6fc90
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "T2TPredictor.h"
#include "T2TPredictor.h"
#include "../../tensor/core/CHeader.h"
#include "../../tensor/core/CHeader.h"
#include <iostream>
using
namespace
nts
;
using
namespace
nts
;
...
@@ -91,15 +92,6 @@ create an initial state
...
@@ -91,15 +92,6 @@ create an initial state
*/
*/
void
T2TPredictor
::
Create
(
T2TModel
*
model
,
XTensor
*
top
,
const
XTensor
*
input
,
int
beamSize
,
T2TStateBundle
*
state
)
void
T2TPredictor
::
Create
(
T2TModel
*
model
,
XTensor
*
top
,
const
XTensor
*
input
,
int
beamSize
,
T2TStateBundle
*
state
)
{
{
state
->
layersEnc
.
Clear
();
state
->
layersDec
.
Clear
();
XTensor
*
encoding
=
XLink
::
SearchNode
(
top
,
ENCODING_NAME
);
CheckNTErrors
(
encoding
!=
NULL
,
"No encoding layers found!"
);
state
->
layersEnc
.
Add
(
encoding
);
state
->
layersDec
.
Add
(
NULL
);
int
dims
[
MAX_TENSOR_DIM_NUM
];
int
dims
[
MAX_TENSOR_DIM_NUM
];
for
(
int
i
=
0
;
i
<
input
->
order
-
1
;
i
++
)
for
(
int
i
=
0
;
i
<
input
->
order
-
1
;
i
++
)
dims
[
i
]
=
input
->
GetDim
(
i
);
dims
[
i
]
=
input
->
GetDim
(
i
);
...
@@ -109,10 +101,18 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input
...
@@ -109,10 +101,18 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input
InitTensor
(
&
state
->
nstep
,
input
->
order
,
dims
,
X_FLOAT
,
input
->
devID
);
InitTensor
(
&
state
->
nstep
,
input
->
order
,
dims
,
X_FLOAT
,
input
->
devID
);
InitTensor
(
&
state
->
endMark
,
input
->
order
,
dims
,
X_INT
,
input
->
devID
);
InitTensor
(
&
state
->
endMark
,
input
->
order
,
dims
,
X_INT
,
input
->
devID
);
state
->
probPath
.
SetZeroAll
();
float
*
data
=
new
float
[
state
->
probPath
.
unitNum
];
for
(
int
i
=
0
;
i
<
state
->
probPath
.
unitNum
;
++
i
)
{
data
[
i
]
=
-
1e20
F
;
if
(
i
%
beamSize
==
0
)
data
[
i
]
=
0
;
}
state
->
probPath
.
SetData
(
data
,
state
->
probPath
.
unitNum
);
state
->
nstep
.
SetZeroAll
();
state
->
nstep
.
SetZeroAll
();
state
->
endMark
.
SetZeroAll
();
state
->
endMark
.
SetZeroAll
();
delete
[]
data
;
state
->
stateNum
=
0
;
state
->
stateNum
=
0
;
}
}
...
@@ -145,20 +145,13 @@ predict the next state
...
@@ -145,20 +145,13 @@ predict the next state
>> encoding - encoder output
>> encoding - encoder output
>> inputEnc - input of the encoder
>> inputEnc - input of the encoder
>> paddingEnc - padding of the encoder
>> paddingEnc - padding of the encoder
>>> isStart - is the start or not
*/
*/
void
T2TPredictor
::
Predict
(
T2TStateBundle
*
next
,
XTensor
*
encoding
,
void
T2TPredictor
::
Predict
(
T2TStateBundle
*
next
,
XTensor
*
encoding
,
XTensor
*
inputEnc
,
XTensor
*
paddingEnc
)
XTensor
*
inputEnc
,
XTensor
*
paddingEnc
,
bool
isStart
)
{
{
int
dims
[
MAX_TENSOR_DIM_NUM
];
int
dims
[
MAX_TENSOR_DIM_NUM
];
next
->
layersEnc
.
Clear
();
next
->
layersDec
.
Clear
();
AttDecoder
&
decoder
=
*
m
->
decoder
;
/* word indices of previous positions */
XTensor
*
inputLast
=
(
XTensor
*
)
s
->
layersDec
.
GetItem
(
0
);
/* word indices of positions up to next state */
/* word indices of positions up to next state */
XTensor
inputDec
;
XTensor
inputDec
;
...
@@ -171,10 +164,10 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
...
@@ -171,10 +164,10 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
dims
[
inputEnc
->
order
-
1
]
=
1
;
dims
[
inputEnc
->
order
-
1
]
=
1
;
InitTensor
(
&
first
,
inputEnc
->
order
,
dims
,
X_INT
,
inputEnc
->
devID
);
InitTensor
(
&
first
,
inputEnc
->
order
,
dims
,
X_INT
,
inputEnc
->
devID
);
_SetDataFixedInt
(
&
first
,
startSymbol
);
SetDataFixedInt
(
first
,
startSymbol
);
/* add a new word into the input sequence of the decoder side */
/* add a new word into the input sequence of the decoder side */
if
(
i
nputLast
==
NULL
)
{
if
(
i
sStart
)
{
inputDec
=
Identity
(
first
);
inputDec
=
Identity
(
first
);
}
}
else
{
else
{
...
@@ -186,7 +179,6 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
...
@@ -186,7 +179,6 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
/* prediction probabilities */
/* prediction probabilities */
XTensor
&
output
=
next
->
prob
;
XTensor
&
output
=
next
->
prob
;
XTensor
decoding
;
XTensor
decoding
;
XTensor
decodingStep
;
for
(
int
i
=
0
;
i
<
inputDec
.
order
-
1
;
i
++
)
for
(
int
i
=
0
;
i
<
inputDec
.
order
-
1
;
i
++
)
dims
[
i
]
=
inputDec
.
GetDim
(
i
);
dims
[
i
]
=
inputDec
.
GetDim
(
i
);
...
@@ -203,38 +195,12 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
...
@@ -203,38 +195,12 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
m
->
MakeMTMaskDec
(
*
inputEnc
,
inputDec
,
*
paddingEnc
,
paddingDec
,
maskDec
,
maskEncDec
,
0
);
m
->
MakeMTMaskDec
(
*
inputEnc
,
inputDec
,
*
paddingEnc
,
paddingDec
,
maskDec
,
maskEncDec
,
0
);
/* make the decoding network */
/* make the decoding network */
decoding
=
decoder
.
Make
(
inputDec
,
*
encoding
,
&
maskDec
,
maskEncDec
,
false
);
decoding
=
m
->
decoder
->
Make
(
inputDec
,
*
encoding
,
NULL
,
maskEncDec
,
false
);
XTensor
selectSrc
;
XTensor
selectTgt
;
CheckNTErrors
(
decoding
.
order
>=
2
,
"The tensor must be of order 2 or larger!"
);
CheckNTErrors
(
decoding
.
order
>=
2
,
"The tensor must be of order 2 or larger!"
);
int
stride
=
decoding
.
GetDim
(
decoding
.
order
-
2
);
//InitTensor1DV2(&selectSrc, 1, X_INT);
//InitTensor1DV2(&selectTgt, 1, X_INT);
//selectSrc.SetInt(stride - 1, 0);
//selectTgt.SetInt(0, 0);
//XTensor srcGPU;
//InitTensor1DV2(&srcGPU, 1, X_INT, decoding.devID);
//_CopyValues(&selectSrc, &srcGPU);
//XTensor tgtGPU;
//InitTensor1DV2(&tgtGPU, 1, X_INT, decoding.devID);
//_CopyValues(&selectTgt, &tgtGPU);
///* the decoder output of the last position */
//decodingStep = CopyIndexed(decoding, decoding.order - 2, srcGPU, tgtGPU);
/* generate the output probabilities */
/* generate the output probabilities */
m
->
outputLayer
->
Make
(
decoding
,
output
);
m
->
outputLayer
->
Make
(
decoding
,
output
);
next
->
layersEnc
.
AddList
(
&
s
->
layersEnc
);
next
->
layersDec
.
Add
(
&
inputDec
);
next
->
layersDec
.
Add
(
&
output
);
}
}
/*
/*
...
@@ -288,14 +254,10 @@ XTensor T2TPredictor::GetLastPrediction(T2TStateBundle* state)
...
@@ -288,14 +254,10 @@ XTensor T2TPredictor::GetLastPrediction(T2TStateBundle* state)
XTensor
lastPred
;
XTensor
lastPred
;
InitTensor2D
(
&
lastPred
,
state
->
stateNum
,
1
,
X_INT
);
InitTensor2D
(
&
lastPred
,
state
->
stateNum
,
1
,
X_INT
);
lastPred
.
SetZeroAll
();
for
(
int
i
=
0
;
i
<
state
->
stateNum
;
i
++
)
{
for
(
int
i
=
0
;
i
<
state
->
stateNum
;
i
++
)
{
T2TState
*
cur
=
state
->
states
+
i
;
T2TState
*
cur
=
state
->
states
+
i
;
while
(
cur
->
last
!=
NULL
)
cur
=
cur
->
last
;
lastPred
.
Set2DInt
(
cur
->
prediction
,
i
,
0
);
lastPred
.
Set2DInt
(
cur
->
prediction
,
i
,
0
);
}
}
...
...
source/sample/transformer/T2TPredictor.h
查看文件 @
bfa6fc90
...
@@ -94,13 +94,6 @@ public:
...
@@ -94,13 +94,6 @@ public:
/* step number of each hypothesis */
/* step number of each hypothesis */
XTensor
nstep
;
XTensor
nstep
;
/* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */
TensorList
layersEnc
;
/* layers on the decoder side */
TensorList
layersDec
;
/* list of states */
/* list of states */
T2TState
*
states
;
T2TState
*
states
;
...
@@ -155,7 +148,7 @@ public:
...
@@ -155,7 +148,7 @@ public:
void
Read
(
T2TModel
*
model
,
T2TStateBundle
*
state
);
void
Read
(
T2TModel
*
model
,
T2TStateBundle
*
state
);
/* predict the next state */
/* predict the next state */
void
Predict
(
T2TStateBundle
*
next
,
XTensor
*
encoding
,
XTensor
*
inputEnc
,
XTensor
*
paddingEnc
);
void
Predict
(
T2TStateBundle
*
next
,
XTensor
*
encoding
,
XTensor
*
inputEnc
,
XTensor
*
paddingEnc
,
bool
isStart
);
/* generate paths up to the states of the current step */
/* generate paths up to the states of the current step */
XTensor
GeneratePaths
(
T2TStateBundle
*
state
);
XTensor
GeneratePaths
(
T2TStateBundle
*
state
);
...
...
source/sample/transformer/T2TSearch.cpp
查看文件 @
bfa6fc90
...
@@ -59,9 +59,9 @@ void T2TSearch::Init(int argc, char ** argv)
...
@@ -59,9 +59,9 @@ void T2TSearch::Init(int argc, char ** argv)
{
{
LoadParamInt
(
argc
,
argv
,
"beamsize"
,
&
beamSize
,
1
);
LoadParamInt
(
argc
,
argv
,
"beamsize"
,
&
beamSize
,
1
);
LoadParamInt
(
argc
,
argv
,
"batchsize"
,
&
batchSize
,
1
);
LoadParamInt
(
argc
,
argv
,
"batchsize"
,
&
batchSize
,
1
);
LoadParamFloat
(
argc
,
argv
,
"lenalpha"
,
&
alpha
,
0.2
F
);
LoadParamFloat
(
argc
,
argv
,
"lenalpha"
,
&
alpha
,
1.0
F
);
LoadParamInt
(
argc
,
argv
,
"endid"
,
endSymbols
,
-
1
);
LoadParamInt
(
argc
,
argv
,
"endid"
,
endSymbols
,
2
);
LoadParamInt
(
argc
,
argv
,
"startid"
,
&
startSymbol
,
-
1
);
LoadParamInt
(
argc
,
argv
,
"startid"
,
&
startSymbol
,
2
);
if
(
endSymbols
[
0
]
>=
0
)
if
(
endSymbols
[
0
]
>=
0
)
endSymbolNum
=
1
;
endSymbolNum
=
1
;
...
@@ -90,13 +90,9 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
...
@@ -90,13 +90,9 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* encoder mask */
/* encoder mask */
model
->
MakeMTMaskEnc
(
*
input
,
*
padding
,
maskEnc
);
model
->
MakeMTMaskEnc
(
*
input
,
*
padding
,
maskEnc
);
//input->Dump(stderr, "input:");
//maskEnc.Dump(stderr, "maskenc:");
/* make the encoding network */
/* make the encoding network */
encoding
=
model
->
MakeEncoder
(
*
input
,
&
maskEnc
,
false
);
encoding
=
model
->
MakeEncoder
(
*
input
,
&
maskEnc
,
false
);
encoding
.
SetName
(
ENCODING_NAME
);
encodingBeam
=
Unsqueeze
(
encoding
,
encoding
.
order
-
2
,
beamSize
);
encodingBeam
=
Unsqueeze
(
encoding
,
encoding
.
order
-
2
,
beamSize
);
inputBeam
=
Unsqueeze
(
*
input
,
input
->
order
-
1
,
beamSize
);
inputBeam
=
Unsqueeze
(
*
input
,
input
->
order
-
1
,
beamSize
);
...
@@ -110,9 +106,11 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
...
@@ -110,9 +106,11 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
maxLength
=
input
->
GetDim
(
-
1
)
*
2
;
maxLength
=
input
->
GetDim
(
-
1
)
*
2
;
CheckNTErrors
(
maxLength
>
0
,
"no max length specified!"
);
CheckNTErrors
(
maxLength
>
0
,
"no max length specified!"
);
T2TStateBundle
*
states
=
new
T2TStateBundle
[
maxLength
+
1
];
T2TStateBundle
*
states
=
new
T2TStateBundle
[
maxLength
+
1
];
T2TStateBundle
*
first
=
states
;
T2TStateBundle
*
first
=
states
;
T2TStateBundle
*
cur
;
T2TStateBundle
*
next
;
/* create the first state */
/* create the first state */
predictor
.
Create
(
model
,
&
encodingBeam
,
input
,
beamSize
,
first
);
predictor
.
Create
(
model
,
&
encodingBeam
,
input
,
beamSize
,
first
);
predictor
.
SetStartSymbol
(
startSymbol
);
predictor
.
SetStartSymbol
(
startSymbol
);
...
@@ -121,14 +119,14 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
...
@@ -121,14 +119,14 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* generate the sequence from left to right */
/* generate the sequence from left to right */
for
(
int
i
=
0
;
i
<
maxLength
;
i
++
){
for
(
int
i
=
0
;
i
<
maxLength
;
i
++
){
T2TStateBundle
*
cur
=
states
+
i
;
cur
=
states
+
i
;
T2TStateBundle
*
next
=
states
+
i
+
1
;
next
=
states
+
i
+
1
;
/* read the current state */
/* read the current state */
predictor
.
Read
(
model
,
cur
);
predictor
.
Read
(
model
,
cur
);
/* predict the next state */
/* predict the next state */
predictor
.
Predict
(
next
,
&
encodingBeam
,
&
inputBeam
,
&
paddingBeam
);
predictor
.
Predict
(
next
,
&
encodingBeam
,
&
inputBeam
,
&
paddingBeam
,
i
==
0
);
/* compute the model score (given the prediction probability) */
/* compute the model score (given the prediction probability) */
Score
(
cur
,
next
);
Score
(
cur
,
next
);
...
@@ -144,7 +142,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
...
@@ -144,7 +142,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
}
}
/* fill the heap with imcomplete hypotheses if neccesary */
/* fill the heap with imcomplete hypotheses if neccesary */
FillHeap
(
&
states
[
maxLength
]
);
FillHeap
(
next
);
Dump
(
output
);
Dump
(
output
);
...
@@ -210,24 +208,25 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
...
@@ -210,24 +208,25 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
_ScaleAndShift
(
&
lenPrev
,
&
len
,
1.0
F
,
1.0
F
);
_ScaleAndShift
(
&
lenPrev
,
&
len
,
1.0
F
,
1.0
F
);
/* the GNMT-like length penalty */
/* the GNMT-like length penalty */
lp
=
T2TLengthPenalizer
::
GNMT
(
len
,
alpha
);
//
lp = T2TLengthPenalizer::GNMT(len, alpha);
lp
.
Reshape
(
lp
.
unitNum
);
//
lp.Reshape(lp.unitNum);
/* score = log-prob/lp */
/* score = log-prob/lp */
_DivDim
(
&
probPath
,
&
lp
,
&
score
,
0
);
//
_DivDim(&probPath, &lp, &score, 0);
if
(
prev
->
isStart
)
{
if
(
prev
->
isStart
)
{
XTensor
firstMask
=
MakeFirstMask
(
beam
);
XTensor
firstMask
=
MakeFirstMask
(
beam
);
firstMask
.
Reshape
(
firstMask
.
unitNum
);
firstMask
.
Reshape
(
firstMask
.
unitNum
);
/* mask the hypotheses in the beam ex
pec
t the first one */
/* mask the hypotheses in the beam ex
cep
t the first one */
_SumDim
(
&
score
,
&
firstMask
,
&
score
,
0
);
_SumDim
(
&
score
,
&
firstMask
,
&
score
,
0
);
}
}
InitTensor
(
&
mask
,
InitTensor
(
&
mask
,
prev
->
endMark
.
order
,
prev
->
endMark
.
dimSize
,
X_FLOAT
,
prev
->
endMark
.
order
,
prev
->
endMark
.
dimSize
,
X_FLOAT
,
prev
->
endMark
.
devID
);
prev
->
endMark
.
devID
);
mask
.
SetZeroAll
();
_SetDataFixedCond
(
&
mask
,
&
prev
->
endMark
,
-
1e9
F
);
_SetDataFixedCond
(
&
mask
,
&
prev
->
endMark
,
-
1e9
F
);
mask
.
Reshape
(
mask
.
unitNum
);
mask
.
Reshape
(
mask
.
unitNum
);
...
@@ -279,17 +278,26 @@ void T2TSearch::Generate(T2TStateBundle * beam)
...
@@ -279,17 +278,26 @@ void T2TSearch::Generate(T2TStateBundle * beam)
dimsTopK
[
order
-
3
]
=
dimsBeam
[
order
-
3
];
dimsTopK
[
order
-
3
]
=
dimsBeam
[
order
-
3
];
dimsTopK
[
order
-
1
]
=
beamSize
;
dimsTopK
[
order
-
1
]
=
beamSize
;
InitTensor
(
&
scoreTopK
,
order
,
dimsTopK
,
score
.
dataType
,
InitTensor
(
&
scoreTopK
,
order
,
dimsTopK
,
score
.
dataType
,
score
.
devID
);
score
.
devID
);
InitTensor
(
&
index
,
order
,
dimsTopK
,
X_INT
,
score
.
devID
);
InitTensor
(
&
index
,
order
,
dimsTopK
,
X_INT
,
score
.
devID
);
InitTensor
(
&
preID
,
order
,
dimsTopK
,
X_INT
,
-
1
);
InitTensor
(
&
preID
,
order
,
dimsTopK
,
X_INT
,
-
1
);
/* mask the first and the padding id */
int
dimMask
[]{
score
.
GetDim
(
-
1
)
};
XTensor
mask
;
InitTensor
(
&
mask
,
1
,
dimMask
,
X_FLOAT
,
-
1
);
mask
.
SetZeroAll
();
mask
.
Set1D
(
-
1e20
F
,
0
);
mask
.
Set1D
(
-
1e20
F
,
1
);
mask
.
SetDevice
(
score
.
devID
,
score
.
mem
);
//_SumDim(&score, &mask, 2);
score
.
Reshape
(
order
,
dimsBeam
);
score
.
Reshape
(
order
,
dimsBeam
);
/* keep the most promissing candidates in the beam */
/* keep the most promissing candidates in the beam */
/* TODO: check this line */
TopK
(
score
,
scoreTopK
,
index
,
-
1
,
beamSize
);
TopK
(
score
,
scoreTopK
,
index
,
-
1
,
beamSize
);
CopyValues
(
index
,
preID
);
CopyValues
(
index
,
preID
);
/* "preID" represents the id (or the offset) of the previous state used to make the current
/* "preID" represents the id (or the offset) of the previous state used to make the current
...
@@ -313,11 +321,14 @@ void T2TSearch::Generate(T2TStateBundle * beam)
...
@@ -313,11 +321,14 @@ void T2TSearch::Generate(T2TStateBundle * beam)
/* CPU data (TODO: remove GPU->CPU data copy!!!) */
/* CPU data (TODO: remove GPU->CPU data copy!!!) */
XTensor
indexGPU
;
XTensor
indexGPU
;
indexGPU
=
CopyValues
(
index
);
indexGPU
=
CopyValues
(
index
);
//InitTensorV2(&indexCPU, index.order, index.dimSize, index.dataType, index.denseRatio, -1);
//CopyValues(index, indexCPU);
for
(
int
i
=
0
;
i
<
indexGPU
.
unitNum
;
i
++
)
for
(
int
i
=
0
;
i
<
indexGPU
.
unitNum
;
i
+=
beamSize
)
{
indexGPU
.
SetInt
(
i
*
stride
+
indexGPU
.
GetInt
(
i
),
i
);
for
(
int
j
=
0
;
j
<
beamSize
;
j
++
)
indexGPU
.
SetInt
(
i
*
stride
+
indexGPU
.
GetInt
(
i
+
j
),
i
+
j
);
}
/*for (int i = 0; i < indexGPU.unitNum; i++) {
indexGPU.SetInt(i + indexGPU.GetInt(i), i);
}*/
CheckNTErrors
(
IsSameShaped
(
prob
,
probPath
),
"Wrong tensor shape!"
);
CheckNTErrors
(
IsSameShaped
(
prob
,
probPath
),
"Wrong tensor shape!"
);
...
@@ -460,6 +471,10 @@ void T2TSearch::Collect(T2TStateBundle * beam)
...
@@ -460,6 +471,10 @@ void T2TSearch::Collect(T2TStateBundle * beam)
CheckNTErrors
(
state
.
pid
>=
0
&&
state
.
pid
<
batchSize
,
CheckNTErrors
(
state
.
pid
>=
0
&&
state
.
pid
<
batchSize
,
"Invalid sample id!"
);
"Invalid sample id!"
);
/* check if this is the first end symbol. It is false
if there have been end symbols in previously generated words. */
bool
isCompleted
=
state
.
isCompleted
&&
(
state
.
last
==
NULL
||
!
state
.
last
->
isCompleted
);
/* we push the hypothesis into the heap when it is completed */
/* we push the hypothesis into the heap when it is completed */
if
(
state
.
isEnd
!=
0
)
if
(
state
.
isEnd
!=
0
)
...
...
source/sample/transformer/T2TTester.cpp
查看文件 @
bfa6fc90
...
@@ -93,17 +93,13 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
...
@@ -93,17 +93,13 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
{
{
count
++
;
count
++
;
wordCount
=
0
;
wordCount
=
0
;
/*if (count % 10 == 0 && sentBatch < 128)
sentBatch *= 2;*/
/* reset cache for decoder */
for
(
int
i
=
0
;
i
<
model
->
decoder
->
nlayer
;
++
i
)
{
for
(
int
i
=
0
;
i
<
model
->
decoder
->
nlayer
;
++
i
)
{
model
->
decoder
->
selfCache
[
i
].
Clear
()
;
model
->
decoder
->
selfCache
[
i
].
miss
=
true
;
model
->
decoder
->
contextCache
[
i
].
Clear
()
;
model
->
decoder
->
contextCache
[
i
].
miss
=
true
;
}
}
vector
<
int
>
indices
=
batchLoader
.
LoadBatch
(
&
batchEnc
,
&
paddingEnc
,
sentBatch
,
devID
);
vector
<
int
>
indices
=
batchLoader
.
LoadBatch
(
&
batchEnc
,
&
paddingEnc
,
sentBatch
,
devID
);
XTensor
output
;
XTensor
output
;
seacher
.
Search
(
model
,
&
batchEnc
,
&
paddingEnc
,
&
output
);
seacher
.
Search
(
model
,
&
batchEnc
,
&
paddingEnc
,
&
output
);
...
@@ -122,7 +118,6 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
...
@@ -122,7 +118,6 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
res
.
id
=
indices
[
i
];
res
.
id
=
indices
[
i
];
batchLoader
.
resBuffer
.
emplace_back
(
res
);
batchLoader
.
resBuffer
.
emplace_back
(
res
);
}
}
wc
=
batchEnc
.
GetDim
(
-
1
);
wc
=
batchEnc
.
GetDim
(
-
1
);
wordCount
+=
wc
;
wordCount
+=
wc
;
...
...
source/sample/transformer/Transformer.cpp
查看文件 @
bfa6fc90
...
@@ -54,7 +54,7 @@ int TransformerMain(int argc, const char ** argv)
...
@@ -54,7 +54,7 @@ int TransformerMain(int argc, const char ** argv)
char
*
rawModel
=
new
char
[
MAX_LINE_LENGTH
];
char
*
rawModel
=
new
char
[
MAX_LINE_LENGTH
];
LoadParamString
(
argc
,
args
,
"model"
,
modelFN
,
""
);
LoadParamString
(
argc
,
args
,
"model"
,
modelFN
,
""
);
LoadParamString
(
argc
,
args
,
"raw
M
odel"
,
rawModel
,
""
);
LoadParamString
(
argc
,
args
,
"raw
m
odel"
,
rawModel
,
""
);
LoadParamString
(
argc
,
args
,
"test"
,
testFN
,
""
);
LoadParamString
(
argc
,
args
,
"test"
,
testFN
,
""
);
LoadParamString
(
argc
,
args
,
"output"
,
outputFN
,
""
);
LoadParamString
(
argc
,
args
,
"output"
,
outputFN
,
""
);
LoadParamBool
(
argc
,
args
,
"beamsearch"
,
&
isBeamSearch
,
false
);
LoadParamBool
(
argc
,
args
,
"beamsearch"
,
&
isBeamSearch
,
false
);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论