Commit 7b6840d4 by huchi

Refactor the dataloader in transformer

parent 3852f15a
......@@ -83,13 +83,13 @@ public:
FILE* fp;
/* number of training samples */
size_t totalSampleNum;
int totalSampleNum;
/* buffer size */
size_t bufferSize;
int bufferSize;
/* size of the bucket used for grouping sentences */
size_t bucketSize;
int bucketSize;
/* indicates whether it is used for training */
bool isTraining;
......@@ -112,44 +112,63 @@ public:
/* the maximum length for a target sentence */
int maxTgtLen;
/* batch size (number of words) */
int batchSize;
/* word-counter */
int wc;
/* sentence-counter */
int sc;
/* current index of the buffer */
int curIdx;
/* the buffer (a list) of samples */
XList * buf;
public:
/* get the maximum source sentence length in a range */
static
int MaxSrcLen(XList* buf, int begin, int end);
int MaxSrcLen(int begin, int end);
/* get the maximum target sentence length in a range */
static
int MaxTgtLen(XList* buf, int begin, int end);
int MaxTgtLen(int begin, int end);
/* sort the input by source sentence length (in descending order) */
void SortBySrcLength(XList* buf);
void SortBySrcLength();
/* sort the input by target sentence length (in descending order) */
void SortByTgtLength(XList* buf);
void SortByTgtLength();
/* sort buckets by key (in descending order) */
void SortBuckets(XList* buf);
void SortBuckets();
/* load the samples into the buffer (a list) */
bool LoadBatchToBuf(XList * buf);
/* load the samples into tensors from the buffer */
static
bool LoadBatch(XList * buf, int & curIdx,
XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label, int minSentBatch, int batchSize, int devID,
int &wc, int &sc);
bool LoadBatchToBuf();
/* release the samples in a buffer */
static
void ClearSamples(XList * buf);
void ClearSamples();
/* initialization function */
void Init(const char* dataFile, int bucketSize, bool training);
void Init(const char * dataFile, int myBatchSize, int myBucketSize, bool training);
/* group data into buckets with similar length */
void BuildBucket(XList * buf);
void BuildBucket();
/* get the number of sentences in a mini-batch */
int GetSentNum();
public:
/* start the process */
bool Start();
/* end the process */
bool End();
/* load the samples into tensors from the buffer */
bool GetBatchSimple(XList* inputs, XList* golds);
/* de-constructor */
~TrainDataSet();
......
......@@ -179,9 +179,8 @@ void Trainer::Train(const char* fn, const char* validFN,
double startT = GetClockSec();
int curIdx = 0;
XList* buf = new XList;
batchLoader.Init(fn, bucketSize, true);
batchLoader.Init(fn, wBatchSize, bucketSize, true);
for (epoch = 1; epoch <= nepoch; epoch++) {
......@@ -204,13 +203,16 @@ void Trainer::Train(const char* fn, const char* validFN,
XTensor paddingEnc;
XTensor paddingDec;
if (curIdx == 0 || curIdx == buf->Size()) {
curIdx = 0;
batchLoader.LoadBatchToBuf(buf);
}
TensorList inputs;
TensorList golds;
inputs.Add(&batchEnc);
inputs.Add(&paddingEnc);
golds.Add(&batchDec);
golds.Add(&paddingDec);
golds.Add(&label);
batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, wBatchSize, devID, wc, sc);
batchLoader.GetBatchSimple((XList*)(&inputs), (XList*)(&golds));
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
......@@ -303,9 +305,6 @@ void Trainer::Train(const char* fn, const char* validFN,
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
}
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT;
epoch = MIN(epoch, nepoch);
......@@ -341,16 +340,15 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
float loss = 0;
/* data files */
batchLoader.Init(fn, 0, false);
batchLoader.Init(fn, wBatchSize, 0, false);
int curIdx = 0;
XList* buf = new XList;
/* set the buffer size to the size of valiadation set */
batchLoader.bufferSize = batchLoader.totalSampleNum;
batchLoader.LoadBatchToBuf(buf);
batchLoader.LoadBatchToBuf();
while (curIdx < buf->count)
while (curIdx < batchLoader.buf->count)
{
/* batch of input sequences */
XTensor batchEnc;
......@@ -370,8 +368,16 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
XTensor labelOnehot;
XTensor lossTensor;
batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, 0, model->devID, wc, sc);
TensorList inputs;
TensorList golds;
inputs.Add(&batchEnc);
inputs.Add(&paddingEnc);
golds.Add(&batchDec);
golds.Add(&paddingDec);
golds.Add(&label);
batchLoader.GetBatchSimple((XList*)(&inputs), (XList*)(&golds));
CheckNTErrors(batchEnc.order == 2, "Wrong tensor order of the sequence batch");
......@@ -404,10 +410,6 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
model->decoder->history->ClearHistory(/*reset=*/false);
}
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT;
ENABLE_GRAD;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论