diff --git a/source/network/Main.cpp b/source/network/Main.cpp index 3764a03..3e6540c 100644 --- a/source/network/Main.cpp +++ b/source/network/Main.cpp @@ -39,7 +39,6 @@ void SumDimTest(); using namespace nts; using namespace fnnlm; using namespace transformer; -using namespace GAN; int main( int argc, const char ** argv ) { @@ -47,9 +46,7 @@ int main( int argc, const char ** argv ) //BackwardTest(); //return 0; - if(argc > 1 && !strcmp(argv[1], "-test")) - Test(); - else if(argc > 1 && !strcmp(argv[1], "-fnnlm")) + if(argc > 1 && !strcmp(argv[1], "-fnnlm")) FNNLMMain(argc - 1, argv + 1); else if(argc > 1 && !strcmp(argv[1], "-t2t")) TransformerMain(argc - 1, argv + 1); diff --git a/source/sample/transformer/T2TEncoder.cpp b/source/sample/transformer/T2TEncoder.cpp index 4508ea3..244981e 100644 --- a/source/sample/transformer/T2TEncoder.cpp +++ b/source/sample/transformer/T2TEncoder.cpp @@ -103,6 +103,8 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes) XTensor fnn; XTensor res; + /* we skip the residual connection for the first layer if + the encoder is used in language modeling. */ if(skipInputRes && i == 0){ /* self attention */ att = attentions[i].Make(x, x, x, mask); diff --git a/source/sample/transformer/T2TModel.cpp b/source/sample/transformer/T2TModel.cpp index bf6c213..fc35f96 100644 --- a/source/sample/transformer/T2TModel.cpp +++ b/source/sample/transformer/T2TModel.cpp @@ -60,7 +60,7 @@ void T2TModel::InitModel(int argc, const char ** argv) if(useMem){ delete mem; - mem = new XMem(devID); + mem = new XMem(devID, UNI_FREE, MILLION * 512, 1024, MILLION * 128); } encoder.InitModel(argc, argv, isLM, isLM ? 1 : 0, devID, mem); @@ -98,7 +98,9 @@ void T2TModel::Make(XTensor &input, XTensor &output) dims[input.order] = len; XTensor mask(input.order + 1, dims, X_FLOAT, 1.0F, input.devID, input.mem); - /* a upper triangular matrix where the cells of the upper triangular are set to -1e-9 */ + /* a upper triangular matrix where the cells of the upper triangular are set to -1e-9. + this matrix can be used to prevent the attention to current or following words in + a given sequence. */ _SetDataLowTri(&mask, 1e9F, -1); _ScaleAndShiftMe(&mask, 1.0F, -1e9F); diff --git a/source/sample/transformer/T2TTrainer.cpp b/source/sample/transformer/T2TTrainer.cpp index ce6ccc8..5bbbd9c 100644 --- a/source/sample/transformer/T2TTrainer.cpp +++ b/source/sample/transformer/T2TTrainer.cpp @@ -53,6 +53,9 @@ initialization */ void T2TTrainer::Init(int argc, const char ** argv) { + bool useMem = false; + + LoadParamBool(argc, argv, "mem", &useMem, useMem); LoadParamInt(argc, argv, "dev", &devID, -1); LoadParamFloat(argc, argv, "lrate", &lrate, 0.001F); LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1); @@ -68,6 +71,11 @@ void T2TTrainer::Init(int argc, const char ** argv) buf = new int[bufSize]; seqLen = new int[bufSize]; seqOffset = new int[bufSize]; + + if(useMem){ + delete mem; + mem = new XMem(devID, UNI_FREE, MILLION * 64, 1024, MILLION * 64); + } } /* @@ -86,6 +94,9 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) float loss = 0; float lr = 0; + model->mem->SetPin(); + mem->SetPin(); + XNet net; double startT = GetClockSec(); @@ -96,11 +107,17 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) CheckNTErrors(file, "cannot open training file!"); wordCount = 0; + + model->mem->BackToPin(); + mem->BackToPin(); /* batch of input sequences */ XTensor batch; + + /* padding */ + XTensor padding; - while(LoadBatch(file, &batch, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc)){ + while(LoadBatch(file, &batch, &padding, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc)){ /* output probabilities */ XTensor output; @@ -108,6 +125,10 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) /* make the network */ model->Make(batch, output); + /* make paddings for the output */ + if(output.GetDim(0) > 1) + PadOutput(&output, &padding); + /* back-propagation for obtaining gradients */ net.Backward(output, batch, CROSSENTROPY); @@ -135,6 +156,9 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n", lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount)); } + + model->mem->BackToPin(); + mem->BackToPin(); } fclose(file); @@ -230,6 +254,7 @@ int T2TTrainer::LoadBuf(FILE * file) load a batch of sequences >> file - the handle to the data file >> batch - the batch +>> padding - padding of the input sequences >> step - the step we go over when move to the next sequence >> vs - vocabulary size >> sBatch - batch size of sequences @@ -237,7 +262,9 @@ load a batch of sequences >> isSorted - indicates whether the sequences are sorted by length >> wCount - word count */ -int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount) +int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding, + int step, int vs, int sBatch, int wBatch, + bool isSorted, int &wCount) { if(nextSeq < 0 || nextSeq >= nseqBuf) LoadBuf(file); @@ -273,12 +300,19 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sB InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem); } + if(padding->order != 2 || padding->GetDim(0) != sc || + padding->GetDim(1) != max){ + InitTensor2D(padding, sc, max, X_FLOAT, devID, mem); + } + batch->SetZeroAll(); + padding->SetZeroAll(); /* this might be slow on GPUs :( */ for(int s = seq; s < seq + sc; s++){ for(int w = 0; w < seqLen[s]; w++){ batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]); + padding->Set2D(1.0F, s - seq, w); wCount++; } } @@ -394,4 +428,35 @@ void T2TTrainer::Update(T2TModel * model, const float lr) } } +/* +do padding on the output +>> output - output tensor of the network +>> padding - padding of a batch of sentences +*/ +void T2TTrainer::PadOutput(XTensor * output, XTensor * padding) +{ + if(output == NULL || padding == NULL) + return; + + int on = output->order; + int * dimso = new int[on]; + + memcpy(dimso, output->dimSize, sizeof(int) * on); + + output->Reshape(output->unitNum/dimso[output->order - 1], dimso[output->order - 1]); + + + XTensor * padding2 = NewTensorBuf(1, &padding->unitNum, X_FLOAT, 1.0F, padding->devID, padding->mem); + + _CopyValues(padding, padding2); + _ScaleAndShiftMe(padding2, 1e9F, -1e9F); + + _SumDim(output, padding2, output, 0); + + output->Reshape(on, dimso); + + delete[] dimso; + DelTensorBuf(padding2); +} + } diff --git a/source/sample/transformer/T2TTrainer.h b/source/sample/transformer/T2TTrainer.h index a1525bf..c7fbe91 100644 --- a/source/sample/transformer/T2TTrainer.h +++ b/source/sample/transformer/T2TTrainer.h @@ -105,13 +105,18 @@ public: int LoadBuf(FILE * file); /* load a batch of sequences */ - int LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount); + int LoadBatch(FILE * file, XTensor * batch, XTensor * padding, + int step, int vs, int sBatch, int wBatch, + bool isSorted, int &wCount); /* get word probabilities for a batch of sequences */ float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs); /* update the model by delta rule */ void Update(T2TModel * model, const float lr); + + /* do padding on the output */ + void PadOutput(XTensor * output, XTensor * padding); }; diff --git a/source/tensor/core/reduce/ReduceMax.cu b/source/tensor/core/reduce/ReduceMax.cu index 39758f6..59976fe 100644 --- a/source/tensor/core/reduce/ReduceMax.cu +++ b/source/tensor/core/reduce/ReduceMax.cu @@ -405,7 +405,7 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long if (vectorSize % 32 != 0) minWarpNum++; warpNum = min(warpNum, minWarpNum); - grid.x = vectorNum; + grid.x = (unsigned int)vectorNum; grid.y = 1; grid.z = 1; block.x = 1; diff --git a/source/tensor/core/reduce/ReduceSum.cu b/source/tensor/core/reduce/ReduceSum.cu index 18bd740..ee7507b 100644 --- a/source/tensor/core/reduce/ReduceSum.cu +++ b/source/tensor/core/reduce/ReduceSum.cu @@ -629,7 +629,7 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long if (vectorSize % 32 != 0) minWarpNum++; warpNum = min(warpNum, minWarpNum); - grid.x = vectorNum; + grid.x = (unsigned int)vectorNum; grid.y = 1; grid.z = 1; block.x = 1;