Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
30c3a629
Commit
30c3a629
authored
Aug 19, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
better code of t2t mask for lm and add DataSetDim
parent
a7fe4564
隐藏空白字符变更
内嵌
并排
正在显示
12 个修改的文件
包含
213 行增加
和
39 行删除
+213
-39
source/sample/transformer/T2TAttention.cpp
+11
-1
source/sample/transformer/T2TAttention.h
+7
-1
source/sample/transformer/T2TEmbedding.cpp
+1
-1
source/sample/transformer/T2TEncoder.cpp
+26
-11
source/sample/transformer/T2TEncoder.h
+10
-7
source/sample/transformer/T2TLayerNormal.cpp
+0
-1
source/sample/transformer/T2TModel.cpp
+16
-10
source/sample/transformer/T2TModel.h
+4
-1
source/tensor/core/getandset/SetData.cpp
+53
-3
source/tensor/core/getandset/SetData.cu
+79
-3
source/tensor/core/getandset/SetData.cuh
+3
-0
source/tensor/core/getandset/SetData.h
+3
-0
没有找到文件。
source/sample/transformer/T2TAttention.cpp
查看文件 @
30c3a629
...
...
@@ -36,6 +36,7 @@ T2TAttention::T2TAttention()
dv
=
-
1
;
d
=
-
1
;
isMasked
=
false
;
ignored
=
0
;
}
/* deconstructor */
...
...
@@ -47,14 +48,19 @@ T2TAttention::~T2TAttention()
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myIgnored - number of position ignored in attention (from the begining)
>> myIsMasked - indicates whether the attention is with a mask
>> myDevID - device id
>> myMem - the memory pool
*/
void
T2TAttention
::
InitModel
(
int
argc
,
const
char
**
argv
,
bool
myIsMasked
,
int
myDevID
,
XMem
*
myMem
)
void
T2TAttention
::
InitModel
(
int
argc
,
const
char
**
argv
,
bool
myIsMasked
,
int
myIgnored
,
int
myDevID
,
XMem
*
myMem
)
{
devID
=
myDevID
;
mem
=
myMem
;
isMasked
=
myIsMasked
;
ignored
=
myIgnored
;
float
minmax
=
0
;
...
...
@@ -116,6 +122,10 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
if
(
isMasked
)
dot
=
dot
+
mask
;
scalar
=
Softmax
(
Linear
(
dot
,
1
/
(
float
)
sqrt
((
float
)
dk
)),
-
1
);
if
(
ignored
>
0
)
_SetDataDim
(
&
scalar
,
0
,
ignored
,
scalar
.
order
-
2
,
1e-9
F
);
att
=
BMMul
(
scalar
,
vheads
);
/* concatenate the heads */
...
...
source/sample/transformer/T2TAttention.h
查看文件 @
30c3a629
...
...
@@ -69,6 +69,10 @@ public:
/* indicates whether the attention is masked */
bool
isMasked
;
/* some positions can be ignored in attention. this is useful in lm where the first position needs
special design for the attention model. */
int
ignored
;
public
:
/* constructor */
T2TAttention
();
...
...
@@ -77,7 +81,9 @@ public:
~
T2TAttention
();
/* initialize the model */
void
InitModel
(
int
argc
,
const
char
**
argv
,
bool
myIsMasked
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
void
InitModel
(
int
argc
,
const
char
**
argv
,
bool
myIsMasked
,
int
myIgnored
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
/* make the network */
XTensor
Make
(
XTensor
&
k
,
XTensor
&
q
,
XTensor
&
v
,
XTensor
&
mask
);
...
...
source/sample/transformer/T2TEmbedding.cpp
查看文件 @
30c3a629
...
...
@@ -136,7 +136,7 @@ XTensor T2TEmbedder::Make(XTensor &input)
wordEmbedding
=
Linear
(
MMul
(
input
,
w
),
(
float
)
sqrt
((
float
)
d
));
/* we sum over the two embeddings */
return
wordEmbedding
+
posEmbedding
;
return
wordEmbedding
+
posEmbedding
;
}
}
source/sample/transformer/T2TEncoder.cpp
查看文件 @
30c3a629
...
...
@@ -47,14 +47,17 @@ initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myIsMasked - indicates whether the masked attention is employed
>> myIgnored - number of positions ignored in attention (from the start)
>> myDevID - device id
>> myMem - the memory pool
*/
void
AttEncoder
::
InitModel
(
int
argc
,
const
char
**
argv
,
bool
myIsMasked
,
int
myDevID
,
XMem
*
myMem
)
void
AttEncoder
::
InitModel
(
int
argc
,
const
char
**
argv
,
bool
myIsMasked
,
int
myIgnored
,
int
myDevID
,
XMem
*
myMem
)
{
devID
=
myDevID
;
mem
=
myMem
;
i
sMasked
=
myIsMask
ed
;
i
gnored
=
myIgnor
ed
;
LoadParamInt
(
argc
,
argv
,
"nlayer"
,
&
nlayer
,
6
);
LoadParamInt
(
argc
,
argv
,
"hsize"
,
&
hSize
,
DEFAULT_EMBEDDING_SIZE
);
...
...
@@ -74,7 +77,7 @@ void AttEncoder::InitModel(int argc, const char ** argv, bool myIsMasked, int my
/* initialize the stacked layers */
for
(
int
i
=
0
;
i
<
nlayer
;
i
++
){
attentions
[
i
].
InitModel
(
argc
,
argv
,
isMask
ed
,
myDevID
,
myMem
);
attentions
[
i
].
InitModel
(
argc
,
argv
,
myIsMasked
,
myIgnor
ed
,
myDevID
,
myMem
);
fnns
[
i
].
InitModel
(
argc
,
argv
,
myDevID
,
myMem
);
attLayerNorms
[
i
].
InitModel
(
argc
,
argv
,
myDevID
,
myMem
);
fnnLayerNorms
[
i
].
InitModel
(
argc
,
argv
,
myDevID
,
myMem
);
...
...
@@ -85,9 +88,10 @@ void AttEncoder::InitModel(int argc, const char ** argv, bool myIsMasked, int my
make the encoding network
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> skipInputRes - indicates whether we skip the residual connection of the first layer
<< return - the output tensor of the encoder
*/
XTensor
AttEncoder
::
Make
(
XTensor
&
input
,
XTensor
&
mask
)
XTensor
AttEncoder
::
Make
(
XTensor
&
input
,
XTensor
&
mask
,
bool
skipInputRes
)
{
XTensor
x
;
...
...
@@ -99,16 +103,27 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask)
XTensor
fnn
;
XTensor
res
;
/* self attention */
att
=
attentions
[
i
].
Make
(
x
,
x
,
x
,
mask
);
if
(
skipInputRes
&&
i
==
0
){
/* self attention */
att
=
attentions
[
i
].
Make
(
x
,
x
,
x
,
mask
);
/* residual connection */
res
=
Sum
(
att
,
x
);
/* TODO: dropout */
/* TODO: dropout */
/* layer normalization */
x
=
attLayerNorms
[
i
].
Make
(
att
);
}
else
{
/* self attention */
att
=
attentions
[
i
].
Make
(
x
,
x
,
x
,
mask
);
/* layer normalization */
x
=
attLayerNorms
[
i
].
Make
(
res
);
/* residual connection */
res
=
Sum
(
att
,
x
);
/* TODO: dropout */
/* layer normalization */
x
=
attLayerNorms
[
i
].
Make
(
res
);
}
/* fnn */
fnn
=
fnns
[
i
].
Make
(
x
);
...
...
source/sample/transformer/T2TEncoder.h
查看文件 @
30c3a629
...
...
@@ -40,7 +40,7 @@ class T2TEncoder
{
public
:
virtual
XTensor
Make
(
XTensor
&
input
,
XTensor
&
mask
)
=
0
;
XTensor
Make
(
XTensor
&
input
,
XTensor
&
mask
,
bool
skipInputRes
)
=
0
;
};
/*
...
...
@@ -49,7 +49,7 @@ the encoder based on RNN
class
RNNEncoder
:
T2TEncoder
{
public
:
XTensor
Make
(
XTensor
&
input
,
XTensor
&
mask
);
XTensor
Make
(
XTensor
&
input
,
XTensor
&
mask
,
bool
skipInputRes
);
};
...
...
@@ -76,9 +76,10 @@ public:
/* vocabulary size */
int
vSize
;
/* indicates whether masked attention is employed */
int
isMasked
;
/* some positions can be ignored in attention. this is useful in lm where the first position needs
special design for the attention model. */
int
ignored
;
/* embedding of word at each position */
T2TEmbedder
embedder
;
...
...
@@ -109,10 +110,12 @@ public:
~
AttEncoder
();
/* initialize the model */
void
InitModel
(
int
argc
,
const
char
**
argv
,
bool
myIsMasked
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
void
InitModel
(
int
argc
,
const
char
**
argv
,
bool
myIsMasked
,
int
myIgnored
,
int
myDevID
=
-
1
,
XMem
*
myMem
=
NULL
);
/* make the encoding network */
XTensor
Make
(
XTensor
&
input
,
XTensor
&
mask
);
XTensor
Make
(
XTensor
&
input
,
XTensor
&
mask
,
bool
skipInputRes
);
};
...
...
source/sample/transformer/T2TLayerNormal.cpp
查看文件 @
30c3a629
...
...
@@ -90,7 +90,6 @@ XTensor T2TLN::Make(XTensor &input)
/* standard = sqrt(variance) */
standard
=
Power
(
variance
,
0.5
F
);
/* unsqueeze mean and standard deviation to fit them into
the same shape of x */
meanFilled
=
Unsqueeze
(
mean
,
x
.
order
-
1
,
x
.
GetDim
(
-
1
));
...
...
source/sample/transformer/T2TModel.cpp
查看文件 @
30c3a629
...
...
@@ -34,6 +34,7 @@ T2TModel::T2TModel()
mem
=
NULL
;
isLM
=
false
;
isMT
=
false
;
nhead
=
1
;
}
/* de-constructor */
...
...
@@ -55,13 +56,14 @@ void T2TModel::InitModel(int argc, const char ** argv)
LoadParamBool
(
argc
,
argv
,
"mem"
,
&
useMem
,
useMem
);
LoadParamBool
(
argc
,
argv
,
"lm"
,
&
isLM
,
true
);
LoadParamBool
(
argc
,
argv
,
"mt"
,
&
isMT
,
false
);
LoadParamInt
(
argc
,
argv
,
"nhead"
,
&
nhead
,
8
);
if
(
useMem
){
delete
mem
;
mem
=
new
XMem
(
devID
);
}
encoder
.
InitModel
(
argc
,
argv
,
isLM
,
devID
,
mem
);
encoder
.
InitModel
(
argc
,
argv
,
isLM
,
isLM
?
1
:
0
,
devID
,
mem
);
outputLayer
.
InitModel
(
argc
,
argv
,
devID
,
mem
);
}
...
...
@@ -69,11 +71,12 @@ void T2TModel::InitModel(int argc, const char ** argv)
make the encoding network
>> input - input tensor
>> mask - the mask for positions that are/not involved in computation
>> skipInputRes - indicates whether we skip the residual connection of the first layer
<< return - encoding result
*/
XTensor
T2TModel
::
MakeEncoding
(
XTensor
&
input
,
XTensor
&
mask
)
XTensor
T2TModel
::
MakeEncoding
(
XTensor
&
input
,
XTensor
&
mask
,
bool
skipInputRes
)
{
return
encoder
.
Make
(
input
,
mask
);
return
encoder
.
Make
(
input
,
mask
,
skipInputRes
);
}
/*
...
...
@@ -88,18 +91,21 @@ void T2TModel::Make(XTensor &input, XTensor &output)
if
(
isLM
){
/* generate mask to see "previous" words only */
int
len
=
input
.
GetDim
(
input
.
order
-
2
);
int
dims
[
MAX_TENSOR_DIM_NUM
];
int
*
dims
=
new
int
[
input
.
order
+
1
];
for
(
int
i
=
0
;
i
<
input
.
order
;
i
++
)
dims
[
i
]
=
input
.
GetDim
(
i
);
dims
[
input
.
order
-
1
]
=
len
;
XTensor
mask
(
input
.
order
,
dims
,
X_FLOAT
,
1.0
F
,
input
.
devID
,
input
.
mem
);
dims
[
i
+
1
]
=
input
.
GetDim
(
i
);
dims
[
0
]
=
nhead
;
dims
[
input
.
order
]
=
len
;
XTensor
mask
(
input
.
order
+
1
,
dims
,
X_FLOAT
,
1.0
F
,
input
.
devID
,
input
.
mem
);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9 */
_SetDataLowTri
(
&
mask
,
1e
-9
,
-
1
);
_ScaleAndShiftMe
(
&
mask
,
1.0
F
,
-
1e
-9
);
_SetDataLowTri
(
&
mask
,
1e
9
F
,
-
1
);
_ScaleAndShiftMe
(
&
mask
,
1.0
F
,
-
1e
9
F
);
encoding
=
MakeEncoding
(
input
,
mask
);
encoding
=
MakeEncoding
(
input
,
mask
,
true
);
outputLayer
.
Make
(
encoding
,
output
);
delete
[]
dims
;
}
else
{
ShowNTErrors
(
"TODO!"
);
...
...
source/sample/transformer/T2TModel.h
查看文件 @
30c3a629
...
...
@@ -55,6 +55,9 @@ public:
/* indicates whether the model is running for machine translation */
bool
isMT
;
/* number of heads in the attention model */
int
nhead
;
public
:
/* constructor */
T2TModel
();
...
...
@@ -66,7 +69,7 @@ public:
void
InitModel
(
int
argc
,
const
char
**
argv
);
/* make the encoding network */
XTensor
MakeEncoding
(
XTensor
&
input
,
XTensor
&
mask
);
XTensor
MakeEncoding
(
XTensor
&
input
,
XTensor
&
mask
,
bool
skipInputRes
);
/* make the entire network (with the output softmax layer) */
void
Make
(
XTensor
&
input
,
XTensor
&
output
);
...
...
source/tensor/core/getandset/SetData.cpp
查看文件 @
30c3a629
...
...
@@ -214,6 +214,56 @@ void _SetDataFixedDouble(XTensor * tensor, double p)
}
/*
set data items along with a given dimension (and keep the remaining items unchanged)
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> len - length along with the given dimension
>> dim - the dimension along which we set the data
e.g., given a 3 * 3 tensor
1 2 3
4 5 6
7 8 9
when beg = 1, len = 1, dim = 0 and p = 0, we have
1 2 3
0 0 0
7 8 9
i.e., we set all entries of row 1 to 0
*/
void
_SetDataDim
(
XTensor
*
tensor
,
int
beg
,
int
len
,
int
dim
,
DTYPE
p
)
{
int
n
=
tensor
->
order
;
CheckNTErrors
(
tensor
->
dataType
==
DEFAULT_DTYPE
,
"TODO!"
);
CheckNTErrors
(
dim
<
n
&&
dim
>
0
,
"Illegal dimension!"
);
CheckNTErrors
(
beg
>=
0
&&
beg
<
tensor
->
GetDim
(
dim
),
"Illegal beginning position!"
);
CheckNTErrors
(
beg
+
len
>=
0
&&
beg
+
len
<
tensor
->
GetDim
(
dim
),
"Illegal length!"
);
if
(
tensor
->
devID
<
0
){
int
stride
=
1
;
int
blockSize
=
1
;
int
blockNum
=
1
;
for
(
int
i
=
n
-
1
;
i
>
dim
;
i
--
){
stride
*=
tensor
->
GetDim
(
i
);
}
blockSize
=
stride
*
tensor
->
GetDim
(
dim
);
blockNum
=
tensor
->
unitNum
/
blockSize
;
int
l
=
len
*
stride
;
for
(
int
i
=
0
;
i
<
blockNum
;
i
++
){
DTYPE
*
d
=
(
DTYPE
*
)
tensor
->
data
+
blockSize
*
i
+
beg
*
stride
;
for
(
int
j
=
0
;
j
<
l
;
j
++
)
d
[
j
]
=
p
;
}
}
else
{
#ifdef USE_CUDA
_CudaSetDataDim
(
tensor
,
beg
,
len
,
dim
,
p
);
#endif
}
}
/*
generate data as lower triangular matrics for last two dimensions
>> tensor - the tensor whose data to be set
>> p - the value for each entry of the lower triangular matrics
...
...
@@ -247,10 +297,10 @@ void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift)
for
(
int
i
=
0
;
i
<
blockNum
;
i
++
){
DTYPE
*
d
=
(
DTYPE
*
)
tensor
->
data
+
i
*
blockSize
;
for
(
int
row
=
0
;
row
<
l
;
row
++
){
for
(
int
col
=
0
;
col
<
row
+
shift
;
col
++
){
d
[
row
*
l
+
col
]
=
row
;
for
(
int
col
=
0
;
col
<
=
row
+
shift
;
col
++
){
d
[
row
*
l
+
col
]
=
p
;
}
for
(
int
col
=
MAX
(
0
,
row
+
shift
);
col
<
l
;
col
++
){
for
(
int
col
=
MAX
(
0
,
row
+
shift
+
1
);
col
<
l
;
col
++
){
d
[
row
*
l
+
col
]
=
0
;
}
}
...
...
source/tensor/core/getandset/SetData.cu
查看文件 @
30c3a629
...
...
@@ -185,6 +185,82 @@ void KernelSetDataRandDouble(double * d, int size, DTYPE lower, DTYPE variance)
}
/*
set data items along with a given dimension (and keep the remaining items unchanged) - kernel version
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> len - length of the segment to be set
>> blockSize - size of a data block
>> blockNum - number of data blocks
*/
__global__
void KernelSetDataDim(DTYPE * d, int beg, int len, int blockSize, int blockNum, DTYPE p)
{
/* offset in each block */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* block id */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= blockSize || j > blockNum)
return;
if(i < beg || i >= beg + len)
return;
d[blockSize * j + i] = p;
}
/*
set data items along with a given dimension (and keep the remaining items unchanged) - cuda version
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> len - length along with the given dimension
>> dim - the dimension along which we set the data
e.g., given a 3 * 3 tensor
1 2 3
4 5 6
7 8 9
when beg = 1, len = 1, dim = 0 and p = 0, we have
1 2 3
0 0 0
7 8 9
i.e., we set all entries of row 1 to 0
*/
void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim > 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for(int i = n - 1; i > dim; i--){
stride *= tensor->GetDim(i);
}
blockSize = stride * tensor->GetDim(dim);
blockNum = tensor->unitNum / blockSize;
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(tensor->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataDim<<<blocks, threads >>>((DTYPE*)tensor->data, beg * stride, len * stride, blockSize, blockNum, p);
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
set lower triangular matrics for each block
>> d - pointer to the data array
>> l - row number (or column number) of each block, i.e,
...
...
@@ -204,7 +280,7 @@ e.g., for a 3* 3 tensor,
2 2 0
*/
__global__
void _Kern
a
lSetDataLowTri(DTYPE * d, int l, int blockSize, int blockNum, DTYPE p, int shift)
void _Kern
e
lSetDataLowTri(DTYPE * d, int l, int blockSize, int blockNum, DTYPE p, int shift)
{
/* offset in each block */
int i = blockDim.x * blockIdx.x + threadIdx.x;
...
...
@@ -219,7 +295,7 @@ void _KernalSetDataLowTri(DTYPE * d, int l, int blockSize, int blockNum, DTYPE p
int col = i % l;
DTYPE * d2 = d + blockSize * j + row * l + col;
if(col < row + shift)
if(col <
=
row + shift)
*d2 = p;
else
*d2 = 0;
...
...
@@ -266,7 +342,7 @@ void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift)
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
_Kern
a
lSetDataLowTri<<<blocks, threads >>>((DTYPE*)tensor->data, l, blockSize, blockNum, p, shift);
_Kern
e
lSetDataLowTri<<<blocks, threads >>>((DTYPE*)tensor->data, l, blockSize, blockNum, p, shift);
BacktoCudaDev(tensor->devID, devIDBackup);
}
...
...
source/tensor/core/getandset/SetData.cuh
查看文件 @
30c3a629
...
...
@@ -37,6 +37,9 @@ void _CudaSetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */
void _CudaSetDataFixedDouble(XTensor * tensor, double p);
/* set data items along with a given dimension (and keep the remaining items unchanged) */
void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p);
/* generate data as lower triangular matrics for last two dimensions (cuda version) */
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift);
...
...
source/tensor/core/getandset/SetData.h
查看文件 @
30c3a629
...
...
@@ -45,6 +45,9 @@ void _SetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */
void
_SetDataFixedDouble
(
XTensor
*
tensor
,
double
p
);
/* set data items along with a given dimension (and keep the remaining items unchanged) */
void
_SetDataDim
(
XTensor
*
tensor
,
int
beg
,
int
len
,
int
dim
,
DTYPE
p
);
/* generate data as lower triangular matrics for last two dimensions */
void
_SetDataLowTri
(
XTensor
*
tensor
,
DTYPE
p
,
int
shift
);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论