Commit e1630c28 by xiaotong

padding for batch training of t2t

parent 6793e025
...@@ -39,7 +39,6 @@ void SumDimTest(); ...@@ -39,7 +39,6 @@ void SumDimTest();
using namespace nts; using namespace nts;
using namespace fnnlm; using namespace fnnlm;
using namespace transformer; using namespace transformer;
using namespace GAN;
int main( int argc, const char ** argv ) int main( int argc, const char ** argv )
{ {
...@@ -47,9 +46,7 @@ int main( int argc, const char ** argv ) ...@@ -47,9 +46,7 @@ int main( int argc, const char ** argv )
//BackwardTest(); //BackwardTest();
//return 0; //return 0;
if(argc > 1 && !strcmp(argv[1], "-test")) if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
Test();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
FNNLMMain(argc - 1, argv + 1); FNNLMMain(argc - 1, argv + 1);
else if(argc > 1 && !strcmp(argv[1], "-t2t")) else if(argc > 1 && !strcmp(argv[1], "-t2t"))
TransformerMain(argc - 1, argv + 1); TransformerMain(argc - 1, argv + 1);
......
...@@ -103,6 +103,8 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes) ...@@ -103,6 +103,8 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes)
XTensor fnn; XTensor fnn;
XTensor res; XTensor res;
/* we skip the residual connection for the first layer if
the encoder is used in language modeling. */
if(skipInputRes && i == 0){ if(skipInputRes && i == 0){
/* self attention */ /* self attention */
att = attentions[i].Make(x, x, x, mask); att = attentions[i].Make(x, x, x, mask);
......
...@@ -60,7 +60,7 @@ void T2TModel::InitModel(int argc, const char ** argv) ...@@ -60,7 +60,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
if(useMem){ if(useMem){
delete mem; delete mem;
mem = new XMem(devID); mem = new XMem(devID, UNI_FREE, MILLION * 512, 1024, MILLION * 128);
} }
encoder.InitModel(argc, argv, isLM, isLM ? 1 : 0, devID, mem); encoder.InitModel(argc, argv, isLM, isLM ? 1 : 0, devID, mem);
...@@ -98,7 +98,9 @@ void T2TModel::Make(XTensor &input, XTensor &output) ...@@ -98,7 +98,9 @@ void T2TModel::Make(XTensor &input, XTensor &output)
dims[input.order] = len; dims[input.order] = len;
XTensor mask(input.order + 1, dims, X_FLOAT, 1.0F, input.devID, input.mem); XTensor mask(input.order + 1, dims, X_FLOAT, 1.0F, input.devID, input.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9 */ /* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
a given sequence. */
_SetDataLowTri(&mask, 1e9F, -1); _SetDataLowTri(&mask, 1e9F, -1);
_ScaleAndShiftMe(&mask, 1.0F, -1e9F); _ScaleAndShiftMe(&mask, 1.0F, -1e9F);
......
...@@ -53,6 +53,9 @@ initialization ...@@ -53,6 +53,9 @@ initialization
*/ */
void T2TTrainer::Init(int argc, const char ** argv) void T2TTrainer::Init(int argc, const char ** argv)
{ {
bool useMem = false;
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamInt(argc, argv, "dev", &devID, -1); LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamFloat(argc, argv, "lrate", &lrate, 0.001F); LoadParamFloat(argc, argv, "lrate", &lrate, 0.001F);
LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1); LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1);
...@@ -68,6 +71,11 @@ void T2TTrainer::Init(int argc, const char ** argv) ...@@ -68,6 +71,11 @@ void T2TTrainer::Init(int argc, const char ** argv)
buf = new int[bufSize]; buf = new int[bufSize];
seqLen = new int[bufSize]; seqLen = new int[bufSize];
seqOffset = new int[bufSize]; seqOffset = new int[bufSize];
if(useMem){
delete mem;
mem = new XMem(devID, UNI_FREE, MILLION * 64, 1024, MILLION * 64);
}
} }
/* /*
...@@ -86,6 +94,9 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -86,6 +94,9 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
float loss = 0; float loss = 0;
float lr = 0; float lr = 0;
model->mem->SetPin();
mem->SetPin();
XNet net; XNet net;
double startT = GetClockSec(); double startT = GetClockSec();
...@@ -96,11 +107,17 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -96,11 +107,17 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
CheckNTErrors(file, "cannot open training file!"); CheckNTErrors(file, "cannot open training file!");
wordCount = 0; wordCount = 0;
model->mem->BackToPin();
mem->BackToPin();
/* batch of input sequences */ /* batch of input sequences */
XTensor batch; XTensor batch;
/* padding */
XTensor padding;
while(LoadBatch(file, &batch, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc)){ while(LoadBatch(file, &batch, &padding, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc)){
/* output probabilities */ /* output probabilities */
XTensor output; XTensor output;
...@@ -108,6 +125,10 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -108,6 +125,10 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
/* make the network */ /* make the network */
model->Make(batch, output); model->Make(batch, output);
/* make paddings for the output */
if(output.GetDim(0) > 1)
PadOutput(&output, &padding);
/* back-propagation for obtaining gradients */ /* back-propagation for obtaining gradients */
net.Backward(output, batch, CROSSENTROPY); net.Backward(output, batch, CROSSENTROPY);
...@@ -135,6 +156,9 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -135,6 +156,9 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n", XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n",
lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount)); lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
} }
model->mem->BackToPin();
mem->BackToPin();
} }
fclose(file); fclose(file);
...@@ -230,6 +254,7 @@ int T2TTrainer::LoadBuf(FILE * file) ...@@ -230,6 +254,7 @@ int T2TTrainer::LoadBuf(FILE * file)
load a batch of sequences load a batch of sequences
>> file - the handle to the data file >> file - the handle to the data file
>> batch - the batch >> batch - the batch
>> padding - padding of the input sequences
>> step - the step we go over when move to the next sequence >> step - the step we go over when move to the next sequence
>> vs - vocabulary size >> vs - vocabulary size
>> sBatch - batch size of sequences >> sBatch - batch size of sequences
...@@ -237,7 +262,9 @@ load a batch of sequences ...@@ -237,7 +262,9 @@ load a batch of sequences
>> isSorted - indicates whether the sequences are sorted by length >> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count >> wCount - word count
*/ */
int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount) int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
int step, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount)
{ {
if(nextSeq < 0 || nextSeq >= nseqBuf) if(nextSeq < 0 || nextSeq >= nseqBuf)
LoadBuf(file); LoadBuf(file);
...@@ -273,12 +300,19 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sB ...@@ -273,12 +300,19 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sB
InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem); InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem);
} }
if(padding->order != 2 || padding->GetDim(0) != sc ||
padding->GetDim(1) != max){
InitTensor2D(padding, sc, max, X_FLOAT, devID, mem);
}
batch->SetZeroAll(); batch->SetZeroAll();
padding->SetZeroAll();
/* this might be slow on GPUs :( */ /* this might be slow on GPUs :( */
for(int s = seq; s < seq + sc; s++){ for(int s = seq; s < seq + sc; s++){
for(int w = 0; w < seqLen[s]; w++){ for(int w = 0; w < seqLen[s]; w++){
batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]); batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
padding->Set2D(1.0F, s - seq, w);
wCount++; wCount++;
} }
} }
...@@ -394,4 +428,35 @@ void T2TTrainer::Update(T2TModel * model, const float lr) ...@@ -394,4 +428,35 @@ void T2TTrainer::Update(T2TModel * model, const float lr)
} }
} }
/*
do padding on the output
>> output - output tensor of the network
>> padding - padding of a batch of sentences
*/
void T2TTrainer::PadOutput(XTensor * output, XTensor * padding)
{
if(output == NULL || padding == NULL)
return;
int on = output->order;
int * dimso = new int[on];
memcpy(dimso, output->dimSize, sizeof(int) * on);
output->Reshape(output->unitNum/dimso[output->order - 1], dimso[output->order - 1]);
XTensor * padding2 = NewTensorBuf(1, &padding->unitNum, X_FLOAT, 1.0F, padding->devID, padding->mem);
_CopyValues(padding, padding2);
_ScaleAndShiftMe(padding2, 1e9F, -1e9F);
_SumDim(output, padding2, output, 0);
output->Reshape(on, dimso);
delete[] dimso;
DelTensorBuf(padding2);
}
} }
...@@ -105,13 +105,18 @@ public: ...@@ -105,13 +105,18 @@ public:
int LoadBuf(FILE * file); int LoadBuf(FILE * file);
/* load a batch of sequences */ /* load a batch of sequences */
int LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount); int LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
int step, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount);
/* get word probabilities for a batch of sequences */ /* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs); float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
/* update the model by delta rule */ /* update the model by delta rule */
void Update(T2TModel * model, const float lr); void Update(T2TModel * model, const float lr);
/* do padding on the output */
void PadOutput(XTensor * output, XTensor * padding);
}; };
......
...@@ -405,7 +405,7 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long ...@@ -405,7 +405,7 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long
if (vectorSize % 32 != 0) minWarpNum++; if (vectorSize % 32 != 0) minWarpNum++;
warpNum = min(warpNum, minWarpNum); warpNum = min(warpNum, minWarpNum);
grid.x = vectorNum; grid.x = (unsigned int)vectorNum;
grid.y = 1; grid.y = 1;
grid.z = 1; grid.z = 1;
block.x = 1; block.x = 1;
......
...@@ -629,7 +629,7 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long ...@@ -629,7 +629,7 @@ inline void continuousStorageThreadAllocation(dim3& grid, dim3& block, long long
if (vectorSize % 32 != 0) minWarpNum++; if (vectorSize % 32 != 0) minWarpNum++;
warpNum = min(warpNum, minWarpNum); warpNum = min(warpNum, minWarpNum);
grid.x = vectorNum; grid.x = (unsigned int)vectorNum;
grid.y = 1; grid.y = 1;
grid.z = 1; grid.z = 1;
block.x = 1; block.x = 1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论