Commit 52a50ab2 by xiaotong

dropout withno use of broadcasting

parent b1a9adde
......@@ -82,7 +82,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */
if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP, 0, 2);
x = Dropout(x, dropoutP);
for(int i = 0; i < nlayer; i++){
XTensor att;
......@@ -97,7 +97,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP, 0, 2);
att = Dropout(att, dropoutP);
/* residual connection */
res = Sum(att, x);
......@@ -111,7 +111,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */
if(isTraining && dropoutP > 0)
ende = Dropout(ende, dropoutP, 0, 2);
ende = Dropout(ende, dropoutP);
/* residual connection */
res = Sum(ende, x);
......@@ -125,7 +125,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */
if(isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP, 0, 2);
fnn = Dropout(fnn, dropoutP);
/* residual connection */
res = Sum(fnn, x);
......
......@@ -105,7 +105,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
/* dropout */
if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP, 0, 2);
x = Dropout(x, dropoutP);
for(int i = 0; i < nlayer; i++){
XTensor att;
......@@ -118,7 +118,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP, 0, 2);
att = Dropout(att, dropoutP);
/* residual connection */
res = Sum(att, x);
......@@ -131,7 +131,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
/* dropout */
if(isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP, 0, 2);
fnn = Dropout(fnn, dropoutP);
/* residual connection */
res = Sum(fnn, x);
......
......@@ -60,6 +60,7 @@ XDevice::~XDevice()
cublasDestroy(cublasHandle);
if(stream != NULL)
delete stream;
curandDestroyGenerator(gen);
#endif
}
......@@ -82,6 +83,10 @@ void XDevice::Init(int myDevID)
cudaDeviceProp prop;
cudaSetDevice(myDevID);
curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(gen, seed);
if(cudaGetDeviceProperties(&prop, devID) != cudaSuccess){
XPRINT1(0, stderr, "cannot get GPU(%d) information.", devID);
exit(1);
......
......@@ -112,6 +112,9 @@ public:
/* specify if the handle is initialized */
bool isHandleReady;
/* generater of random numbers */
curandGenerator_t gen;
#endif
......
......@@ -387,7 +387,7 @@ generate data items with a uniform distribution in [lower, upper]
>> lower - lower value of the range
>> upper - upper value of the range
*/
void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
void _SetDataRand(const XTensor * tensor, DTYPE lower, DTYPE upper)
{
CheckNTErrors(upper > lower, "the high value must be greater than low value!");
......@@ -430,6 +430,39 @@ void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
//delete t2;
}
}
/*
generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise
>> tensor - the tensor whose data array would be initialized
>> lower - lower value of the range
>> upper - upper value of the range
>> p - the threshold
>> value - the value we intend to assign to the item
*/
void _SetDataRandP(const XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value)
{
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO");
if (tensor->devID < 0) {
_SetDataRand(tensor, lower, upper);
DTYPE * data = (DTYPE*)tensor->data;
for (int i = 0; i < tensor->unitNum; i++) {
if (data[i] >= p)
data[i] = value;
else
data[i] = 0;
}
}
else {
#ifdef USE_CUDA
_CudaSetDataRandP(tensor, lower, upper, p, value);
#else
ShowNTErrors("Please recompile the code by specifying USE_CUDA");
#endif // USE_CUDA
}
}
/*
generate data items with a normal distribution with specified mean and standard deviation
......
......@@ -185,6 +185,26 @@ void KernelSetDataRandDouble(double * d, int size, DTYPE lower, DTYPE variance)
}
}
/*
set data items to a pre-defined value if its value >= p, set it to 0 otherwise
>> d - pointer to the data array
>> size - size of the array
>> lower - low value of the range
>> variance - the variance of the range
*/
__global__
void KernelSetDataPCut(DTYPE * d, int size, DTYPE p, DTYPE value)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
if (d[i] >= p)
d[i] = value;
else
d[i] = 0;
}
}
/*
set data items along with a given dimension (and keep the remaining items unchanged) - kernel version
>> tensor - the tensor whose data array would be initialized
......@@ -437,7 +457,7 @@ generate data items with a uniform distribution in [lower, upper]
>> lower - lower value of the range
>> upper - upper value of the range
*/
void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
void _CudaSetDataRand(const XTensor * tensor, DTYPE lower, DTYPE upper)
{
CheckNTErrors(upper > lower, "the high value must be greater than low value!");
......@@ -452,17 +472,46 @@ void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
curandGenerator_t gen;
curandCreateGenerator (&gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(gen, time(NULL));
curandGenerator_t & gen = GDevs.GPUs[tensor->devID].gen;
curandGenerateUniform(gen , (float*)tensor->data , tensor->unitNum);
curandDestroyGenerator(gen);
DTYPE variance = upper - lower;
if (tensor->dataType == X_FLOAT)
KernelSetDataRandFloat <<<blocks, threads >>>((float*) tensor->data, tensor->unitNum, lower, variance);
else if (tensor->dataType == X_DOUBLE)
KernelSetDataRandDouble <<<blocks, threads >>>((double*)tensor->data, tensor->unitNum, lower, variance);
if(variance != 1.0F || lower != 0){
if (tensor->dataType == X_FLOAT)
KernelSetDataRandFloat <<<blocks, threads >>>((float*) tensor->data, tensor->unitNum, lower, variance);
else if (tensor->dataType == X_DOUBLE)
KernelSetDataRandDouble <<<blocks, threads >>>((double*)tensor->data, tensor->unitNum, lower, variance);
}
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise
>> tensor - the tensor whose data array would be initialized
>> lower - lower value of the range
>> upper - upper value of the range
>> p - the threshold
>> value - the value we intend to assign to the item
*/
void _CudaSetDataRandP(const XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value)
{
_CudaSetDataRand(tensor, lower, upper);
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(tensor->devID, tensor->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataPCut << <blocks, threads >> >((float*)tensor->data, tensor->unitNum, p, value);
BacktoCudaDev(tensor->devID, devIDBackup);
}
......
......@@ -47,7 +47,11 @@ void _CudaSetDataIndexed(XTensor * source, XTensor * modify, int dim, int index)
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift);
/* generate data items with a uniform distribution in [lower, upper] */
void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
void _CudaSetDataRand(const XTensor * tensor, DTYPE lower, DTYPE upper);
/* generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise */
void _CudaSetDataRandP(const XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value);
/* set the data with an array of offsets */
void _CudaSetDataWithOffset(XTensor * tensor, MTYPE * offsets, DTYPE value, MTYPE num);
......
......@@ -55,7 +55,11 @@ void _SetDataIndexed(XTensor * source, XTensor * modify, int dim, int index);
void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift);
/* generate data items with a uniform distribution in [lower, upper] */
void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
void _SetDataRand(const XTensor * tensor, DTYPE lower, DTYPE upper);
/* generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise */
void _SetDataRandP(const XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value);
/* generate data items with a normal distribution with specified mean and standard deviation */
void _SetDataRandN(XTensor * tensor, DTYPE mean = 0.0F, DTYPE standardDeviation = 1.0F);
......
......@@ -26,6 +26,7 @@
#include "../core/arithmetic/Multiply.h"
#include "../core/arithmetic/MultiplyDim.h"
#include "../core/math/ScaleAndShift.h"
#include "../core/getandset/SetData.h"
namespace nts{ // namespace nts(NiuTrans.Tensor
......@@ -147,17 +148,21 @@ XTensor Dropout(const XTensor &x, DTYPE dropProb, int leadingDim, int leadingDim
XTensor mask;
DTYPE * maskArray = NULL;
DTYPE scaleFactor = (DTYPE)1.0 / ((DTYPE)1.0 - dropProb);
if(leadingDim < 0 && leadingDim2 < 0){
ShowNTErrors("TODO");
XTensor mask;
InitTensor(&mask, &x);
_SetDataRandP(&mask, 0, 1.0F, dropProb, scaleFactor);
return Multiply(x, mask);
}
else if(leadingDim2 < 0){
int n = leadingDim;
CheckNTErrors(n >= 0 && n < x.order, "Wrong leadingDim!");
DTYPE scaleFactor = (DTYPE)1.0 / ((DTYPE)1.0 - dropProb);
/* generate a mask tensor with probability p */
int unitNum = x.dimSize[n];
maskArray = new DTYPE[unitNum];
......@@ -180,8 +185,6 @@ XTensor Dropout(const XTensor &x, DTYPE dropProb, int leadingDim, int leadingDim
CheckNTErrors(n >= 0 && n < x.order, "Wrong leadingDim!");
CheckNTErrors(m >= 0 && m < x.order, "Wrong leadingDim!");
DTYPE scaleFactor = (DTYPE)1.0 / ((DTYPE)1.0 - dropProb);
/* generate a mask tensor with probability p */
int unitNum = x.dimSize[n] * x.dimSize[m];
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论