Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
1b50554a
Commit
1b50554a
authored
Sep 17, 2018
by
xuchen
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'xuchen' into xiaotong-working
parents
cf43c58c
102db468
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
82 行增加
和
94 行删除
+82
-94
source/tensor/core/arithmetic/Multiply.cpp
+0
-0
source/tensor/core/arithmetic/Multiply.cu
+2
-4
source/tensor/function/Dropout.cpp
+55
-66
source/tensor/function/Dropout.h
+5
-5
source/tensor/test/TDropout.cpp
+20
-19
没有找到文件。
source/tensor/core/arithmetic/Multiply.cpp
查看文件 @
1b50554a
source/tensor/core/arithmetic/Multiply.cu
查看文件 @
1b50554a
...
@@ -171,14 +171,12 @@ void _CudaMultiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alph
...
@@ -171,14 +171,12 @@ void _CudaMultiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alph
if (alpha == 0) {
if (alpha == 0) {
KernelMulElementWiseTensorDynamic<0> << <blocks, threads >> >
KernelMulElementWiseTensorDynamic<0> << <blocks, threads >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, 0,
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, 0,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
blockNum);
}
}
else {
else {
KernelMulElementWiseTensorDynamic<1> << <blocks, threads >> >
KernelMulElementWiseTensorDynamic<1> << <blocks, threads >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, alpha,
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, alpha,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
blockNum);
}
}
}
}
}
}
...
...
source/tensor/function/Dropout.cpp
查看文件 @
1b50554a
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#include "Dropout.h"
#include "Dropout.h"
#include "Dropout.cuh"
#include "Dropout.cuh"
#include "../core/arithmetic/Multiply.h"
#include "../core/arithmetic/Multiply.h"
#include "../core/arithmetic/
Sum
Dim.h"
#include "../core/arithmetic/
Multiply
Dim.h"
#include "../core/math/ScaleAndShift.h"
#include "../core/math/ScaleAndShift.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor
namespace
nts
{
// namespace nts(NiuTrans.Tensor
...
@@ -44,40 +44,35 @@ the same inference procedure as that with no use of dropout on the test data.
...
@@ -44,40 +44,35 @@ the same inference procedure as that with no use of dropout on the test data.
>> x - input tensor
>> x - input tensor
>> y - output tensor
>> y - output tensor
>> prob - probability to set an element to zero
>> seed - random seed
>> dropProb - probability to set an element to zero
>> leadingDim - the dimension which we generate the random numbers and perform broadcasting
*/
*/
void
_Dropout
(
const
XTensor
*
x
,
XTensor
*
y
,
unsigned
int
seed
,
DTYPE
prob
)
void
_Dropout
(
const
XTensor
*
x
,
XTensor
*
y
,
unsigned
int
seed
,
DTYPE
dropProb
,
int
leadingDim
)
{
{
CheckNTErrors
(
prob
>=
0.0
&&
p
rob
<=
1.0
,
"The probability must be 0-1!"
);
CheckNTErrors
(
dropProb
>=
0.0
&&
dropP
rob
<=
1.0
,
"The probability must be 0-1!"
);
DTYPE
scaleFactor
=
(
DTYPE
)
1.0
/
((
DTYPE
)
1.0
-
prob
);
int
n
=
leadingDim
<
0
?
x
->
order
-
1
:
leadingDim
;
CheckNTErrors
(
n
>=
0
&&
n
<
x
->
order
,
"Wrong leadingDim!"
);
DTYPE
scaleFactor
=
(
DTYPE
)
1.0
/
((
DTYPE
)
1.0
-
dropProb
);
/* generate a mask tensor again with special probability */
/* generate a mask tensor again with special probability */
srand
(
seed
);
int
unitNum
=
x
->
dimSize
[
n
];
int
unitNum
=
x
->
unitNum
;
DTYPE
*
maskArray
=
new
DTYPE
[
unitNum
];
DTYPE
*
maskArray
=
new
DTYPE
[
unitNum
];
srand
(
seed
);
for
(
int
i
=
0
;
i
<
unitNum
;
i
++
)
for
(
int
i
=
0
;
i
<
unitNum
;
i
++
)
maskArray
[
i
]
=
RandomBernoulli
(
prob
,
1.0
F
);
maskArray
[
i
]
=
RandomBernoulli
(
dropProb
,
scaleFactor
);
XTensor
*
mask
Tensor
=
NewTensorBuf
(
x
,
x
->
devID
,
x
->
mem
);
XTensor
*
mask
=
NewTensor1D
(
unitNum
,
x
->
dataType
,
x
->
devID
,
x
->
mem
);
mask
Tensor
->
SetData
(
maskArray
,
unitNum
);
mask
->
SetData
(
maskArray
,
unitNum
);
#ifdef USE_CUDA
/* call Multiply function for mask */
if
(
x
->
devID
>=
0
||
y
->
devID
>=
0
){
_MultiplyDim
(
x
,
mask
,
y
,
n
,
0
);
_CudaDropout
(
x
,
y
,
maskTensor
,
scaleFactor
);
DelTensorBuf
(
maskTensor
);
delete
mask
;
delete
[]
maskArray
;
return
;
}
#endif
XTensor
*
inter
=
NewTensorBuf
(
x
,
x
->
devID
,
x
->
mem
);
_Multiply
(
x
,
maskTensor
,
inter
);
_ScaleAndShift
(
inter
,
y
,
scaleFactor
,
0
);
DelTensorBuf
(
inter
);
DelTensorBuf
(
maskTensor
);
delete
[]
maskArray
;
delete
[]
maskArray
;
}
}
...
@@ -90,44 +85,39 @@ dE/dx = dE/dy * dy/dx
...
@@ -90,44 +85,39 @@ dE/dx = dE/dy * dy/dx
>> x - input of the dropout function
>> x - input of the dropout function
>> dedy - dE/dy
>> dedy - dE/dy
>> dedx - dE/dx
>> dedx - dE/dx
>> prob - probability to set an element zero
>> seed - random seed
>> dropProb - probability to set an element to zero
>> leadingDim - the dimension which we generate the random numbers and perform broadcasting
*/
*/
void
_DropoutBackward
(
const
XTensor
*
y
,
const
XTensor
*
x
,
void
_DropoutBackward
(
const
XTensor
*
y
,
const
XTensor
*
x
,
const
XTensor
*
dedy
,
XTensor
*
dedx
,
const
XTensor
*
dedy
,
XTensor
*
dedx
,
unsigned
int
seed
,
DTYPE
prob
)
unsigned
int
seed
,
DTYPE
dropProb
,
int
leadingDim
)
{
{
CheckNTErrors
(
dropProb
>=
0.0
&&
dropProb
<=
1.0
,
"The probability must be 0-1!"
);
int
n
=
leadingDim
<
0
?
x
->
order
-
1
:
leadingDim
;
CheckNTErrors
(
n
>=
0
&&
n
<
x
->
order
,
"Wrong leadingDim!"
);
if
(
x
->
dataType
==
DEFAULT_DTYPE
&&
y
->
dataType
==
DEFAULT_DTYPE
)
if
(
x
->
dataType
==
DEFAULT_DTYPE
&&
y
->
dataType
==
DEFAULT_DTYPE
)
{
{
int
unitNum
=
y
->
unitNum
;
DTYPE
scaleFactor
=
(
DTYPE
)
1.0
F
/
((
DTYPE
)
1.0
F
-
dropProb
);
DTYPE
scaleFactor
=
(
DTYPE
)
1.0
F
/
((
DTYPE
)
1.0
F
-
prob
);
/* generate a mask tensor again with special probability */
/* generate a mask tensor again with special probability */
srand
(
seed
)
;
int
unitNum
=
x
->
dimSize
[
n
]
;
DTYPE
*
maskArray
=
new
DTYPE
[
unitNum
];
DTYPE
*
maskArray
=
new
DTYPE
[
unitNum
];
for
(
int
i
=
0
;
i
<
unitNum
;
i
++
)
maskArray
[
i
]
=
RandomBernoulli
(
prob
,
1.0
F
);
XTensor
*
maskTensor
=
NewTensorBuf
(
x
,
x
->
devID
,
x
->
mem
);
maskTensor
->
SetData
(
maskArray
,
unitNum
);
#ifdef USE_CUDA
srand
(
seed
);
if
(
x
->
devID
>=
0
||
y
->
devID
>=
0
){
for
(
int
i
=
0
;
i
<
unitNum
;
i
++
)
_CudaDropoutBackward
(
y
,
x
,
dedy
,
dedx
,
maskTensor
,
scaleFactor
);
maskArray
[
i
]
=
RandomBernoulli
(
dropProb
,
scaleFactor
);
DelTensorBuf
(
maskTensor
);
delete
[]
maskArray
;
return
;
}
#endif
DTYPE
*
dedyp
=
(
DTYPE
*
)
dedy
->
data
;
XTensor
*
mask
=
NewTensor1D
(
unitNum
,
x
->
dataType
,
x
->
devID
,
x
->
mem
)
;
DTYPE
*
dedxp
=
(
DTYPE
*
)
dedx
->
data
;
mask
->
SetData
(
maskArray
,
unitNum
)
;
/* dE/dx = dE/dy * dy/dx */
/* call MultiplyDim function for mask */
for
(
int
i
=
0
;
i
<
unitNum
;
i
++
)
_MultiplyDim
(
dedy
,
mask
,
dedx
,
n
,
0
);
dedxp
[
i
]
=
dedyp
[
i
]
*
maskArray
[
i
]
*
scaleFactor
;
DelTensorBuf
(
maskTensor
)
;
delete
mask
;
delete
[]
maskArray
;
delete
[]
maskArray
;
}
}
else
else
...
@@ -147,14 +137,18 @@ to mark the tensor with probability p in the inference phase. Instead we perform
...
@@ -147,14 +137,18 @@ to mark the tensor with probability p in the inference phase. Instead we perform
the same inference procedure as that with no use of dropout on the test data.
the same inference procedure as that with no use of dropout on the test data.
>> x - input tensor
>> x - input tensor
>> y - output tensor
>> dropProb - probability to set an element to zero
>> prob - probability to set an element to zero
>> leadingDim - the dimension which we generate the random numbers and perform broadcasting
>> leadDim - the dimension along which we generate the random numbers
*/
*/
XTensor
Dropout
(
const
XTensor
&
x
,
DTYPE
prob
,
int
lead
Dim
)
XTensor
Dropout
(
const
XTensor
&
x
,
DTYPE
dropProb
,
int
leading
Dim
)
{
{
int
n
=
leadDim
<
0
?
x
.
order
-
1
:
leadDim
;
CheckNTErrors
(
dropProb
>=
0.0
&&
dropProb
<=
1.0
,
"The probability must be 0-1!"
);
DTYPE
scaleFactor
=
(
DTYPE
)
1.0
/
((
DTYPE
)
1.0
-
prob
);
int
n
=
leadingDim
<
0
?
x
.
order
-
1
:
leadingDim
;
CheckNTErrors
(
n
>=
0
&&
n
<
x
.
order
,
"Wrong leadingDim!"
);
DTYPE
scaleFactor
=
(
DTYPE
)
1.0
/
((
DTYPE
)
1.0
-
dropProb
);
/* generate a mask tensor with probability p */
/* generate a mask tensor with probability p */
int
unitNum
=
x
.
dimSize
[
n
];
int
unitNum
=
x
.
dimSize
[
n
];
...
@@ -162,20 +156,15 @@ XTensor Dropout(const XTensor &x, DTYPE prob, int leadDim)
...
@@ -162,20 +156,15 @@ XTensor Dropout(const XTensor &x, DTYPE prob, int leadDim)
srand
((
unsigned
int
)
time
(
NULL
));
srand
((
unsigned
int
)
time
(
NULL
));
for
(
int
i
=
0
;
i
<
unitNum
;
i
++
)
for
(
int
i
=
0
;
i
<
unitNum
;
i
++
)
maskArray
[
i
]
=
RandomBernoulli
(
prob
,
scaleFactor
);
maskArray
[
i
]
=
RandomBernoulli
(
dropProb
,
scaleFactor
);
XTensor
mask
(
&
x
);
mask
.
SetZeroAll
();
XTensor
*
maskVector
=
NewTensorBuf
(
1
,
&
unitNum
,
X_FLOAT
,
1.0
F
,
x
.
devID
,
x
.
mem
);
maskVector
->
SetData
(
maskArray
,
unitNum
);
_SumDim
(
&
mask
,
maskVector
,
&
mask
,
n
);
XTensor
mask
;
InitTensor1D
(
&
mask
,
unitNum
,
x
.
dataType
,
x
.
devID
,
x
.
mem
);
mask
.
SetData
(
maskArray
,
unitNum
);
delete
[]
maskArray
;
delete
[]
maskArray
;
DelTensorBuf
(
maskVector
);
return
Multiply
(
x
,
mask
);
return
Multiply
Dim
(
x
,
mask
,
n
,
0
);
}
}
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
source/tensor/function/Dropout.h
查看文件 @
1b50554a
...
@@ -28,21 +28,21 @@
...
@@ -28,21 +28,21 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/* generate a random bernoulli number */
/* generate a random bernoulli number */
inline
DTYPE
RandomBernoulli
(
DTYPE
p
rob
,
DTYPE
value
)
inline
DTYPE
RandomBernoulli
(
DTYPE
dropP
rob
,
DTYPE
value
)
{
{
return
(
DTYPE
)
rand
()
/
(
DTYPE
)
RAND_MAX
>=
p
rob
?
(
DTYPE
)
value
:
0
;
return
(
DTYPE
)
rand
()
/
(
DTYPE
)
RAND_MAX
>=
dropP
rob
?
(
DTYPE
)
value
:
0
;
}
}
/* dropout function */
/* dropout function */
void
_Dropout
(
const
XTensor
*
x
,
XTensor
*
y
,
unsigned
int
seed
,
DTYPE
prob
);
void
_Dropout
(
const
XTensor
*
x
,
XTensor
*
y
,
unsigned
int
seed
,
DTYPE
dropProb
,
int
leadingDim
=
-
1
);
/* de/dx */
/* de/dx */
void
_DropoutBackward
(
const
XTensor
*
y
,
const
XTensor
*
x
,
void
_DropoutBackward
(
const
XTensor
*
y
,
const
XTensor
*
x
,
const
XTensor
*
dedy
,
XTensor
*
dedx
,
const
XTensor
*
dedy
,
XTensor
*
dedx
,
unsigned
int
seed
,
DTYPE
prob
);
unsigned
int
seed
,
DTYPE
dropProb
,
int
leadingDim
=
-
1
);
/* dropout function */
/* dropout function */
XTensor
Dropout
(
const
XTensor
&
x
,
DTYPE
prob
,
int
lead
Dim
=
-
1
);
XTensor
Dropout
(
const
XTensor
&
x
,
DTYPE
dropProb
,
int
leading
Dim
=
-
1
);
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
...
...
source/tensor/test/TDropout.cpp
查看文件 @
1b50554a
...
@@ -31,10 +31,11 @@ case 1: test Dropout function.
...
@@ -31,10 +31,11 @@ case 1: test Dropout function.
bool
TestDropout1
()
bool
TestDropout1
()
{
{
/* a input tensor of size (4, 5) */
/* a input tensor of size (4, 5) */
int
order
=
2
;
int
order
=
3
;
int
*
dimSize
=
new
int
[
order
];
int
*
dimSize
=
new
int
[
order
];
dimSize
[
0
]
=
40
;
dimSize
[
0
]
=
40
;
dimSize
[
1
]
=
50
;
dimSize
[
1
]
=
50
;
dimSize
[
2
]
=
60
;
int
unitNum
=
1
;
int
unitNum
=
1
;
for
(
int
i
=
0
;
i
<
order
;
i
++
)
for
(
int
i
=
0
;
i
<
order
;
i
++
)
...
@@ -49,14 +50,14 @@ bool TestDropout1()
...
@@ -49,14 +50,14 @@ bool TestDropout1()
XTensor
yUser
;
XTensor
yUser
;
/* initialize variables */
/* initialize variables */
x
->
SetDataRand
(
0
,
1
);
_SetDataFixedFloat
(
x
,
1.0
F
);
y
->
SetZeroAll
();
y
->
SetZeroAll
();
/* call Dropout function */
/* call Dropout function */
float
p
rob
=
0.2
F
;
float
dropP
rob
=
0.2
F
;
int
seed
=
20
;
int
seed
=
20
;
_Dropout
(
x
,
y
,
seed
,
p
rob
);
_Dropout
(
x
,
y
,
seed
,
dropP
rob
);
yUser
=
Dropout
(
*
x
,
0.5
F
);
yUser
=
Dropout
(
*
x
,
dropProb
);
/* check result */
/* check result */
int
zeroNum1
=
0
;
int
zeroNum1
=
0
;
...
@@ -73,9 +74,9 @@ bool TestDropout1()
...
@@ -73,9 +74,9 @@ bool TestDropout1()
}
}
printf
(
"CPU Test:
\n
"
);
printf
(
"CPU Test:
\n
"
);
printf
(
"In tensor y, there are %d units.
\n
"
,
unitNum
);
printf
(
"In tensor y, there are %d units.
\n
"
,
unitNum
);
printf
(
"There are %d zero units by Dropout layer with probability %.2f.
\n
"
,
zeroNum1
,
p
rob
);
printf
(
"There are %d zero units by Dropout layer with probability %.2f.
\n
"
,
zeroNum1
,
dropP
rob
);
printf
(
"In tensor yUser, there are %d units.
\n
"
,
unitNum
);
printf
(
"In tensor yUser, there are %d units.
\n
"
,
unitNum
);
printf
(
"There are %d zero units by Dropout layer with default probability %.2f.
\n
"
,
zeroNum2
,
0.5
F
);
printf
(
"There are %d zero units by Dropout layer with default probability %.2f.
\n
"
,
zeroNum2
,
dropProb
);
#ifdef USE_CUDA
#ifdef USE_CUDA
/* GPU test */
/* GPU test */
...
@@ -87,12 +88,12 @@ bool TestDropout1()
...
@@ -87,12 +88,12 @@ bool TestDropout1()
XTensor
yUserGPU
;
XTensor
yUserGPU
;
/* initialize variables */
/* initialize variables */
xGPU
->
SetDataRand
(
0
,
1
);
_SetDataFixedFloat
(
xGPU
,
1.0
F
);
yGPU
->
SetZeroAll
();
yGPU
->
SetZeroAll
();
/* call Dropout function */
/* call Dropout function */
_Dropout
(
xGPU
,
yGPU
,
seed
,
p
rob
);
_Dropout
(
xGPU
,
yGPU
,
seed
,
dropP
rob
);
yUserGPU
=
Dropout
(
*
xGPU
,
0.5
F
);
yUserGPU
=
Dropout
(
*
xGPU
,
dropProb
);
/* check result */
/* check result */
zeroNum1
=
0
;
zeroNum1
=
0
;
...
@@ -109,9 +110,9 @@ bool TestDropout1()
...
@@ -109,9 +110,9 @@ bool TestDropout1()
}
}
printf
(
"CPU Test:
\n
"
);
printf
(
"CPU Test:
\n
"
);
printf
(
"In tensor y, there are %d units.
\n
"
,
unitNum
);
printf
(
"In tensor y, there are %d units.
\n
"
,
unitNum
);
printf
(
"There are %d zero units by Dropout layer with probability %.2f.
\n
"
,
zeroNum1
,
p
rob
);
printf
(
"There are %d zero units by Dropout layer with probability %.2f.
\n
"
,
zeroNum1
,
dropP
rob
);
printf
(
"In tensor yUser, there are %d units.
\n
"
,
unitNum
);
printf
(
"In tensor yUser, there are %d units.
\n
"
,
unitNum
);
printf
(
"There are %d zero units by Dropout layer with default probability %.2f.
\n
"
,
zeroNum2
,
0.5
F
);
printf
(
"There are %d zero units by Dropout layer with default probability %.2f.
\n
"
,
zeroNum2
,
dropProb
);
/* destroy variables */
/* destroy variables */
delete
x
;
delete
x
;
...
@@ -159,13 +160,13 @@ bool TestDropout2()
...
@@ -159,13 +160,13 @@ bool TestDropout2()
_SetDataFixedFloat
(
x
,
1.0
F
);
_SetDataFixedFloat
(
x
,
1.0
F
);
y
->
SetZeroAll
();
y
->
SetZeroAll
();
dedx
->
SetZeroAll
();
dedx
->
SetZeroAll
();
_SetDataFixedFloat
(
dedy
,
1.
0
F
);
_SetDataFixedFloat
(
dedy
,
1.
5
F
);
/* call Dropout function */
/* call Dropout function */
float
p
rob
=
0.5
F
;
float
dropP
rob
=
0.5
F
;
int
seed
=
1
;
int
seed
=
1
;
_Dropout
(
x
,
y
,
seed
,
p
rob
);
_Dropout
(
x
,
y
,
seed
,
dropP
rob
);
_DropoutBackward
(
y
,
x
,
dedy
,
dedx
,
1
,
p
rob
);
_DropoutBackward
(
y
,
x
,
dedy
,
dedx
,
1
,
dropP
rob
);
/* check result */
/* check result */
y
->
Dump
(
stderr
,
"y"
);
y
->
Dump
(
stderr
,
"y"
);
...
@@ -185,11 +186,11 @@ bool TestDropout2()
...
@@ -185,11 +186,11 @@ bool TestDropout2()
_SetDataFixedFloat
(
xGPU
,
1.0
F
);
_SetDataFixedFloat
(
xGPU
,
1.0
F
);
yGPU
->
SetZeroAll
();
yGPU
->
SetZeroAll
();
dedxGPU
->
SetZeroAll
();
dedxGPU
->
SetZeroAll
();
_SetDataFixedFloat
(
dedyGPU
,
1.
0
F
);
_SetDataFixedFloat
(
dedyGPU
,
1.
5
F
);
/* call Dropout function */
/* call Dropout function */
_Dropout
(
xGPU
,
yGPU
,
seed
,
p
rob
);
_Dropout
(
xGPU
,
yGPU
,
seed
,
dropP
rob
);
_DropoutBackward
(
yGPU
,
xGPU
,
dedyGPU
,
dedxGPU
,
1
,
p
rob
);
_DropoutBackward
(
yGPU
,
xGPU
,
dedyGPU
,
dedxGPU
,
1
,
dropP
rob
);
/* check result */
/* check result */
yGPU
->
Dump
(
stderr
,
"yGPU"
);
yGPU
->
Dump
(
stderr
,
"yGPU"
);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论