Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
52a50ab2
Commit
52a50ab2
authored
Dec 29, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
dropout withno use of broadcasting
parent
b1a9adde
显示空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
121 行增加
和
20 行删除
+121
-20
source/sample/transformer/T2TDecoder.cpp
+4
-4
source/sample/transformer/T2TEncoder.cpp
+3
-3
source/tensor/XDevice.cpp
+5
-0
source/tensor/XDevice.h
+3
-0
source/tensor/core/getandset/SetData.cpp
+34
-1
source/tensor/core/getandset/SetData.cu
+54
-5
source/tensor/core/getandset/SetData.cuh
+5
-1
source/tensor/core/getandset/SetData.h
+5
-1
source/tensor/function/Dropout.cpp
+8
-5
没有找到文件。
source/sample/transformer/T2TDecoder.cpp
查看文件 @
52a50ab2
...
@@ -82,7 +82,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
...
@@ -82,7 +82,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */
/* dropout */
if
(
isTraining
&&
dropoutP
>
0
)
if
(
isTraining
&&
dropoutP
>
0
)
x
=
Dropout
(
x
,
dropoutP
,
0
,
2
);
x
=
Dropout
(
x
,
dropoutP
);
for
(
int
i
=
0
;
i
<
nlayer
;
i
++
){
for
(
int
i
=
0
;
i
<
nlayer
;
i
++
){
XTensor
att
;
XTensor
att
;
...
@@ -97,7 +97,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
...
@@ -97,7 +97,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */
/* dropout */
if
(
isTraining
&&
dropoutP
>
0
)
if
(
isTraining
&&
dropoutP
>
0
)
att
=
Dropout
(
att
,
dropoutP
,
0
,
2
);
att
=
Dropout
(
att
,
dropoutP
);
/* residual connection */
/* residual connection */
res
=
Sum
(
att
,
x
);
res
=
Sum
(
att
,
x
);
...
@@ -111,7 +111,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
...
@@ -111,7 +111,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */
/* dropout */
if
(
isTraining
&&
dropoutP
>
0
)
if
(
isTraining
&&
dropoutP
>
0
)
ende
=
Dropout
(
ende
,
dropoutP
,
0
,
2
);
ende
=
Dropout
(
ende
,
dropoutP
);
/* residual connection */
/* residual connection */
res
=
Sum
(
ende
,
x
);
res
=
Sum
(
ende
,
x
);
...
@@ -125,7 +125,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
...
@@ -125,7 +125,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */
/* dropout */
if
(
isTraining
&&
dropoutP
>
0
)
if
(
isTraining
&&
dropoutP
>
0
)
fnn
=
Dropout
(
fnn
,
dropoutP
,
0
,
2
);
fnn
=
Dropout
(
fnn
,
dropoutP
);
/* residual connection */
/* residual connection */
res
=
Sum
(
fnn
,
x
);
res
=
Sum
(
fnn
,
x
);
...
...
source/sample/transformer/T2TEncoder.cpp
查看文件 @
52a50ab2
...
@@ -105,7 +105,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
...
@@ -105,7 +105,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
/* dropout */
/* dropout */
if
(
isTraining
&&
dropoutP
>
0
)
if
(
isTraining
&&
dropoutP
>
0
)
x
=
Dropout
(
x
,
dropoutP
,
0
,
2
);
x
=
Dropout
(
x
,
dropoutP
);
for
(
int
i
=
0
;
i
<
nlayer
;
i
++
){
for
(
int
i
=
0
;
i
<
nlayer
;
i
++
){
XTensor
att
;
XTensor
att
;
...
@@ -118,7 +118,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
...
@@ -118,7 +118,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
/* dropout */
/* dropout */
if
(
isTraining
&&
dropoutP
>
0
)
if
(
isTraining
&&
dropoutP
>
0
)
att
=
Dropout
(
att
,
dropoutP
,
0
,
2
);
att
=
Dropout
(
att
,
dropoutP
);
/* residual connection */
/* residual connection */
res
=
Sum
(
att
,
x
);
res
=
Sum
(
att
,
x
);
...
@@ -131,7 +131,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
...
@@ -131,7 +131,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
/* dropout */
/* dropout */
if
(
isTraining
&&
dropoutP
>
0
)
if
(
isTraining
&&
dropoutP
>
0
)
fnn
=
Dropout
(
fnn
,
dropoutP
,
0
,
2
);
fnn
=
Dropout
(
fnn
,
dropoutP
);
/* residual connection */
/* residual connection */
res
=
Sum
(
fnn
,
x
);
res
=
Sum
(
fnn
,
x
);
...
...
source/tensor/XDevice.cpp
查看文件 @
52a50ab2
...
@@ -60,6 +60,7 @@ XDevice::~XDevice()
...
@@ -60,6 +60,7 @@ XDevice::~XDevice()
cublasDestroy
(
cublasHandle
);
cublasDestroy
(
cublasHandle
);
if
(
stream
!=
NULL
)
if
(
stream
!=
NULL
)
delete
stream
;
delete
stream
;
curandDestroyGenerator
(
gen
);
#endif
#endif
}
}
...
@@ -82,6 +83,10 @@ void XDevice::Init(int myDevID)
...
@@ -82,6 +83,10 @@ void XDevice::Init(int myDevID)
cudaDeviceProp
prop
;
cudaDeviceProp
prop
;
cudaSetDevice
(
myDevID
);
cudaSetDevice
(
myDevID
);
curandCreateGenerator
(
&
gen
,
CURAND_RNG_PSEUDO_DEFAULT
);
curandSetPseudoRandomGeneratorSeed
(
gen
,
seed
);
if
(
cudaGetDeviceProperties
(
&
prop
,
devID
)
!=
cudaSuccess
){
if
(
cudaGetDeviceProperties
(
&
prop
,
devID
)
!=
cudaSuccess
){
XPRINT1
(
0
,
stderr
,
"cannot get GPU(%d) information."
,
devID
);
XPRINT1
(
0
,
stderr
,
"cannot get GPU(%d) information."
,
devID
);
exit
(
1
);
exit
(
1
);
...
...
source/tensor/XDevice.h
查看文件 @
52a50ab2
...
@@ -112,6 +112,9 @@ public:
...
@@ -112,6 +112,9 @@ public:
/* specify if the handle is initialized */
/* specify if the handle is initialized */
bool
isHandleReady
;
bool
isHandleReady
;
/* generater of random numbers */
curandGenerator_t
gen
;
#endif
#endif
...
...
source/tensor/core/getandset/SetData.cpp
查看文件 @
52a50ab2
...
@@ -387,7 +387,7 @@ generate data items with a uniform distribution in [lower, upper]
...
@@ -387,7 +387,7 @@ generate data items with a uniform distribution in [lower, upper]
>> lower - lower value of the range
>> lower - lower value of the range
>> upper - upper value of the range
>> upper - upper value of the range
*/
*/
void
_SetDataRand
(
XTensor
*
tensor
,
DTYPE
lower
,
DTYPE
upper
)
void
_SetDataRand
(
const
XTensor
*
tensor
,
DTYPE
lower
,
DTYPE
upper
)
{
{
CheckNTErrors
(
upper
>
lower
,
"the high value must be greater than low value!"
);
CheckNTErrors
(
upper
>
lower
,
"the high value must be greater than low value!"
);
...
@@ -432,6 +432,39 @@ void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
...
@@ -432,6 +432,39 @@ void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
}
}
/*
/*
generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise
>> tensor - the tensor whose data array would be initialized
>> lower - lower value of the range
>> upper - upper value of the range
>> p - the threshold
>> value - the value we intend to assign to the item
*/
void
_SetDataRandP
(
const
XTensor
*
tensor
,
DTYPE
lower
,
DTYPE
upper
,
DTYPE
p
,
DTYPE
value
)
{
CheckNTErrors
(
tensor
->
dataType
==
DEFAULT_DTYPE
,
"TODO"
);
if
(
tensor
->
devID
<
0
)
{
_SetDataRand
(
tensor
,
lower
,
upper
);
DTYPE
*
data
=
(
DTYPE
*
)
tensor
->
data
;
for
(
int
i
=
0
;
i
<
tensor
->
unitNum
;
i
++
)
{
if
(
data
[
i
]
>=
p
)
data
[
i
]
=
value
;
else
data
[
i
]
=
0
;
}
}
else
{
#ifdef USE_CUDA
_CudaSetDataRandP
(
tensor
,
lower
,
upper
,
p
,
value
);
#else
ShowNTErrors
(
"Please recompile the code by specifying USE_CUDA"
);
#endif // USE_CUDA
}
}
/*
generate data items with a normal distribution with specified mean and standard deviation
generate data items with a normal distribution with specified mean and standard deviation
>> tensor - the tensor that keeps the data
>> tensor - the tensor that keeps the data
>> mean - mean or expectation of the distribution
>> mean - mean or expectation of the distribution
...
...
source/tensor/core/getandset/SetData.cu
查看文件 @
52a50ab2
...
@@ -186,6 +186,26 @@ void KernelSetDataRandDouble(double * d, int size, DTYPE lower, DTYPE variance)
...
@@ -186,6 +186,26 @@ void KernelSetDataRandDouble(double * d, int size, DTYPE lower, DTYPE variance)
}
}
/*
/*
set data items to a pre-defined value if its value >= p, set it to 0 otherwise
>> d - pointer to the data array
>> size - size of the array
>> lower - low value of the range
>> variance - the variance of the range
*/
__global__
void KernelSetDataPCut(DTYPE * d, int size, DTYPE p, DTYPE value)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
if (d[i] >= p)
d[i] = value;
else
d[i] = 0;
}
}
/*
set data items along with a given dimension (and keep the remaining items unchanged) - kernel version
set data items along with a given dimension (and keep the remaining items unchanged) - kernel version
>> tensor - the tensor whose data array would be initialized
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> beg - the beginning position
...
@@ -437,7 +457,7 @@ generate data items with a uniform distribution in [lower, upper]
...
@@ -437,7 +457,7 @@ generate data items with a uniform distribution in [lower, upper]
>> lower - lower value of the range
>> lower - lower value of the range
>> upper - upper value of the range
>> upper - upper value of the range
*/
*/
void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
void _CudaSetDataRand(
const
XTensor * tensor, DTYPE lower, DTYPE upper)
{
{
CheckNTErrors(upper > lower, "the high value must be greater than low value!");
CheckNTErrors(upper > lower, "the high value must be greater than low value!");
...
@@ -452,17 +472,46 @@ void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
...
@@ -452,17 +472,46 @@ void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
int devIDBackup;
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
ProtectCudaDev(tensor->devID, devIDBackup);
curandGenerator_t gen;
curandGenerator_t & gen = GDevs.GPUs[tensor->devID].gen;
curandCreateGenerator (&gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(gen, time(NULL));
curandGenerateUniform(gen , (float*)tensor->data , tensor->unitNum);
curandGenerateUniform(gen , (float*)tensor->data , tensor->unitNum);
curandDestroyGenerator(gen);
DTYPE variance = upper - lower;
DTYPE variance = upper - lower;
if(variance != 1.0F || lower != 0){
if (tensor->dataType == X_FLOAT)
if (tensor->dataType == X_FLOAT)
KernelSetDataRandFloat <<<blocks, threads >>>((float*) tensor->data, tensor->unitNum, lower, variance);
KernelSetDataRandFloat <<<blocks, threads >>>((float*) tensor->data, tensor->unitNum, lower, variance);
else if (tensor->dataType == X_DOUBLE)
else if (tensor->dataType == X_DOUBLE)
KernelSetDataRandDouble <<<blocks, threads >>>((double*)tensor->data, tensor->unitNum, lower, variance);
KernelSetDataRandDouble <<<blocks, threads >>>((double*)tensor->data, tensor->unitNum, lower, variance);
}
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise
>> tensor - the tensor whose data array would be initialized
>> lower - lower value of the range
>> upper - upper value of the range
>> p - the threshold
>> value - the value we intend to assign to the item
*/
void _CudaSetDataRandP(const XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value)
{
_CudaSetDataRand(tensor, lower, upper);
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(tensor->devID, tensor->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataPCut << <blocks, threads >> >((float*)tensor->data, tensor->unitNum, p, value);
BacktoCudaDev(tensor->devID, devIDBackup);
BacktoCudaDev(tensor->devID, devIDBackup);
}
}
...
...
source/tensor/core/getandset/SetData.cuh
查看文件 @
52a50ab2
...
@@ -47,7 +47,11 @@ void _CudaSetDataIndexed(XTensor * source, XTensor * modify, int dim, int index)
...
@@ -47,7 +47,11 @@ void _CudaSetDataIndexed(XTensor * source, XTensor * modify, int dim, int index)
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift);
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift);
/* generate data items with a uniform distribution in [lower, upper] */
/* generate data items with a uniform distribution in [lower, upper] */
void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
void _CudaSetDataRand(const XTensor * tensor, DTYPE lower, DTYPE upper);
/* generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise */
void _CudaSetDataRandP(const XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value);
/* set the data with an array of offsets */
/* set the data with an array of offsets */
void _CudaSetDataWithOffset(XTensor * tensor, MTYPE * offsets, DTYPE value, MTYPE num);
void _CudaSetDataWithOffset(XTensor * tensor, MTYPE * offsets, DTYPE value, MTYPE num);
...
...
source/tensor/core/getandset/SetData.h
查看文件 @
52a50ab2
...
@@ -55,7 +55,11 @@ void _SetDataIndexed(XTensor * source, XTensor * modify, int dim, int index);
...
@@ -55,7 +55,11 @@ void _SetDataIndexed(XTensor * source, XTensor * modify, int dim, int index);
void
_SetDataLowTri
(
XTensor
*
tensor
,
DTYPE
p
,
int
shift
);
void
_SetDataLowTri
(
XTensor
*
tensor
,
DTYPE
p
,
int
shift
);
/* generate data items with a uniform distribution in [lower, upper] */
/* generate data items with a uniform distribution in [lower, upper] */
void
_SetDataRand
(
XTensor
*
tensor
,
DTYPE
lower
,
DTYPE
upper
);
void
_SetDataRand
(
const
XTensor
*
tensor
,
DTYPE
lower
,
DTYPE
upper
);
/* generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise */
void
_SetDataRandP
(
const
XTensor
*
tensor
,
DTYPE
lower
,
DTYPE
upper
,
DTYPE
p
,
DTYPE
value
);
/* generate data items with a normal distribution with specified mean and standard deviation */
/* generate data items with a normal distribution with specified mean and standard deviation */
void
_SetDataRandN
(
XTensor
*
tensor
,
DTYPE
mean
=
0
.
0
F
,
DTYPE
standardDeviation
=
1
.
0
F
);
void
_SetDataRandN
(
XTensor
*
tensor
,
DTYPE
mean
=
0
.
0
F
,
DTYPE
standardDeviation
=
1
.
0
F
);
...
...
source/tensor/function/Dropout.cpp
查看文件 @
52a50ab2
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include "../core/arithmetic/Multiply.h"
#include "../core/arithmetic/Multiply.h"
#include "../core/arithmetic/MultiplyDim.h"
#include "../core/arithmetic/MultiplyDim.h"
#include "../core/math/ScaleAndShift.h"
#include "../core/math/ScaleAndShift.h"
#include "../core/getandset/SetData.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor
namespace
nts
{
// namespace nts(NiuTrans.Tensor
...
@@ -147,17 +148,21 @@ XTensor Dropout(const XTensor &x, DTYPE dropProb, int leadingDim, int leadingDim
...
@@ -147,17 +148,21 @@ XTensor Dropout(const XTensor &x, DTYPE dropProb, int leadingDim, int leadingDim
XTensor
mask
;
XTensor
mask
;
DTYPE
*
maskArray
=
NULL
;
DTYPE
*
maskArray
=
NULL
;
DTYPE
scaleFactor
=
(
DTYPE
)
1.0
/
((
DTYPE
)
1.0
-
dropProb
);
if
(
leadingDim
<
0
&&
leadingDim2
<
0
){
if
(
leadingDim
<
0
&&
leadingDim2
<
0
){
ShowNTErrors
(
"TODO"
);
XTensor
mask
;
InitTensor
(
&
mask
,
&
x
);
_SetDataRandP
(
&
mask
,
0
,
1.0
F
,
dropProb
,
scaleFactor
);
return
Multiply
(
x
,
mask
);
}
}
else
if
(
leadingDim2
<
0
){
else
if
(
leadingDim2
<
0
){
int
n
=
leadingDim
;
int
n
=
leadingDim
;
CheckNTErrors
(
n
>=
0
&&
n
<
x
.
order
,
"Wrong leadingDim!"
);
CheckNTErrors
(
n
>=
0
&&
n
<
x
.
order
,
"Wrong leadingDim!"
);
DTYPE
scaleFactor
=
(
DTYPE
)
1.0
/
((
DTYPE
)
1.0
-
dropProb
);
/* generate a mask tensor with probability p */
/* generate a mask tensor with probability p */
int
unitNum
=
x
.
dimSize
[
n
];
int
unitNum
=
x
.
dimSize
[
n
];
maskArray
=
new
DTYPE
[
unitNum
];
maskArray
=
new
DTYPE
[
unitNum
];
...
@@ -181,8 +186,6 @@ XTensor Dropout(const XTensor &x, DTYPE dropProb, int leadingDim, int leadingDim
...
@@ -181,8 +186,6 @@ XTensor Dropout(const XTensor &x, DTYPE dropProb, int leadingDim, int leadingDim
CheckNTErrors
(
n
>=
0
&&
n
<
x
.
order
,
"Wrong leadingDim!"
);
CheckNTErrors
(
n
>=
0
&&
n
<
x
.
order
,
"Wrong leadingDim!"
);
CheckNTErrors
(
m
>=
0
&&
m
<
x
.
order
,
"Wrong leadingDim!"
);
CheckNTErrors
(
m
>=
0
&&
m
<
x
.
order
,
"Wrong leadingDim!"
);
DTYPE
scaleFactor
=
(
DTYPE
)
1.0
/
((
DTYPE
)
1.0
-
dropProb
);
/* generate a mask tensor with probability p */
/* generate a mask tensor with probability p */
int
unitNum
=
x
.
dimSize
[
n
]
*
x
.
dimSize
[
m
];
int
unitNum
=
x
.
dimSize
[
n
]
*
x
.
dimSize
[
m
];
maskArray
=
new
DTYPE
[
unitNum
];
maskArray
=
new
DTYPE
[
unitNum
];
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论