Commit 7b6840d4 by huchi

Refactor the dataloader in transformer

parent 3852f15a
...@@ -83,13 +83,13 @@ public: ...@@ -83,13 +83,13 @@ public:
FILE* fp; FILE* fp;
/* number of training samples */ /* number of training samples */
size_t totalSampleNum; int totalSampleNum;
/* buffer size */ /* buffer size */
size_t bufferSize; int bufferSize;
/* size of the bucket used for grouping sentences */ /* size of the bucket used for grouping sentences */
size_t bucketSize; int bucketSize;
/* indicates whether it is used for training */ /* indicates whether it is used for training */
bool isTraining; bool isTraining;
...@@ -112,44 +112,63 @@ public: ...@@ -112,44 +112,63 @@ public:
/* the maximum length for a target sentence */ /* the maximum length for a target sentence */
int maxTgtLen; int maxTgtLen;
/* batch size (number of words) */
int batchSize;
/* word-counter */
int wc;
/* sentence-counter */
int sc;
/* current index of the buffer */
int curIdx;
/* the buffer (a list) of samples */
XList * buf;
public: public:
/* get the maximum source sentence length in a range */ /* get the maximum source sentence length in a range */
static int MaxSrcLen(int begin, int end);
int MaxSrcLen(XList* buf, int begin, int end);
/* get the maximum target sentence length in a range */ /* get the maximum target sentence length in a range */
static int MaxTgtLen(int begin, int end);
int MaxTgtLen(XList* buf, int begin, int end);
/* sort the input by source sentence length (in descending order) */ /* sort the input by source sentence length (in descending order) */
void SortBySrcLength(XList* buf); void SortBySrcLength();
/* sort the input by target sentence length (in descending order) */ /* sort the input by target sentence length (in descending order) */
void SortByTgtLength(XList* buf); void SortByTgtLength();
/* sort buckets by key (in descending order) */ /* sort buckets by key (in descending order) */
void SortBuckets(XList* buf); void SortBuckets();
/* load the samples into the buffer (a list) */ /* load the samples into the buffer (a list) */
bool LoadBatchToBuf(XList * buf); bool LoadBatchToBuf();
/* load the samples into tensors from the buffer */
static
bool LoadBatch(XList * buf, int & curIdx,
XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label, int minSentBatch, int batchSize, int devID,
int &wc, int &sc);
/* release the samples in a buffer */ /* release the samples in a buffer */
static void ClearSamples();
void ClearSamples(XList * buf);
/* initialization function */ /* initialization function */
void Init(const char* dataFile, int bucketSize, bool training); void Init(const char * dataFile, int myBatchSize, int myBucketSize, bool training);
/* group data into buckets with similar length */ /* group data into buckets with similar length */
void BuildBucket(XList * buf); void BuildBucket();
/* get the number of sentences in a mini-batch */
int GetSentNum();
public:
/* start the process */
bool Start();
/* end the process */
bool End();
/* load the samples into tensors from the buffer */
bool GetBatchSimple(XList* inputs, XList* golds);
/* de-constructor */ /* de-constructor */
~TrainDataSet(); ~TrainDataSet();
......
...@@ -179,9 +179,8 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -179,9 +179,8 @@ void Trainer::Train(const char* fn, const char* validFN,
double startT = GetClockSec(); double startT = GetClockSec();
int curIdx = 0; int curIdx = 0;
XList* buf = new XList;
batchLoader.Init(fn, bucketSize, true); batchLoader.Init(fn, wBatchSize, bucketSize, true);
for (epoch = 1; epoch <= nepoch; epoch++) { for (epoch = 1; epoch <= nepoch; epoch++) {
...@@ -204,13 +203,16 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -204,13 +203,16 @@ void Trainer::Train(const char* fn, const char* validFN,
XTensor paddingEnc; XTensor paddingEnc;
XTensor paddingDec; XTensor paddingDec;
if (curIdx == 0 || curIdx == buf->Size()) { TensorList inputs;
curIdx = 0; TensorList golds;
batchLoader.LoadBatchToBuf(buf);
} inputs.Add(&batchEnc);
inputs.Add(&paddingEnc);
golds.Add(&batchDec);
golds.Add(&paddingDec);
golds.Add(&label);
batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label, batchLoader.GetBatchSimple((XList*)(&inputs), (XList*)(&golds));
sBatchSize, wBatchSize, devID, wc, sc);
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
...@@ -303,9 +305,6 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -303,9 +305,6 @@ void Trainer::Train(const char* fn, const char* validFN,
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch); MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
} }
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
epoch = MIN(epoch, nepoch); epoch = MIN(epoch, nepoch);
...@@ -341,16 +340,15 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model) ...@@ -341,16 +340,15 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
float loss = 0; float loss = 0;
/* data files */ /* data files */
batchLoader.Init(fn, 0, false); batchLoader.Init(fn, wBatchSize, 0, false);
int curIdx = 0; int curIdx = 0;
XList* buf = new XList;
/* set the buffer size to the size of valiadation set */ /* set the buffer size to the size of valiadation set */
batchLoader.bufferSize = batchLoader.totalSampleNum; batchLoader.bufferSize = batchLoader.totalSampleNum;
batchLoader.LoadBatchToBuf(buf); batchLoader.LoadBatchToBuf();
while (curIdx < buf->count) while (curIdx < batchLoader.buf->count)
{ {
/* batch of input sequences */ /* batch of input sequences */
XTensor batchEnc; XTensor batchEnc;
...@@ -370,8 +368,16 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model) ...@@ -370,8 +368,16 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
XTensor labelOnehot; XTensor labelOnehot;
XTensor lossTensor; XTensor lossTensor;
batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label, TensorList inputs;
sBatchSize, 0, model->devID, wc, sc); TensorList golds;
inputs.Add(&batchEnc);
inputs.Add(&paddingEnc);
golds.Add(&batchDec);
golds.Add(&paddingDec);
golds.Add(&label);
batchLoader.GetBatchSimple((XList*)(&inputs), (XList*)(&golds));
CheckNTErrors(batchEnc.order == 2, "Wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "Wrong tensor order of the sequence batch");
...@@ -404,10 +410,6 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model) ...@@ -404,10 +410,6 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
model->decoder->history->ClearHistory(/*reset=*/false); model->decoder->history->ClearHistory(/*reset=*/false);
} }
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
ENABLE_GRAD; ENABLE_GRAD;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论