Commit 7b6840d4 by huchi

Refactor the dataloader in transformer

parent 3852f15a
......@@ -31,7 +31,7 @@ using namespace nmt;
namespace nts {
/* get the maximum source sentence length in a range */
int TrainDataSet::MaxSrcLen(XList* buf, int begin, int end) {
int TrainDataSet::MaxSrcLen(int begin, int end) {
CheckNTErrors((end > begin) && (begin >= 0) && (end <= buf->count), "Invalid range");
int maxLen = 0;
for (int i = begin; i < end; i++) {
......@@ -42,7 +42,7 @@ int TrainDataSet::MaxSrcLen(XList* buf, int begin, int end) {
}
/* get the maximum target sentence length in a range */
int TrainDataSet::MaxTgtLen(XList* buf, int begin, int end) {
int TrainDataSet::MaxTgtLen(int begin, int end) {
CheckNTErrors((end > begin) && (begin >= 0) && (end <= buf->count), "Invalid range");
int maxLen = 0;
for (int i = begin; i < end; i++) {
......@@ -53,28 +53,28 @@ int TrainDataSet::MaxTgtLen(XList* buf, int begin, int end) {
}
/* sort the buffer by source sentence length (in descending order) */
void TrainDataSet::SortBySrcLength(XList* buf) {
void TrainDataSet::SortBySrcLength() {
stable_sort(buf->items, buf->items + buf->count,
[](void* a, void* b) {
return ((TrainExample*)(a))->srcSent->Size() <
((TrainExample*)(b))->srcSent->Size();
});
[](void* a, void* b) {
return ((TrainExample*)(a))->srcSent->Size() <
((TrainExample*)(b))->srcSent->Size();
});
}
/* sort the buffer by target sentence length (in descending order) */
void TrainDataSet::SortByTgtLength(XList* buf) {
void TrainDataSet::SortByTgtLength() {
stable_sort(buf->items, buf->items + buf->count,
[](void* a, void* b) {
return ((TrainExample*)(a))->tgtSent->Size() <
((TrainExample*)(b))->tgtSent->Size();
});
[](void* a, void* b) {
return ((TrainExample*)(a))->tgtSent->Size() <
((TrainExample*)(b))->tgtSent->Size();
});
}
/* sort buckets by key (in descending order) */
void TrainDataSet::SortBuckets(XList* buf) {
void TrainDataSet::SortBuckets() {
sort(buf->items, buf->items + buf->count,
[](void* a, void* b) {
return ((TrainExample*)(a))->bucketKey <
((TrainExample*)(b))->bucketKey;
return ((TrainExample*)(a))->bucketKey <
((TrainExample*)(b))->bucketKey;
});
}
......@@ -87,11 +87,13 @@ source sentence length (4 bit)
target sentence length (4 bit)
source tokens (4 bit per token)
target tokens (4 bit per token)
>> buf - the buffer (list) of samples
*/
bool TrainDataSet::LoadBatchToBuf(XList* buf)
bool TrainDataSet::LoadBatchToBuf()
{
ClearSamples(buf);
/* reset the buffer and index */
curIdx = 0;
ClearSamples();
int sampleNum = 0;
......@@ -130,69 +132,50 @@ bool TrainDataSet::LoadBatchToBuf(XList* buf)
}
/* group samples in the buffer into buckets */
SortByTgtLength(buf);
SortByTgtLength();
SortBySrcLength(buf);
SortBySrcLength();
if (isTraining)
BuildBucket(buf);
BuildBucket();
return true;
}
/*
load a mini-batch to a device
>> buf - the buffer (list) of samples
>> curIdx - the index of the buffer
>> batchEnc - a tensor to store the batch of encoder input
>> paddingEnc - a tensor to store the batch of encoder paddings
>> batchDec - a tensor to store the batch of decoder input
>> paddingDec - a tensor to store the batch of decoder paddings
>> label - a tensor to store the label of input
>> minSentBatch - the minimum number of sentence batch
>> batchSize - the maxium number of words in a batch
>> devID - the device id, -1 for the CPU
>> wc - number of target words in a batch
>> sc - number of samples in a batch
>> inputs - the list to store input tensors
>> golds - the list to store gold tensors
*/
bool TrainDataSet::LoadBatch(XList* buf, int & curIdx,
XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label,
int minSentBatch, int batchSize, int devID, int& wc, int& sc)
bool TrainDataSet::GetBatchSimple(XList* inputs, XList* golds)
{
int srcTokenNum = 0;
int tgtTokenNum = 0;
int realBatchSize = 0;
/* dynamic batching for sentences */
int bucketKey = ((TrainExample*)(buf->Get(curIdx)))->bucketKey;
while ((realBatchSize < (int(buf->Size()) - curIdx)) &&
(((TrainExample*)(buf->Get(curIdx + realBatchSize)))->bucketKey == bucketKey)) {
realBatchSize++;
if (curIdx == 0 || curIdx == buf->Size()) {
LoadBatchToBuf();
}
realBatchSize = MIN(realBatchSize, (int(buf->Size()) - curIdx));
CheckNTErrors(realBatchSize > 0, "Invalid batch size");
wc = 0;
GetSentNum();
/* get the maximum target sentence length in a mini-batch */
int maxSrcLen = MaxSrcLen(buf, curIdx, curIdx + realBatchSize);
int maxTgtLen = MaxTgtLen(buf, curIdx, curIdx + realBatchSize);
int maxSrcLen = MaxSrcLen(curIdx, curIdx + sc);
int maxTgtLen = MaxTgtLen(curIdx, curIdx + sc);
CheckNTErrors(maxSrcLen > 0, "Invalid source length for batching");
CheckNTErrors(maxTgtLen > 0, "Invalid target length for batching");
int* batchEncValues = new int[realBatchSize * maxSrcLen];
float* paddingEncValues = new float[realBatchSize * maxSrcLen];
int* batchEncValues = new int[sc * maxSrcLen];
float* paddingEncValues = new float[sc * maxSrcLen];
int* labelVaues = new int[realBatchSize * maxTgtLen];
int* batchDecValues = new int[realBatchSize * maxTgtLen];
float* paddingDecValues = new float[realBatchSize * maxTgtLen];
int* labelVaues = new int[sc * maxTgtLen];
int* batchDecValues = new int[sc * maxTgtLen];
float* paddingDecValues = new float[sc * maxTgtLen];
for (int i = 0; i < realBatchSize * maxSrcLen; i++) {
for (int i = 0; i < sc * maxSrcLen; i++) {
batchEncValues[i] = 1;
paddingEncValues[i] = 1.0F;
}
for (int i = 0; i < realBatchSize * maxTgtLen; i++) {
for (int i = 0; i < sc * maxTgtLen; i++) {
batchDecValues[i] = 1;
labelVaues[i] = 1;
paddingDecValues[i] = 1.0F;
......@@ -206,11 +189,10 @@ bool TrainDataSet::LoadBatch(XList* buf, int & curIdx,
batchDec: begin with SOS (right padding)
label: end with EOS (right padding)
*/
for (int i = 0; i < realBatchSize; ++i) {
for (int i = 0; i < sc; ++i) {
TrainExample* sample = (TrainExample*)(buf->Get(curIdx + i));
srcTokenNum += int(sample->srcSent->Size());
tgtTokenNum += int(sample->tgtSent->Size());
wc += int(sample->tgtSent->Size());
curSrc = maxSrcLen * i;
for (int j = 0; j < sample->srcSent->Size(); j++) {
......@@ -230,13 +212,25 @@ bool TrainDataSet::LoadBatch(XList* buf, int & curIdx,
paddingDecValues[curTgt++] = 0;
}
InitTensor2D(batchEnc, realBatchSize, maxSrcLen, X_INT, devID);
InitTensor2D(paddingEnc, realBatchSize, maxSrcLen, X_FLOAT, devID);
InitTensor2D(batchDec, realBatchSize, maxTgtLen, X_INT, devID);
InitTensor2D(paddingDec, realBatchSize, maxTgtLen, X_FLOAT, devID);
InitTensor2D(label, realBatchSize, maxTgtLen, X_INT, devID);
XTensor * batchEnc = ((TensorList*)(inputs))->Get(0);
XTensor * paddingEnc = ((TensorList*)(inputs))->Get(1);
XTensor * batchDec = ((TensorList*)(golds))->Get(0);
XTensor * paddingDec = ((TensorList*)(golds))->Get(1);
XTensor * label = ((TensorList*)(golds))->Get(2);
curIdx += realBatchSize;
InitTensor2D(batchEnc, sc, maxSrcLen, X_INT);
InitTensor2D(paddingEnc, sc, maxSrcLen, X_FLOAT);
InitTensor2D(batchDec, sc, maxTgtLen, X_INT);
InitTensor2D(paddingDec, sc, maxTgtLen, X_FLOAT);
InitTensor2D(label, sc, maxTgtLen, X_INT);
inputs->Add(batchEnc);
inputs->Add(paddingEnc);
golds->Add(batchDec);
golds->Add(paddingDec);
golds->Add(label);
curIdx += sc;
batchEnc->SetData(batchEncValues, batchEnc->unitNum);
paddingEnc->SetData(paddingEncValues, paddingEnc->unitNum);
......@@ -250,8 +244,6 @@ bool TrainDataSet::LoadBatch(XList* buf, int & curIdx,
delete[] paddingDecValues;
delete[] labelVaues;
wc = tgtTokenNum;
sc = realBatchSize;
return true;
}
......@@ -259,7 +251,7 @@ bool TrainDataSet::LoadBatch(XList* buf, int & curIdx,
clear the buffer
>> buf - the buffer (list) of samples
*/
void TrainDataSet::ClearSamples(XList* buf)
void TrainDataSet::ClearSamples()
{
for (int i = 0; i < buf->count; i++) {
TrainExample* sample = (TrainExample*)buf->Get(i);
......@@ -274,7 +266,7 @@ the constructor of DataSet
>> bucketSize - size of the bucket to keep similar length sentence pairs
>> training - indicates whether it is used for training
*/
void TrainDataSet::Init(const char* dataFile, int myBucketSize, bool training)
void TrainDataSet::Init(const char* dataFile, int myBatchSize, int myBucketSize, bool training)
{
fp = fopen(dataFile, "rb");
CheckNTErrors(fp, "can not open the training file");
......@@ -289,12 +281,15 @@ void TrainDataSet::Init(const char* dataFile, int myBucketSize, bool training)
fread(&totalSampleNum, sizeof(totalSampleNum), 1, fp);
CheckNTErrors(totalSampleNum > 0, "Invalid sentence pairs number");
batchSize = myBatchSize;
bucketSize = myBucketSize;
isTraining = training;
buf = new XList;
}
/* group data with similar length into buckets */
void TrainDataSet::BuildBucket(XList * buf)
/* group samples with similar length into buckets */
void TrainDataSet::BuildBucket()
{
int idx = 0;
......@@ -305,8 +300,8 @@ void TrainDataSet::BuildBucket(XList * buf)
int sentNum = 1;
/* get the maximum source sentence length in a bucket */
int maxSrcLen = MaxSrcLen(buf, idx, idx + sentNum);
int maxTgtLen = MaxTgtLen(buf, idx, idx + sentNum);
int maxSrcLen = MaxSrcLen(idx, idx + sentNum);
int maxTgtLen = MaxTgtLen(idx, idx + sentNum);
int maxLen = MAX(maxSrcLen, maxTgtLen);
/* the maximum sentence number in a bucket */
......@@ -316,8 +311,8 @@ void TrainDataSet::BuildBucket(XList * buf)
&& (sentNum < MAX_SENT_NUM)
&& (sentNum * maxLen <= bucketSize)) {
sentNum++;
maxSrcLen = MaxSrcLen(buf, idx, idx + sentNum);
maxTgtLen = MaxTgtLen(buf, idx, idx + sentNum);
maxSrcLen = MaxSrcLen(idx, idx + sentNum);
maxTgtLen = MaxTgtLen(idx, idx + sentNum);
maxLen = MAX(maxSrcLen, maxTgtLen);
}
......@@ -339,12 +334,42 @@ void TrainDataSet::BuildBucket(XList * buf)
}
/* sort buckets by their keys */
SortBuckets(buf);
SortBuckets();
}
/* get the number of sentences in a mini-batch */
inline int TrainDataSet::GetSentNum()
{
sc = 0;
/* dynamic batching for sentences */
int bucketKey = ((TrainExample*)(buf->Get(curIdx)))->bucketKey;
while ((sc < (int(buf->Size()) - curIdx)) &&
(((TrainExample*)(buf->Get(curIdx + sc)))->bucketKey == bucketKey)) {
sc++;
}
sc = MIN(sc, (int(buf->Size()) - curIdx));
CheckNTErrors(sc > 0, "Invalid batch size");
}
/* start the process */
bool TrainDataSet::Start()
{
return false;
}
/* end the process */
bool TrainDataSet::End()
{
return true;
}
/* de-constructor */
TrainDataSet::~TrainDataSet()
{
ClearSamples();
delete buf;
fclose(fp);
}
......
......@@ -83,13 +83,13 @@ public:
FILE* fp;
/* number of training samples */
size_t totalSampleNum;
int totalSampleNum;
/* buffer size */
size_t bufferSize;
int bufferSize;
/* size of the bucket used for grouping sentences */
size_t bucketSize;
int bucketSize;
/* indicates whether it is used for training */
bool isTraining;
......@@ -112,44 +112,63 @@ public:
/* the maximum length for a target sentence */
int maxTgtLen;
/* batch size (number of words) */
int batchSize;
/* word-counter */
int wc;
/* sentence-counter */
int sc;
/* current index of the buffer */
int curIdx;
/* the buffer (a list) of samples */
XList * buf;
public:
/* get the maximum source sentence length in a range */
static
int MaxSrcLen(XList* buf, int begin, int end);
int MaxSrcLen(int begin, int end);
/* get the maximum target sentence length in a range */
static
int MaxTgtLen(XList* buf, int begin, int end);
int MaxTgtLen(int begin, int end);
/* sort the input by source sentence length (in descending order) */
void SortBySrcLength(XList* buf);
void SortBySrcLength();
/* sort the input by target sentence length (in descending order) */
void SortByTgtLength(XList* buf);
void SortByTgtLength();
/* sort buckets by key (in descending order) */
void SortBuckets(XList* buf);
void SortBuckets();
/* load the samples into the buffer (a list) */
bool LoadBatchToBuf(XList * buf);
/* load the samples into tensors from the buffer */
static
bool LoadBatch(XList * buf, int & curIdx,
XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label, int minSentBatch, int batchSize, int devID,
int &wc, int &sc);
bool LoadBatchToBuf();
/* release the samples in a buffer */
static
void ClearSamples(XList * buf);
void ClearSamples();
/* initialization function */
void Init(const char* dataFile, int bucketSize, bool training);
void Init(const char * dataFile, int myBatchSize, int myBucketSize, bool training);
/* group data into buckets with similar length */
void BuildBucket(XList * buf);
void BuildBucket();
/* get the number of sentences in a mini-batch */
int GetSentNum();
public:
/* start the process */
bool Start();
/* end the process */
bool End();
/* load the samples into tensors from the buffer */
bool GetBatchSimple(XList* inputs, XList* golds);
/* de-constructor */
~TrainDataSet();
......
......@@ -179,9 +179,8 @@ void Trainer::Train(const char* fn, const char* validFN,
double startT = GetClockSec();
int curIdx = 0;
XList* buf = new XList;
batchLoader.Init(fn, bucketSize, true);
batchLoader.Init(fn, wBatchSize, bucketSize, true);
for (epoch = 1; epoch <= nepoch; epoch++) {
......@@ -204,13 +203,16 @@ void Trainer::Train(const char* fn, const char* validFN,
XTensor paddingEnc;
XTensor paddingDec;
if (curIdx == 0 || curIdx == buf->Size()) {
curIdx = 0;
batchLoader.LoadBatchToBuf(buf);
}
TensorList inputs;
TensorList golds;
inputs.Add(&batchEnc);
inputs.Add(&paddingEnc);
golds.Add(&batchDec);
golds.Add(&paddingDec);
golds.Add(&label);
batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, wBatchSize, devID, wc, sc);
batchLoader.GetBatchSimple((XList*)(&inputs), (XList*)(&golds));
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
......@@ -303,9 +305,6 @@ void Trainer::Train(const char* fn, const char* validFN,
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
}
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT;
epoch = MIN(epoch, nepoch);
......@@ -341,16 +340,15 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
float loss = 0;
/* data files */
batchLoader.Init(fn, 0, false);
batchLoader.Init(fn, wBatchSize, 0, false);
int curIdx = 0;
XList* buf = new XList;
/* set the buffer size to the size of valiadation set */
batchLoader.bufferSize = batchLoader.totalSampleNum;
batchLoader.LoadBatchToBuf(buf);
batchLoader.LoadBatchToBuf();
while (curIdx < buf->count)
while (curIdx < batchLoader.buf->count)
{
/* batch of input sequences */
XTensor batchEnc;
......@@ -370,8 +368,16 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
XTensor labelOnehot;
XTensor lossTensor;
batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, 0, model->devID, wc, sc);
TensorList inputs;
TensorList golds;
inputs.Add(&batchEnc);
inputs.Add(&paddingEnc);
golds.Add(&batchDec);
golds.Add(&paddingDec);
golds.Add(&label);
batchLoader.GetBatchSimple((XList*)(&inputs), (XList*)(&golds));
CheckNTErrors(batchEnc.order == 2, "Wrong tensor order of the sequence batch");
......@@ -404,10 +410,6 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
model->decoder->history->ClearHistory(/*reset=*/false);
}
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT;
ENABLE_GRAD;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论