Commit 102db468 by xuchen

better code for dropout function (by broadcasting)

parent 4336f2f9
...@@ -125,17 +125,8 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask) ...@@ -125,17 +125,8 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
dot = Linear(dot, 1.0F/(float)sqrt((float)dk)); dot = Linear(dot, 1.0F/(float)sqrt((float)dk));
//if(llnum == 1)
// dot.Dump(tf, "dot:");
scalar = Softmax(dot, -1); scalar = Softmax(dot, -1);
//if(llnum == 1)
// scalar.Dump(tf, "scalar:");
//if(ignored > 0)
// _SetDataDim(&scalar, 0, ignored, scalar.order - 2, 1e-9F);
att = BMMul(scalar, vheads); att = BMMul(scalar, vheads);
/* concatenate the heads */ /* concatenate the heads */
......
...@@ -73,6 +73,9 @@ public: ...@@ -73,6 +73,9 @@ public:
special design for the attention model. */ special design for the attention model. */
int ignored; int ignored;
/* indicates whether the model is used for training */
bool isTraining;
public: public:
/* constructor */ /* constructor */
T2TAttention(); T2TAttention();
......
...@@ -63,6 +63,7 @@ void AttEncoder::InitModel(int argc, const char ** argv, ...@@ -63,6 +63,7 @@ void AttEncoder::InitModel(int argc, const char ** argv,
LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "esize", &eSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "esize", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "vsize", &vSize, -1); LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamFloat(argc, argv, "dropout", &dropoutP, 0);
CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!"); CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!");
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsize\""); CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsize\"");
...@@ -89,29 +90,34 @@ make the encoding network ...@@ -89,29 +90,34 @@ make the encoding network
>> input - the input tensor of the encoder >> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid >> mask - the mask that indicate each position is valid
>> skipInputRes - indicates whether we skip the residual connection of the first layer >> skipInputRes - indicates whether we skip the residual connection of the first layer
>> isTraining - indicates whether the model is for training
<< return - the output tensor of the encoder << return - the output tensor of the encoder
*/ */
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes) XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining)
{ {
XTensor x; XTensor x;
x = embedder.Make(input); x = embedder.Make(input);
/* dropout */
if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
for(int i = 0; i < nlayer; i++){ for(int i = 0; i < nlayer; i++){
XTensor att; XTensor att;
XTensor ln; XTensor ln;
XTensor fnn; XTensor fnn;
XTensor res; XTensor res;
llnum = -1;
/* we skip the residual connection for the first layer if /* we skip the residual connection for the first layer if
the encoder is used in language modeling. */ the encoder is used in language modeling. */
if(skipInputRes && i == 0){ if(skipInputRes && i == 0){
/* self attention */ /* self attention */
att = attentions[i].Make(x, x, x, mask); att = attentions[i].Make(x, x, x, mask);
/* TODO: dropout */ /* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
/* layer normalization */ /* layer normalization */
x = attLayerNorms[i].Make(att); x = attLayerNorms[i].Make(att);
...@@ -121,27 +127,32 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes) ...@@ -121,27 +127,32 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes)
/* self attention */ /* self attention */
att = attentions[i].Make(x, x, x, mask); att = attentions[i].Make(x, x, x, mask);
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
/* residual connection */ /* residual connection */
res = Sum(att, x); res = Sum(att, x);
/* TODO: dropout */
/* layer normalization */ /* layer normalization */
x = attLayerNorms[i].Make(res); x = attLayerNorms[i].Make(res);
llnum = -1;
} }
/* fnn */ /* fnn */
fnn = fnns[i].Make(x); fnn = fnns[i].Make(x);
/* dropout */
if(isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP);
/* residual connection */ /* residual connection */
res = Sum(fnn, x); res = Sum(fnn, x);
/* TODO: dropout */
/* layer normalization */ /* layer normalization */
x = fnnLayerNorms[i].Make(res); x = fnnLayerNorms[i].Make(res);
if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
} }
return x; return x;
......
...@@ -40,7 +40,7 @@ class T2TEncoder ...@@ -40,7 +40,7 @@ class T2TEncoder
{ {
public: public:
virtual virtual
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes) = 0; XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining) = 0;
}; };
/* /*
...@@ -49,7 +49,7 @@ the encoder based on RNN ...@@ -49,7 +49,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder class RNNEncoder : T2TEncoder
{ {
public: public:
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes); XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining);
}; };
...@@ -77,6 +77,9 @@ public: ...@@ -77,6 +77,9 @@ public:
/* vocabulary size */ /* vocabulary size */
int vSize; int vSize;
/* dropout probability */
DTYPE dropoutP;
/* some positions can be ignored in attention. this is useful in lm where the first position needs /* some positions can be ignored in attention. this is useful in lm where the first position needs
special design for the attention model. */ special design for the attention model. */
int ignored; int ignored;
...@@ -115,7 +118,7 @@ public: ...@@ -115,7 +118,7 @@ public:
int myDevID = -1, XMem * myMem = NULL); int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */ /* make the encoding network */
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes); XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining);
}; };
......
...@@ -58,7 +58,7 @@ void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem) ...@@ -58,7 +58,7 @@ void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &outSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &outSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "fnnh", &hSize, DEFAULT_EMBEDDING_SIZE * 4);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.1F); LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.1F);
InitTensor2D(&w1, inSize, hSize, X_FLOAT, devID, mem); InitTensor2D(&w1, inSize, hSize, X_FLOAT, devID, mem);
......
...@@ -77,11 +77,12 @@ make the encoding network ...@@ -77,11 +77,12 @@ make the encoding network
>> input - input tensor >> input - input tensor
>> mask - the mask for positions that are/not involved in computation >> mask - the mask for positions that are/not involved in computation
>> skipInputRes - indicates whether we skip the residual connection of the first layer >> skipInputRes - indicates whether we skip the residual connection of the first layer
>> isTraining - indicates whether we are training the model
<< return - encoding result << return - encoding result
*/ */
XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes) XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining)
{ {
return encoder.Make(input, mask, skipInputRes); return encoder.Make(input, mask, skipInputRes, isTraining);
} }
/* /*
...@@ -89,8 +90,9 @@ make the entire network (with the output softmax layer) ...@@ -89,8 +90,9 @@ make the entire network (with the output softmax layer)
>> input - input tensor >> input - input tensor
>> output - output tensor (distribution) >> output - output tensor (distribution)
>> padding - padding of the sequences >> padding - padding of the sequences
>> isTraining - indicates whether the model is for training
*/ */
void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding) void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding, bool isTraining)
{ {
XTensor encoding; XTensor encoding;
...@@ -134,7 +136,7 @@ void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding) ...@@ -134,7 +136,7 @@ void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding)
//_Sum(&mask, padding3, &mask); //_Sum(&mask, padding3, &mask);
encoding = MakeEncoding(input, mask, true); encoding = MakeEncoding(input, mask, true, isTraining);
outputLayer.Make(encoding, output); outputLayer.Make(encoding, output);
delete[] dims; delete[] dims;
......
...@@ -69,10 +69,10 @@ public: ...@@ -69,10 +69,10 @@ public:
void InitModel(int argc, const char ** argv); void InitModel(int argc, const char ** argv);
/* make the encoding network */ /* make the encoding network */
XTensor MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes); XTensor MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining);
/* make the entire network (with the output softmax layer) */ /* make the entire network (with the output softmax layer) */
void Make(XTensor &input, XTensor &output, XTensor &padding); void Make(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
/* get parameter matrics */ /* get parameter matrics */
void GetParams(XList &list); void GetParams(XList &list);
......
...@@ -149,7 +149,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -149,7 +149,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
XTensor output; XTensor output;
/* make the network */ /* make the network */
model->Make(batch, output, padding); model->Make(batch, output, padding, true);
/* make paddings for the output */ /* make paddings for the output */
if(output.GetDim(0) > 1) if(output.GetDim(0) > 1)
...@@ -167,16 +167,6 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -167,16 +167,6 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
/* get probabilities */ /* get probabilities */
float prob = GetProb(&output, &gold, NULL); float prob = GetProb(&output, &gold, NULL);
MTYPE totalUsed = 0;
MTYPE totalSize = 0;
for (int i = 0; i <= mem->curBlockID; i++) {
totalSize += mem->blocks[i].size;
totalUsed += mem->blocks[i].used;
}
//fprintf(stderr, "%d(%ld,%ld,%f)\n", mem->curBlockID, totalUsed, totalSize, (float)totalUsed/totalSize);
loss += -prob; loss += -prob;
wordCount += wc; wordCount += wc;
wordCountTotal += wc; wordCountTotal += wc;
...@@ -209,6 +199,8 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -209,6 +199,8 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
fclose(tf); fclose(tf);
epoch = MIN(epoch, nepoch);
XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n", XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n",
lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount)); lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n", XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
...@@ -271,7 +263,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -271,7 +263,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
XTensor output; XTensor output;
/* make the network */ /* make the network */
model->Make(batch, output, padding); model->Make(batch, output, padding, false);
int bSize = batch.GetDim(0); int bSize = batch.GetDim(0);
int length = batch.GetDim(1); int length = batch.GetDim(1);
...@@ -333,11 +325,19 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -333,11 +325,19 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
char line[MAX_SEQUENCE_LENGTH]; char line[MAX_SEQUENCE_LENGTH];
struct SampleNode
{
int id;
int size;
};
/* /*
load data to buffer load data to buffer
>> file - where to load data >> file - where to load data
>> isSorted - indicates whether the samples are sorted by length
>> step - the number of sequences we go over when move to the next sample
*/ */
int T2TTrainer::LoadBuf(FILE * file) int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
{ {
int lineCount = 0; int lineCount = 0;
int seqCount = 0; int seqCount = 0;
...@@ -403,6 +403,17 @@ int T2TTrainer::LoadBuf(FILE * file) ...@@ -403,6 +403,17 @@ int T2TTrainer::LoadBuf(FILE * file)
nseqBuf = seqCount; nseqBuf = seqCount;
nextSeq = 0; nextSeq = 0;
if (isSorted) {
SampleNode * nodes = new SampleNode[seqCount];
int count = 0;
for (int i = 0; i < seqCount; i += step) {
nodes[count].id = count;
nodes[count].size = seqLen[i];
count++;
}
delete[] nodes;
}
return lineCount; return lineCount;
} }
...@@ -438,7 +449,7 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM, ...@@ -438,7 +449,7 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM,
int devID, XMem * mem) int devID, XMem * mem)
{ {
if(nextSeq < 0 || nextSeq >= nseqBuf) if(nextSeq < 0 || nextSeq >= nseqBuf)
LoadBuf(file); LoadBuf(file, isSorted);
int seq = MAX(nextSeq, 0); int seq = MAX(nextSeq, 0);
int wc = 0; int wc = 0;
......
...@@ -118,7 +118,7 @@ public: ...@@ -118,7 +118,7 @@ public:
void Test(const char * fn, const char * ofn, T2TModel * model); void Test(const char * fn, const char * ofn, T2TModel * model);
/* load data to buffer */ /* load data to buffer */
int LoadBuf(FILE * file); int LoadBuf(FILE * file, bool isSorted, int step);
/* clear data buffer */ /* clear data buffer */
void ClearBuf(); void ClearBuf();
......
...@@ -747,6 +747,64 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex) ...@@ -747,6 +747,64 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
CheckNTErrors(nodeNumUsed < nodeNum, "No enough index nodes for the memory pool!"); CheckNTErrors(nodeNumUsed < nodeNum, "No enough index nodes for the memory pool!");
} }
/*if(testxmemid == 30){
recordp = result;
}
if(curBlockID >= 25){
MHeader * head = blocks[25].head;
while(head != NULL){
fprintf(stderr, "head: %ld %ld\n", head->indexNode->pReal, head->indexNode->size);
head = head->next;
}
}
if(testxmemid == 32){
int nnn = 0;
}
if(recordp != NULL){
MTYPE size = mySize;
if(size <= minSizeIndex[0])
size = minSizeIndex[0];
MPieceNode * entry = NULL;
MPieceNode * node = NULL;
MPieceNode * hit = NULL;
MPieceNode * last = NULL;
entry = memIndex + indexEntryNum + FindIndexEntry(size);
last = entry;
node = entry->next;
while(node != NULL){
CheckNTErrors(node->pre == last, "Something is wrong!");
CheckNTErrors(last->next == node, "Something is wrong!");
CheckNTErrors(node->head.state == 2, "Something is wrong!");
last = node;
if(node->size == 0){
MPieceNode * next = node->next;
RemoveFreeIndexNode(node, entry);
node = next;
ShowNTErrors("Something is wrong!");
}
else{
CheckNTErrors(node->pReal != NULL, "Illegal pointer!");
if(node->pReal == recordp){
hit = node;
break;
}
node = node->next;
}
}
if(hit == NULL){
int nnn = 0;
}
}*/
return result; return result;
} }
...@@ -918,6 +976,8 @@ void XMem::ReleaseStandard(int myDevID, void * p, MTYPE size) ...@@ -918,6 +976,8 @@ void XMem::ReleaseStandard(int myDevID, void * p, MTYPE size)
hit->head.state = 1; hit->head.state = 1;
RemoveAllocIndexNode(hit); RemoveAllocIndexNode(hit);
hit->size = (char*)hit->p + hit->head.size - (char*)GetPitchedAddress((char*)hit->p, MY_PITCH);
AddFreeIndexNode(hit); AddFreeIndexNode(hit);
blocks[hit->head.blockID].used -= hit->head.size; blocks[hit->head.blockID].used -= hit->head.size;
...@@ -981,8 +1041,9 @@ void XMem::RebuildIndex() ...@@ -981,8 +1041,9 @@ void XMem::RebuildIndex()
/* make a new index node */ /* make a new index node */
MPieceNode * newNode = memIndex2 + nodeNumUsed2++; MPieceNode * newNode = memIndex2 + nodeNumUsed2++;
newNode->p = p; newNode->p = p;
newNode->size = (char*)p + head->size - newNode->size = node->size;
( head->state == 1 ? (char*)GetPitchedAddress((char*)p, MY_PITCH) : (char*)head->indexNode->pReal); //newNode->size = (char*)p + head->size -
// ( head->state == 1 ? (char*)GetPitchedAddress((char*)p, MY_PITCH) : (char*)head->indexNode->pReal);
newNode->pre = NULL; newNode->pre = NULL;
newNode->next = NULL; newNode->next = NULL;
......
...@@ -553,10 +553,16 @@ void XTensor::SetZeroAll(XStream * stream) ...@@ -553,10 +553,16 @@ void XTensor::SetZeroAll(XStream * stream)
#ifdef USE_CUDA #ifdef USE_CUDA
int size = sizeof(int) + (sizeof(int)+sizeof(DTYPE)) * unitNumNonZero; int size = sizeof(int) + (sizeof(int)+sizeof(DTYPE)) * unitNumNonZero;
int devIDBackup = 0;
cudaGetDevice(&devIDBackup);
cudaSetDevice(devID);
if(stream == NULL) if(stream == NULL)
cudaMemset(data, 0, size); cudaMemset(data, 0, size);
else else
cudaMemsetAsync(data, 0, size, stream->stream); cudaMemsetAsync(data, 0, size, stream->stream);
cudaSetDevice(devIDBackup);
#endif #endif
} }
else else
...@@ -567,10 +573,16 @@ void XTensor::SetZeroAll(XStream * stream) ...@@ -567,10 +573,16 @@ void XTensor::SetZeroAll(XStream * stream)
else{ else{
if(devID >= 0){ if(devID >= 0){
#ifdef USE_CUDA #ifdef USE_CUDA
int devIDBackup = 0;
cudaGetDevice(&devIDBackup);
cudaSetDevice(devID);
if(stream == NULL) if(stream == NULL)
cudaMemset(data, 0, unitNum * unitSize); cudaMemset(data, 0, unitNum * unitSize);
else else
cudaMemsetAsync(data, 0, unitNum * unitSize, stream->stream); cudaMemsetAsync(data, 0, unitNum * unitSize, stream->stream);
cudaSetDevice(devIDBackup);
#endif #endif
} }
else else
......
...@@ -171,14 +171,12 @@ void _CudaMultiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alph ...@@ -171,14 +171,12 @@ void _CudaMultiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alph
if (alpha == 0) { if (alpha == 0) {
KernelMulElementWiseTensorDynamic<0> << <blocks, threads >> > KernelMulElementWiseTensorDynamic<0> << <blocks, threads >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, 0, ((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, 0,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
blockNum);
} }
else { else {
KernelMulElementWiseTensorDynamic<1> << <blocks, threads >> > KernelMulElementWiseTensorDynamic<1> << <blocks, threads >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, alpha, ((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, alpha,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
blockNum);
} }
} }
} }
......
...@@ -25,120 +25,59 @@ ...@@ -25,120 +25,59 @@
#include "Dropout.h" #include "Dropout.h"
#include "Dropout.cuh" #include "Dropout.cuh"
#include "../core/arithmetic/Multiply.h" #include "../core/arithmetic/Multiply.h"
#include "../core/arithmetic/MultiplyDim.h"
#include "../core/math/ScaleAndShift.h" #include "../core/math/ScaleAndShift.h"
namespace nts{ // namespace nts(NiuTrans.Tensor namespace nts{ // namespace nts(NiuTrans.Tensor
/* /*
generate a random bernoulli number
*/
DTYPE RandomBernoulli(DTYPE prob)
{
return (DTYPE)rand()/(DTYPE)RAND_MAX > prob ? (DTYPE)1.0 : (DTYPE)0.0;
}
/*
dropout function dropout function
It randomly zeroes some of the elements of the input tensor
with probability p via a Bernoulli distribution.
During training, randomly zeroes some of the elements of the input tensor See "Improving neural networks by preventing co-adaptation of feature detectors"
with probability p using samples from a Bernoulli distribution. for more details.
The elements to zero are randomized on every forward call.
This has proven to be an effective technique for regularization and Here, the output is scaled by a factor of \frac{1}{1-p} so that we do not need
preventing the co-adaptation of neurons as described in the paper to mark the tensor with probability p in the inference phase. Instead we perform
"Improving neural networks by preventing co-adaptation of feature detectors". the same inference procedure as that with no use of dropout on the test data.
Furthermore, the outputs are scaled by a factor of \frac{1}{1-p} during training.
This means that during evaluation the module simply computes an identity function.
>> x - input tensor >> x - input tensor
>> y - output tensor >> y - output tensor
>> prob - probability to set an element zero >> seed - random seed
>> dropProb - probability to set an element to zero
>> leadingDim - the dimension which we generate the random numbers and perform broadcasting
*/ */
void _Dropout(const XTensor *x, XTensor *y, unsigned int seed, DTYPE prob) void _Dropout(const XTensor * x, XTensor * y, unsigned int seed, DTYPE dropProb, int leadingDim)
{ {
CheckNTErrors(prob >= 0.0 && prob <= 1.0, "The probability must be 0-1!"); CheckNTErrors(dropProb >= 0.0 && dropProb <= 1.0, "The probability must be 0-1!");
DTYPE scaleFactor = (DTYPE)1.0 / ((DTYPE)1.0 - prob);
/* generate a mask tensor again with special probability */
srand(seed);
int unitNum = x->unitNum;
DTYPE * maskArray = new DTYPE[unitNum];
for (int i = 0; i < unitNum; i++)
maskArray[i] = RandomBernoulli(prob);
XTensor * maskTensor = NewTensorBuf(x, x->devID, x->mem);
maskTensor->SetData(maskArray, unitNum);
#ifdef USE_CUDA
if(x->devID >=0 || y->devID >= 0){
_CudaDropout(x, y, maskTensor, scaleFactor);
DelTensorBuf(maskTensor);
delete[] maskArray;
return;
}
#endif
XTensor * inter = NewTensorBuf(x, x->devID, x->mem);
_Multiply(x, maskTensor, inter);
_ScaleAndShift(inter, y, scaleFactor, 0);
DelTensorBuf(inter);
DelTensorBuf(maskTensor);
delete[] maskArray;
}
/* int n = leadingDim < 0 ? x->order - 1 : leadingDim;
dropout function (return a XTensor structure)
make a new tensor to keep the result and return it
During training, randomly zeroes some of the elements of the input tensor CheckNTErrors(n >= 0 && n < x->order, "Wrong leadingDim!");
with probability p using samples from a Bernoulli distribution.
The elements to zero are randomized on every forward call.
This has proven to be an effective technique for regularization and DTYPE scaleFactor = (DTYPE)1.0 / ((DTYPE)1.0 - dropProb);
preventing the co-adaptation of neurons as described in the paper
"Improving neural networks by preventing co-adaptation of feature detectors".
Furthermore, the outputs are scaled by a factor of \frac{1}{1-p} during training.
This means that during evaluation the module simply computes an identity function.
>> x - input tensor
>> y - output tensor
>> prob - probability to set an element zero
*/
XTensor Dropout(const XTensor &x, DTYPE prob)
{
XTensor y(&x);
y.SetTMP();
DTYPE scaleFactor = (DTYPE)1.0 / ((DTYPE)1.0 - prob);
/* generate a mask tensor again with special probability */ /* generate a mask tensor again with special probability */
srand((unsigned int)time(NULL)); int unitNum = x->dimSize[n];
int unitNum = x.unitNum;
DTYPE * maskArray = new DTYPE[unitNum]; DTYPE * maskArray = new DTYPE[unitNum];
srand(seed);
for (int i = 0; i < unitNum; i++) for (int i = 0; i < unitNum; i++)
maskArray[i] = RandomBernoulli(prob); maskArray[i] = RandomBernoulli(dropProb, scaleFactor);
XTensor maskTensor(&x); XTensor * mask = NewTensor1D(unitNum, x->dataType, x->devID, x->mem);
maskTensor.SetData(maskArray, unitNum); mask->SetData(maskArray, unitNum);
XTensor inter; /* call Multiply function for mask */
inter = Multiply(x, maskTensor); _MultiplyDim(x, mask, y, n, 0);
y = ScaleAndShift(inter, scaleFactor, 0);
delete mask;
delete[] maskArray; delete[] maskArray;
///* tensor connection */
//XLink::MakeLink(&x, NULL, &y, FUNC_DROPOUT);
//XLink::AddParamToHead(&y, prob);
return y;
} }
/* /*
backward computation of dropout function backward computation of the dropout function
dE/dx = dE/dy * dy/dx dE/dx = dE/dy * dy/dx
...@@ -146,48 +85,86 @@ dE/dx = dE/dy * dy/dx ...@@ -146,48 +85,86 @@ dE/dx = dE/dy * dy/dx
>> x - input of the dropout function >> x - input of the dropout function
>> dedy - dE/dy >> dedy - dE/dy
>> dedx - dE/dx >> dedx - dE/dx
>> prob - probability to set an element zero >> seed - random seed
>> dropProb - probability to set an element to zero
>> leadingDim - the dimension which we generate the random numbers and perform broadcasting
*/ */
void _DropoutBackward(const XTensor * y, const XTensor * x, void _DropoutBackward(const XTensor * y, const XTensor * x,
const XTensor * dedy, XTensor * dedx, const XTensor * dedy, XTensor * dedx,
unsigned int seed, DTYPE prob) unsigned int seed, DTYPE dropProb, int leadingDim)
{ {
CheckNTErrors(dropProb >= 0.0 && dropProb <= 1.0, "The probability must be 0-1!");
int n = leadingDim < 0 ? x->order - 1 : leadingDim;
CheckNTErrors(n >= 0 && n < x->order, "Wrong leadingDim!");
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE) if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE)
{ {
int unitNum = y->unitNum; DTYPE scaleFactor = (DTYPE)1.0F / ((DTYPE)1.0F - dropProb);
DTYPE scaleFactor = (DTYPE)1.0F / ((DTYPE)1.0F - prob);
/* generate a mask tensor again with special probability */ /* generate a mask tensor again with special probability */
srand(seed); int unitNum = x->dimSize[n];
DTYPE * maskArray = new DTYPE[unitNum]; DTYPE * maskArray = new DTYPE[unitNum];
srand(seed);
for (int i = 0; i < unitNum; i++) for (int i = 0; i < unitNum; i++)
maskArray[i] = RandomBernoulli(prob); maskArray[i] = RandomBernoulli(dropProb, scaleFactor);
XTensor * maskTensor = NewTensorBuf(x, x->devID, x->mem); XTensor * mask = NewTensor1D(unitNum, x->dataType, x->devID, x->mem);
maskTensor->SetData(maskArray, unitNum); mask->SetData(maskArray, unitNum);
#ifdef USE_CUDA /* call MultiplyDim function for mask */
if(x->devID >= 0 || y->devID >= 0){ _MultiplyDim(dedy, mask, dedx, n, 0);
_CudaDropoutBackward(y, x, dedy, dedx, maskTensor, scaleFactor);
DelTensorBuf(maskTensor); delete mask;
delete[] maskArray; delete[] maskArray;
return;
} }
#endif else
ShowNTErrors("TODO!");
}
/*
dropout function (we make tensor connections here)
It randomly zeroes some of the elements of the input tensor
with probability p via a Bernoulli distribution.
See "Improving neural networks by preventing co-adaptation of feature detectors"
for more details.
Here, the output is scaled by a factor of \frac{1}{1-p} so that we do not need
to mark the tensor with probability p in the inference phase. Instead we perform
the same inference procedure as that with no use of dropout on the test data.
DTYPE * dedyp = (DTYPE*)dedy->data; >> x - input tensor
DTYPE * dedxp = (DTYPE*)dedx->data; >> dropProb - probability to set an element to zero
>> leadingDim - the dimension which we generate the random numbers and perform broadcasting
*/
XTensor Dropout(const XTensor &x, DTYPE dropProb, int leadingDim)
{
CheckNTErrors(dropProb >= 0.0 && dropProb <= 1.0, "The probability must be 0-1!");
int n = leadingDim < 0 ? x.order - 1 : leadingDim;
/* dE/dx = dE/dy * dy/dx */ CheckNTErrors(n >= 0 && n < x.order, "Wrong leadingDim!");
for(int i = 0; i < unitNum; i++)
dedxp[i] = dedyp[i] * maskArray[i] * scaleFactor; DTYPE scaleFactor = (DTYPE)1.0 / ((DTYPE)1.0 - dropProb);
/* generate a mask tensor with probability p */
int unitNum = x.dimSize[n];
DTYPE * maskArray = new DTYPE[unitNum];
srand((unsigned int)time(NULL));
for (int i = 0; i < unitNum; i++)
maskArray[i] = RandomBernoulli(dropProb, scaleFactor);
XTensor mask;
InitTensor1D(&mask, unitNum, x.dataType, x.devID, x.mem);
mask.SetData(maskArray, unitNum);
DelTensorBuf(maskTensor);
delete[] maskArray; delete[] maskArray;
}
else return MultiplyDim(x, mask, n, 0);
ShowNTErrors("TODO!");
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -27,16 +27,22 @@ ...@@ -27,16 +27,22 @@
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
/* dropout function */ /* generate a random bernoulli number */
void _Dropout(const XTensor * x, XTensor * y, unsigned int seed, DTYPE prob = 0.5); inline DTYPE RandomBernoulli(DTYPE dropProb, DTYPE value)
{
return (DTYPE)rand()/(DTYPE)RAND_MAX >= dropProb ? (DTYPE)value : 0;
}
/* dropout function */ /* dropout function */
XTensor Dropout(const XTensor &x, DTYPE prob = 0.5); void _Dropout(const XTensor * x, XTensor * y, unsigned int seed, DTYPE dropProb, int leadingDim = -1);
/* de/dx */ /* de/dx */
void _DropoutBackward(const XTensor * y, const XTensor * x, void _DropoutBackward(const XTensor * y, const XTensor * x,
const XTensor * dedy, XTensor * dedx, const XTensor * dedy, XTensor * dedx,
unsigned int seed, DTYPE prob = 0.5); unsigned int seed, DTYPE dropProb, int leadingDim = -1);
/* dropout function */
XTensor Dropout(const XTensor &x, DTYPE dropProb, int leadingDim = -1);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "../XTensor.h" #include "../XTensor.h"
#include "Dropout.h"
#include "HardTanH.h" #include "HardTanH.h"
#include "Identity.h" #include "Identity.h"
#include "LogSoftmax.h" #include "LogSoftmax.h"
......
...@@ -31,10 +31,11 @@ case 1: test Dropout function. ...@@ -31,10 +31,11 @@ case 1: test Dropout function.
bool TestDropout1() bool TestDropout1()
{ {
/* a input tensor of size (4, 5) */ /* a input tensor of size (4, 5) */
int order = 2; int order = 3;
int * dimSize = new int[order]; int * dimSize = new int[order];
dimSize[0] = 40; dimSize[0] = 40;
dimSize[1] = 50; dimSize[1] = 50;
dimSize[2] = 60;
int unitNum = 1; int unitNum = 1;
for (int i = 0; i < order; i++) for (int i = 0; i < order; i++)
...@@ -49,14 +50,14 @@ bool TestDropout1() ...@@ -49,14 +50,14 @@ bool TestDropout1()
XTensor yUser; XTensor yUser;
/* initialize variables */ /* initialize variables */
x->SetDataRand(0, 1); _SetDataFixedFloat(x, 1.0F);
y->SetZeroAll(); y->SetZeroAll();
/* call Dropout function */ /* call Dropout function */
float prob = 0.2F; float drop_prob = 0.2F;
int seed = 20; int seed = 20;
_Dropout(x, y, seed, prob); _Dropout(x, y, seed, drop_prob);
yUser = Dropout(*x); yUser = Dropout(*x, drop_prob);
/* check result */ /* check result */
int zeroNum1 = 0; int zeroNum1 = 0;
...@@ -73,9 +74,9 @@ bool TestDropout1() ...@@ -73,9 +74,9 @@ bool TestDropout1()
} }
printf("CPU Test:\n"); printf("CPU Test:\n");
printf("In tensor y, there are %d units.\n", unitNum); printf("In tensor y, there are %d units.\n", unitNum);
printf("There are %d zero units by Dropout layer with probability %.2f.\n", zeroNum1, prob); printf("There are %d zero units by Dropout layer with probability %.2f.\n", zeroNum1, drop_prob);
printf("In tensor yUser, there are %d units.\n", unitNum); printf("In tensor yUser, there are %d units.\n", unitNum);
printf("There are %d zero units by Dropout layer with default probability %.2f.\n", zeroNum2, 0.5F); printf("There are %d zero units by Dropout layer with default probability %.2f.\n", zeroNum2, drop_prob);
#ifdef USE_CUDA #ifdef USE_CUDA
/* GPU test */ /* GPU test */
...@@ -87,12 +88,12 @@ bool TestDropout1() ...@@ -87,12 +88,12 @@ bool TestDropout1()
XTensor yUserGPU; XTensor yUserGPU;
/* initialize variables */ /* initialize variables */
xGPU->SetDataRand(0, 1); _SetDataFixedFloat(xGPU, 1.0F);
yGPU->SetZeroAll(); yGPU->SetZeroAll();
/* call Dropout function */ /* call Dropout function */
_Dropout(xGPU, yGPU, seed, prob); _Dropout(xGPU, yGPU, seed, drop_prob);
yUserGPU = Dropout(*xGPU); yUserGPU = Dropout(*xGPU, drop_prob);
/* check result */ /* check result */
zeroNum1 = 0; zeroNum1 = 0;
...@@ -109,9 +110,9 @@ bool TestDropout1() ...@@ -109,9 +110,9 @@ bool TestDropout1()
} }
printf("CPU Test:\n"); printf("CPU Test:\n");
printf("In tensor y, there are %d units.\n", unitNum); printf("In tensor y, there are %d units.\n", unitNum);
printf("There are %d zero units by Dropout layer with probability %.2f.\n", zeroNum1, prob); printf("There are %d zero units by Dropout layer with probability %.2f.\n", zeroNum1, drop_prob);
printf("In tensor yUser, there are %d units.\n", unitNum); printf("In tensor yUser, there are %d units.\n", unitNum);
printf("There are %d zero units by Dropout layer with default probability %.2f.\n", zeroNum2, 0.5F); printf("There are %d zero units by Dropout layer with default probability %.2f.\n", zeroNum2, drop_prob);
/* destroy variables */ /* destroy variables */
delete x; delete x;
...@@ -159,13 +160,13 @@ bool TestDropout2() ...@@ -159,13 +160,13 @@ bool TestDropout2()
_SetDataFixedFloat(x, 1.0F); _SetDataFixedFloat(x, 1.0F);
y->SetZeroAll(); y->SetZeroAll();
dedx->SetZeroAll(); dedx->SetZeroAll();
_SetDataFixedFloat(dedy, 1.0F); _SetDataFixedFloat(dedy, 1.5F);
/* call Dropout function */ /* call Dropout function */
float prob = 0.5F; float drop_prob = 0.5F;
int seed = 1; int seed = 1;
_Dropout(x, y, seed, prob); _Dropout(x, y, seed, drop_prob);
_DropoutBackward(y, x, dedy, dedx, 1, prob); _DropoutBackward(y, x, dedy, dedx, 1, drop_prob);
/* check result */ /* check result */
y->Dump(stderr, "y"); y->Dump(stderr, "y");
...@@ -185,11 +186,11 @@ bool TestDropout2() ...@@ -185,11 +186,11 @@ bool TestDropout2()
_SetDataFixedFloat(xGPU, 1.0F); _SetDataFixedFloat(xGPU, 1.0F);
yGPU->SetZeroAll(); yGPU->SetZeroAll();
dedxGPU->SetZeroAll(); dedxGPU->SetZeroAll();
_SetDataFixedFloat(dedyGPU, 1.0F); _SetDataFixedFloat(dedyGPU, 1.5F);
/* call Dropout function */ /* call Dropout function */
_Dropout(xGPU, yGPU, seed, prob); _Dropout(xGPU, yGPU, seed, drop_prob);
_DropoutBackward(yGPU, xGPU, dedyGPU, dedxGPU, 1, prob); _DropoutBackward(yGPU, xGPU, dedyGPU, dedxGPU, 1, drop_prob);
/* check result */ /* check result */
yGPU->Dump(stderr, "yGPU"); yGPU->Dump(stderr, "yGPU");
......
...@@ -65,9 +65,10 @@ bool TestXMemCase1() ...@@ -65,9 +65,10 @@ bool TestXMemCase1()
for (int i = 0; i < testNum * scalar; i++) { for (int i = 0; i < testNum * scalar; i++) {
testxmemid++; testxmemid++;
//fprintf(stderr, "%d %d\n", testxmemid, ok);
int j = rand() % caseNum; int j = rand() % caseNum;
//fprintf(stderr, "%d %d %d\n", testxmemid, j, ok);
if (p[j] == NULL) { if (p[j] == NULL) {
p[j] = (int*)mem.AllocStandard(mem.devID, size[j] * sizeof(int)); p[j] = (int*)mem.AllocStandard(mem.devID, size[j] * sizeof(int));
for (int k = 0; k < size[j]; k++) for (int k = 0; k < size[j]; k++)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论