Commit 98a9130d by hello

Refactor class `TrainDataSet`

parent 4bbd6a27
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -16,13 +16,10 @@ ...@@ -16,13 +16,10 @@
/* /*
* $Created by: HU Chi (huchinlp@foxmail.com) 2020-08-09 * $Created by: HU Chi (huchinlp@foxmail.com) 2020-08-09
* TODO: refactor the data loader class and references * $Updated by: CAO Hang and Wu Siming 2020-12-13
*/ */
#include <string>
#include <vector>
#include <cstdlib> #include <cstdlib>
#include <fstream>
#include <algorithm> #include <algorithm>
#include "TrainDataSet.h" #include "TrainDataSet.h"
...@@ -33,37 +30,56 @@ using namespace nmt; ...@@ -33,37 +30,56 @@ using namespace nmt;
namespace nts { namespace nts {
/* sort the dataset by length (in descending order) */ /* get the maximum source sentence length in a range */
void TrainDataSet::SortByLength() { int TrainDataSet::MaxSrcLen(XList* buf, int begin, int end) {
sort(buffer.items, buffer.items + buffer.count, CheckNTErrors((end > begin) && (begin >= 0) && (end <= buf->count), "Invalid range");
[](TrainExample* a, TrainExample* b) { int maxLen = 0;
return (a->srcSent.Size() + a->tgtSent.Size()) for (int i = begin; i < end; i++) {
> (b->srcSent.Size() + b->tgtSent.Size()); IntList* srcSent = ((TrainExample*)buf->Get(i))->srcSent;
}); maxLen = MAX(int(srcSent->Size()), maxLen);
}
return maxLen;
} }
/* sort buckets by key (in descending order) */ /* get the maximum target sentence length in a range */
void TrainDataSet::SortBucket() { int TrainDataSet::MaxTgtLen(XList* buf, int begin, int end) {
sort(buffer.items, buffer.items + buffer.count, CheckNTErrors((end > begin) && (begin >= 0) && (end <= buf->count), "Invalid range");
[](TrainExample* a, TrainExample* b) { int maxLen = 0;
return a->bucketKey > b->bucketKey; for (int i = begin; i < end; i++) {
}); IntList* tgtSent = ((TrainExample*)buf->Get(i))->tgtSent;
maxLen = MAX(int(tgtSent->Size()), maxLen);
}
return maxLen;
} }
/* /* sort the buffer by source sentence length (in descending order) */
sort the output by key in a range (in descending order) void TrainDataSet::SortBySrcLength(XList* buf) {
>> begin - the first index of the range stable_sort(buf->items, buf->items + buf->count,
>> end - the last index of the range [](void* a, void* b) {
*/ return ((TrainExample*)(a))->srcSent->Size() <
void TrainDataSet::SortInBucket(int begin, int end) { ((TrainExample*)(b))->srcSent->Size();
sort(buffer.items + begin, buffer.items + end, });
[](TrainExample* a, TrainExample* b) { }
return (a->key > b->key); /* sort the buffer by target sentence length (in descending order) */
}); void TrainDataSet::SortByTgtLength(XList* buf) {
stable_sort(buf->items, buf->items + buf->count,
[](void* a, void* b) {
return ((TrainExample*)(a))->tgtSent->Size() <
((TrainExample*)(b))->tgtSent->Size();
});
}
/* sort buckets by key (in descending order) */
void TrainDataSet::SortBuckets(XList* buf) {
sort(buf->items, buf->items + buf->count,
[](void* a, void* b) {
return ((TrainExample*)(a))->bucketKey <
((TrainExample*)(b))->bucketKey;
});
} }
/* /*
load all data from a file to the buffer load samples from a file into the buffer
training data format (binary): training data format (binary):
first 8 bit: number of sentence pairs first 8 bit: number of sentence pairs
subsequent segements: subsequent segements:
...@@ -71,52 +87,63 @@ source sentence length (4 bit) ...@@ -71,52 +87,63 @@ source sentence length (4 bit)
target sentence length (4 bit) target sentence length (4 bit)
source tokens (4 bit per token) source tokens (4 bit per token)
target tokens (4 bit per token) target tokens (4 bit per token)
>> buf - the buffer (list) of samples
*/ */
void TrainDataSet::LoadDataToBuffer() bool TrainDataSet::LoadBatchToBuf(XList* buf)
{ {
buffer.Clear(); ClearSamples(buf);
curIdx = 0;
int id = 0;
uint64_t sentNum = 0;
int srcVocabSize = 0; int sampleNum = 0;
int tgtVocabSize = 0;
fread(&srcVocabSize, sizeof(srcVocabSize), 1, fp);
fread(&tgtVocabSize, sizeof(tgtVocabSize), 1, fp);
fread(&sentNum, sizeof(uint64_t), 1, fp); while ((sampleNum < bufferSize)) {
CheckNTErrors(sentNum > 0, "Invalid sentence pairs number");
while (id < sentNum) {
int srcLen = 0; int srcLen = 0;
int tgtLen = 0; int tgtLen = 0;
fread(&srcLen, sizeof(int), 1, fp);
size_t n = fread(&srcLen, sizeof(int), 1, fp);
if (n == 0)
break;
fread(&tgtLen, sizeof(int), 1, fp); fread(&tgtLen, sizeof(int), 1, fp);
CheckNTErrors(srcLen > 0, "Invalid source sentence length"); CheckNTErrors(srcLen > 0, "Invalid source sentence length");
CheckNTErrors(tgtLen > 0, "Invalid target sentence length"); CheckNTErrors(tgtLen > 0, "Invalid target sentence length");
IntList srcSent; IntList *srcSent = new IntList(srcLen);
IntList tgtSent; IntList *tgtSent = new IntList(tgtLen);
srcSent.ReadFromFile(fp, srcLen); srcSent->ReadFromFile(fp, srcLen);
tgtSent.ReadFromFile(fp, tgtLen); tgtSent->ReadFromFile(fp, tgtLen);
TrainExample* example = new TrainExample(sampleNum++, 0, srcSent, tgtSent);
buf->Add(example);
}
/* reset the file pointer to the begin */
if (feof(fp) && isTraining) {
TrainExample* example = new TrainExample; rewind(fp);
example->id = id++;
example->key = id;
example->srcSent = srcSent;
example->tgtSent = tgtSent;
buffer.Add(example); int srcVocabSize = 0;
int tgtVocabSize = 0;
fread(&srcVocabSize, sizeof(int), 1, fp);
fread(&tgtVocabSize, sizeof(int), 1, fp);
fread(&totalSampleNum, sizeof(totalSampleNum), 1, fp);
} }
fclose(fp); /* group samples in the buffer into buckets */
SortByTgtLength(buf);
SortBySrcLength(buf);
if (isTraining)
BuildBucket(buf);
XPRINT1(0, stderr, "[INFO] loaded %d sentences\n", id); return true;
} }
/* /*
load a mini-batch to the device (for training) load a mini-batch to a device
>> buf - the buffer (list) of samples
>> curIdx - the index of the buffer
>> batchEnc - a tensor to store the batch of encoder input >> batchEnc - a tensor to store the batch of encoder input
>> paddingEnc - a tensor to store the batch of encoder paddings >> paddingEnc - a tensor to store the batch of encoder paddings
>> batchDec - a tensor to store the batch of decoder input >> batchDec - a tensor to store the batch of decoder input
...@@ -125,57 +152,34 @@ load a mini-batch to the device (for training) ...@@ -125,57 +152,34 @@ load a mini-batch to the device (for training)
>> minSentBatch - the minimum number of sentence batch >> minSentBatch - the minimum number of sentence batch
>> batchSize - the maxium number of words in a batch >> batchSize - the maxium number of words in a batch
>> devID - the device id, -1 for the CPU >> devID - the device id, -1 for the CPU
<< return - number of target tokens and sentences >> wc - number of target words in a batch
>> sc - number of samples in a batch
*/ */
UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc, bool TrainDataSet::LoadBatch(XList* buf, int & curIdx,
XTensor* batchDec, XTensor* paddingDec, XTensor* label, XTensor* batchEnc, XTensor* paddingEnc,
size_t minSentBatch, size_t batchSize, int devID) XTensor* batchDec, XTensor* paddingDec, XTensor* label,
int minSentBatch, int batchSize, int devID, int& wc, int& sc)
{ {
UInt64List info; int srcTokenNum = 0;
size_t srcTokenNum = 0; int tgtTokenNum = 0;
size_t tgtTokenNum = 0; int realBatchSize = 0;
size_t realBatchSize = 1;
/* dynamic batching for sentences */
if (!isTraining) int bucketKey = ((TrainExample*)(buf->Get(curIdx)))->bucketKey;
realBatchSize = minSentBatch; while ((realBatchSize < (int(buf->Size()) - curIdx)) &&
(((TrainExample*)(buf->Get(curIdx + realBatchSize)))->bucketKey == bucketKey)) {
/* get the maximum source sentence length in a mini-batch */ realBatchSize++;
size_t maxSrcLen = buffer[(int)curIdx]->srcSent.Size();
/* max batch size */
const int MAX_BATCH_SIZE = 512;
/* dynamic batching for sentences, enabled when the dataset is used for training */
if (isTraining) {
while ((realBatchSize < (buffer.Size() - curIdx))
&& (realBatchSize * maxSrcLen < batchSize)
&& (realBatchSize < MAX_BATCH_SIZE)
&& (realBatchSize * buffer[(int)(curIdx + realBatchSize)]->srcSent.Size() < batchSize)) {
if (maxSrcLen < buffer[(int)(curIdx + realBatchSize)]->srcSent.Size())
maxSrcLen = buffer[(int)(curIdx + realBatchSize)]->srcSent.Size();
realBatchSize++;
}
}
/* real batch size */
if ((buffer.Size() - curIdx) < realBatchSize) {
realBatchSize = buffer.Size() - curIdx;
} }
realBatchSize = MIN(realBatchSize, (int(buf->Size()) - curIdx));
CheckNTErrors(realBatchSize > 0, "Invalid batch size"); CheckNTErrors(realBatchSize > 0, "Invalid batch size");
/* get the maximum target sentence length in a mini-batch */ /* get the maximum target sentence length in a mini-batch */
size_t maxTgtLen = buffer[(int)curIdx]->tgtSent.Size(); int maxSrcLen = MaxSrcLen(buf, curIdx, curIdx + realBatchSize);
for (size_t i = 0; i < realBatchSize; i++) { int maxTgtLen = MaxTgtLen(buf, curIdx, curIdx + realBatchSize);
if (maxTgtLen < buffer[(int)(curIdx + i)]->tgtSent.Size())
maxTgtLen = buffer[(int)(curIdx + i)]->tgtSent.Size();
}
for (size_t i = 0; i < realBatchSize; i++) {
if (maxSrcLen < buffer[(int)(curIdx + i)]->srcSent.Size())
maxSrcLen = buffer[(int)(curIdx + i)]->srcSent.Size();
}
CheckNTErrors(maxSrcLen != 0, "Invalid source length for batching"); CheckNTErrors(maxSrcLen > 0, "Invalid source length for batching");
CheckNTErrors(maxTgtLen > 0, "Invalid target length for batching");
int* batchEncValues = new int[realBatchSize * maxSrcLen]; int* batchEncValues = new int[realBatchSize * maxSrcLen];
float* paddingEncValues = new float[realBatchSize * maxSrcLen]; float* paddingEncValues = new float[realBatchSize * maxSrcLen];
...@@ -185,17 +189,17 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc, ...@@ -185,17 +189,17 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
float* paddingDecValues = new float[realBatchSize * maxTgtLen]; float* paddingDecValues = new float[realBatchSize * maxTgtLen];
for (int i = 0; i < realBatchSize * maxSrcLen; i++) { for (int i = 0; i < realBatchSize * maxSrcLen; i++) {
batchEncValues[i] = PAD; batchEncValues[i] = 1;
paddingEncValues[i] = 1; paddingEncValues[i] = 1.0F;
} }
for (int i = 0; i < realBatchSize * maxTgtLen; i++) { for (int i = 0; i < realBatchSize * maxTgtLen; i++) {
batchDecValues[i] = PAD; batchDecValues[i] = 1;
labelVaues[i] = PAD; labelVaues[i] = 1;
paddingDecValues[i] = 1.0F; paddingDecValues[i] = 1.0F;
} }
size_t curSrc = 0; int curSrc = 0;
size_t curTgt = 0; int curTgt = 0;
/* /*
batchEnc: end with EOS (left padding) batchEnc: end with EOS (left padding)
...@@ -204,35 +208,33 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc, ...@@ -204,35 +208,33 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
*/ */
for (int i = 0; i < realBatchSize; ++i) { for (int i = 0; i < realBatchSize; ++i) {
srcTokenNum += buffer[(int)(curIdx + i)]->srcSent.Size(); TrainExample* sample = (TrainExample*)(buf->Get(curIdx + i));
tgtTokenNum += buffer[(int)(curIdx + i)]->tgtSent.Size(); srcTokenNum += int(sample->srcSent->Size());
tgtTokenNum += int(sample->tgtSent->Size());
curSrc = maxSrcLen * i; curSrc = maxSrcLen * i;
for (int j = 0; j < buffer[(int)(curIdx + i)]->srcSent.Size(); j++) { for (int j = 0; j < sample->srcSent->Size(); j++) {
batchEncValues[curSrc++] = buffer[(int)(curIdx + i)]->srcSent[j]; batchEncValues[curSrc++] = sample->srcSent->Get(j);
} }
curTgt = maxTgtLen * i; curTgt = maxTgtLen * i;
for (int j = 0; j < buffer[(int)(curIdx + i)]->tgtSent.Size(); j++) { for (int j = 0; j < sample->tgtSent->Size(); j++) {
if (j > 0) if (j > 0)
labelVaues[curTgt - 1] = buffer[(int)(curIdx + i)]->tgtSent[j]; labelVaues[curTgt - 1] = sample->tgtSent->Get(j);
batchDecValues[curTgt++] = buffer[(int)(curIdx + i)]->tgtSent[j]; batchDecValues[curTgt++] = sample->tgtSent->Get(j);
} }
labelVaues[curTgt - 1] = EOS; labelVaues[curTgt - 1] = 2;
while (curSrc < maxSrcLen * (i + 1)) while (curSrc < maxSrcLen * (i + 1))
paddingEncValues[curSrc++] = 0; paddingEncValues[curSrc++] = 0;
while (curTgt < maxTgtLen * (i + 1)) while (curTgt < maxTgtLen * (i + 1))
paddingDecValues[curTgt++] = 0; paddingDecValues[curTgt++] = 0;
} }
int rbs = (int)realBatchSize; InitTensor2D(batchEnc, realBatchSize, maxSrcLen, X_INT, devID);
int msl = (int)maxSrcLen; InitTensor2D(paddingEnc, realBatchSize, maxSrcLen, X_FLOAT, devID);
InitTensor2D(batchEnc, rbs, msl, X_INT, devID); InitTensor2D(batchDec, realBatchSize, maxTgtLen, X_INT, devID);
InitTensor2D(paddingEnc, rbs, msl, X_FLOAT, devID); InitTensor2D(paddingDec, realBatchSize, maxTgtLen, X_FLOAT, devID);
InitTensor2D(batchDec, rbs, msl, X_INT, devID); InitTensor2D(label, realBatchSize, maxTgtLen, X_INT, devID);
InitTensor2D(paddingDec, rbs, msl, X_FLOAT, devID);
InitTensor2D(label, rbs, msl, X_INT, devID);
curIdx += realBatchSize; curIdx += realBatchSize;
...@@ -248,9 +250,22 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc, ...@@ -248,9 +250,22 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
delete[] paddingDecValues; delete[] paddingDecValues;
delete[] labelVaues; delete[] labelVaues;
info.Add(tgtTokenNum); wc = tgtTokenNum;
info.Add(realBatchSize); sc = realBatchSize;
return info; return true;
}
/*
clear the buffer
>> buf - the buffer (list) of samples
*/
void TrainDataSet::ClearSamples(XList* buf)
{
for (int i = 0; i < buf->count; i++) {
TrainExample* sample = (TrainExample*)buf->Get(i);
delete sample;
}
buf->Clear();
} }
/* /*
...@@ -263,98 +278,90 @@ void TrainDataSet::Init(const char* dataFile, int myBucketSize, bool training) ...@@ -263,98 +278,90 @@ void TrainDataSet::Init(const char* dataFile, int myBucketSize, bool training)
{ {
fp = fopen(dataFile, "rb"); fp = fopen(dataFile, "rb");
CheckNTErrors(fp, "can not open the training file"); CheckNTErrors(fp, "can not open the training file");
curIdx = 0;
bucketSize = myBucketSize;
isTraining = training;
LoadDataToBuffer();
SortByLength();
if (isTraining)
BuildBucket();
}
/* check if the buffer is empty */
bool TrainDataSet::IsEmpty() {
if (curIdx < buffer.Size())
return false;
return true;
}
/* reset the buffer */ int srcVocabSize = 0;
void TrainDataSet::ClearBuf() int tgtVocabSize = 0;
{ fread(&srcVocabSize, sizeof(int), 1, fp);
curIdx = 0; fread(&tgtVocabSize, sizeof(int), 1, fp);
CheckNTErrors(srcVocabSize > 0, "Invalid source vocabulary size");
CheckNTErrors(tgtVocabSize > 0, "Invalid target vocabulary size");
/* make different batches in different epochs */ fread(&totalSampleNum, sizeof(totalSampleNum), 1, fp);
SortByLength(); CheckNTErrors(totalSampleNum > 0, "Invalid sentence pairs number");
if (isTraining) bucketSize = myBucketSize;
BuildBucket(); isTraining = training;
} }
/* group data into buckets with similar length */ /* group data with similar length into buckets */
void TrainDataSet::BuildBucket() void TrainDataSet::BuildBucket(XList * buf)
{ {
size_t idx = 0; int idx = 0;
/* build and shuffle buckets */ /* build buckets by the length of source and target sentences */
while (idx < buffer.Size()) { while (idx < int(buf->Size())) {
/* sentence number in a bucket */ /* sentence number in a bucket */
size_t sentNum = 1; int sentNum = 1;
/* get the maximum source sentence length in a bucket */ /* get the maximum source sentence length in a bucket */
size_t maxSrcLen = buffer[(int)idx]->srcSent.Size(); int maxSrcLen = MaxSrcLen(buf, idx, idx + sentNum);
int maxTgtLen = MaxTgtLen(buf, idx, idx + sentNum);
/* bucketing for sentences */ int maxLen = MAX(maxSrcLen, maxTgtLen);
while ((sentNum < (buffer.Size() - idx))
&& (sentNum * maxSrcLen < bucketSize) /* the maximum sentence number in a bucket */
&& (sentNum * buffer[(int)(curIdx + sentNum)]->srcSent.Size() < bucketSize)) { const int MAX_SENT_NUM = 5120;
if (maxSrcLen < buffer[(int)(idx + sentNum)]->srcSent.Size())
maxSrcLen = buffer[(int)(idx + sentNum)]->srcSent.Size(); while ((sentNum < (buf->count - idx))
&& (sentNum < MAX_SENT_NUM)
&& (sentNum * maxLen <= bucketSize)) {
sentNum++; sentNum++;
maxSrcLen = MaxSrcLen(buf, idx, idx + sentNum);
maxTgtLen = MaxTgtLen(buf, idx, idx + sentNum);
maxLen = MAX(maxSrcLen, maxTgtLen);
} }
/* make sure the number is valid */ /* make sure the number is valid */
if ((buffer.Size() - idx) < sentNum) { if ((sentNum) * maxLen > bucketSize || sentNum >= MAX_SENT_NUM) {
sentNum = buffer.Size() - idx; sentNum--;
sentNum = max(8 * (sentNum / 8), sentNum % 8);
} }
if ((int(buf->Size()) - idx) < sentNum)
sentNum = int(buf->Size()) - idx;
/* assign the same key for items in a bucket */
int randomKey = rand(); int randomKey = rand();
for (int i = 0; i < sentNum; i++) {
/* shuffle items in a bucket */ ((TrainExample*)(buf->Get(idx + i)))->bucketKey = randomKey;
for (size_t i = 0; i < sentNum; i++) {
buffer[(int)(idx + i)]->bucketKey = randomKey;
} }
idx += sentNum; idx += sentNum;
} }
SortBucket();
/* sort buckets by their keys */
/* sort items in a bucket */ SortBuckets(buf);
idx = 0;
while (idx < buffer.Size()) {
size_t sentNum = 0;
int bucketKey = buffer[(int)(idx + sentNum)]->bucketKey;
while (sentNum < (buffer.Size() - idx)
&& buffer[(int)(idx + sentNum)]->bucketKey == bucketKey) {
buffer[(int)(idx + sentNum)]->key = (int)buffer[(int)(idx + sentNum)]->srcSent.Size();
sentNum++;
}
SortInBucket((int)idx, (int)(idx + sentNum));
idx += sentNum;
}
} }
/* de-constructor */ /* de-constructor */
TrainDataSet::~TrainDataSet() TrainDataSet::~TrainDataSet()
{ {
fclose(fp);
}
/* constructor */
TrainExample::TrainExample(int myID, int myKey, IntList* s, IntList* t)
{
id = myID;
bucketKey = myKey;
srcSent = s;
tgtSent = t;
}
/* release the buffer */ /* de-constructor */
for (int i = 0; i < buffer.Size(); i++) TrainExample::~TrainExample()
delete buffer[i]; {
delete srcSent;
delete tgtSent;
} }
} }
\ No newline at end of file
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -30,7 +30,6 @@ ...@@ -30,7 +30,6 @@
#include "../../../tensor/XTensor.h" #include "../../../tensor/XTensor.h"
#include "../../../tensor/XGlobal.h" #include "../../../tensor/XGlobal.h"
#define MAX_WORD_NUM 120
using namespace std; using namespace std;
...@@ -39,39 +38,54 @@ namespace nts { ...@@ -39,39 +38,54 @@ namespace nts {
/* a class of sentence pairs for training */ /* a class of sentence pairs for training */
struct TrainExample { struct TrainExample {
public:
/* id of the sentence pair */ /* id of the sentence pair */
int id; int id;
/* source language setence (tokenized) */ /* source language setence (tokenized) */
IntList srcSent; IntList * srcSent;
/* target language setence (tokenized) */ /* target language setence (tokenized) */
IntList tgtSent; IntList * tgtSent;
/* the key used to shuffle items in a bucket */
int key;
/* the key used to shuffle buckets */ /* the key used to shuffle buckets */
int bucketKey; int bucketKey;
public:
/* constructor */
TrainExample(int myID, int myKey, IntList* s, IntList* t);
/* de-constructor */
~TrainExample();
};
struct ReservedIDs {
/* the padding id */
int padID;
/* the unk id */
int unkID;
/* start symbol */
int startID;
/* end symbol */
int endID;
}; };
/* A `TrainDataSet` is associated with a file which contains training data. */ /* A `TrainDataSet` is associated with a file which contains training data. */
struct TrainDataSet { struct TrainDataSet {
public:
/* the data buffer */
TrainBufferType buffer;
/* a list of empty line number */ public:
IntList emptyLines;
/* the pointer to file stream */ /* the pointer to file stream */
FILE* fp; FILE* fp;
/* current index in the buffer */ /* number of training samples */
size_t curIdx; size_t totalSampleNum;
/* size of used data in the buffer */ /* buffer size */
size_t bufferUsed; size_t bufferSize;
/* size of the bucket used for grouping sentences */ /* size of the bucket used for grouping sentences */
size_t bucketSize; size_t bucketSize;
...@@ -79,34 +93,51 @@ public: ...@@ -79,34 +93,51 @@ public:
/* indicates whether it is used for training */ /* indicates whether it is used for training */
bool isTraining; bool isTraining;
/* the padding id */
int padID;
/* the unk id */
int unkID;
/* start symbol */
int startID;
/* end symbol */
int endID;
/* the maximum length for a source sentence */
int maxSrcLen;
/* the maximum length for a target sentence */
int maxTgtLen;
public: public:
/* sort the input by length (in descending order) */ /* get the maximum source sentence length in a range */
void SortByLength(); static
int MaxSrcLen(XList* buf, int begin, int end);
/* sort buckets by key (in descending order) */ /* get the maximum target sentence length in a range */
void SortBucket(); static
int MaxTgtLen(XList* buf, int begin, int end);
/* sort the output by key (in descending order) */ /* sort the input by source sentence length (in descending order) */
void SortInBucket(int begin, int end); void SortBySrcLength(XList* buf);
/* load data from a file to the buffer */ /* sort the input by target sentence length (in descending order) */
void LoadDataToBuffer(); void SortByTgtLength(XList* buf);
/* generate a mini-batch */ /* sort buckets by key (in descending order) */
UInt64List LoadBatch(XTensor* batchEnc, XTensor* paddingEnc, void SortBuckets(XList* buf);
XTensor* batchDec, XTensor* paddingDec, XTensor* label,
size_t minSentBatch, size_t batchSize, int devID);
/* load the samples into the buffer (a list) */ /* load the samples into the buffer (a list) */
bool LoadBatchToBuf(XList * buf); bool LoadBatchToBuf(XList * buf);
/* load the samples into tensors from the buffer */ /* load the samples into tensors from the buffer */
static static
bool LoadBatch(XList * buf, bool LoadBatch(XList * buf, int & curIdx,
XTensor* batchEnc, XTensor* paddingEnc, XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label, XTensor* batchDec, XTensor* paddingDec, XTensor* label, int minSentBatch, int batchSize, int devID,
size_t minSentBatch, size_t batchSize, int devID,
int &wc, int &sc); int &wc, int &sc);
/* release the samples in a buffer */ /* release the samples in a buffer */
...@@ -116,14 +147,8 @@ public: ...@@ -116,14 +147,8 @@ public:
/* initialization function */ /* initialization function */
void Init(const char* dataFile, int bucketSize, bool training); void Init(const char* dataFile, int bucketSize, bool training);
/* check if the buffer is empty */
bool IsEmpty();
/* reset the buffer */
void ClearBuf();
/* group data into buckets with similar length */ /* group data into buckets with similar length */
void BuildBucket(); void BuildBucket(XList * buf);
/* de-constructor */ /* de-constructor */
~TrainDataSet(); ~TrainDataSet();
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -39,6 +39,41 @@ namespace nmt ...@@ -39,6 +39,41 @@ namespace nmt
Trainer::Trainer() Trainer::Trainer()
{ {
cfg = NULL; cfg = NULL;
lrate = 0.0F;
lrbias = 0.0F;
sBatchSize = 0;
wBatchSize = 0;
bucketSize = 0;
nstep = 0;
nepoch = 0;
logInterval = 0;
maxCheckpoint = 0;
d = 0;
nwarmup = 0;
vSize = 0;
vSizeTgt = 0;
useAdam = false;
adamBeta1 = 0.0F;
adamBeta2 = 0.0F;
adamDelta = 0.0F;
isShuffled = false;
labelSmoothingP = 0.0F;
nStepCheckpoint = 0;
useEpochCheckpoint = false;
updateStep = 0;
isLenSorted = 0;
adamBeta1T = 1.0F;
adamBeta2T = 1.0F;
batchLoader.startID = 0;
batchLoader.endID = 0;
batchLoader.unkID = 0;
batchLoader.padID = 0;
batchLoader.maxSrcLen = 0;
batchLoader.maxTgtLen = 0;
batchLoader.bufferSize = 0;
} }
/* de-constructor */ /* de-constructor */
...@@ -62,13 +97,15 @@ initialization ...@@ -62,13 +97,15 @@ initialization
void Trainer::Init(Config& config) void Trainer::Init(Config& config)
{ {
cfg = &config; cfg = &config;
lrate = config.lrate; lrate = config.lrate;
lrbias = config.lrbias; lrbias = config.lrbias;
sBatchSize = config.sBatchSize; sBatchSize = config.sBatchSize;
wBatchSize = config.wBatchSize; wBatchSize = config.wBatchSize;
bucketSize = config.bucketSize; bucketSize = config.bucketSize;
nepoch = config.nepoch;
nstep = config.nstep; nstep = config.nstep;
nepoch = config.nepoch;
logInterval = config.logInterval;
maxCheckpoint = config.maxCheckpoint; maxCheckpoint = config.maxCheckpoint;
d = config.modelSize; d = config.modelSize;
nwarmup = config.nwarmup; nwarmup = config.nwarmup;
...@@ -87,6 +124,14 @@ void Trainer::Init(Config& config) ...@@ -87,6 +124,14 @@ void Trainer::Init(Config& config)
adamBeta1T = 1.0F; adamBeta1T = 1.0F;
adamBeta2T = 1.0F; adamBeta2T = 1.0F;
batchLoader.startID = config.startID;
batchLoader.endID = config.endID;
batchLoader.unkID = config.unkID;
batchLoader.padID = config.padID;
batchLoader.maxSrcLen = config.maxSrcLen;
batchLoader.maxTgtLen = config.maxTgtLen;
batchLoader.bufferSize = config.bufSize;
} }
/* /*
...@@ -106,7 +151,7 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -106,7 +151,7 @@ void Trainer::Train(const char* fn, const char* validFN,
} }
int step = 0; int step = 0;
int wc = 0; int wc = 0;
int ws = 0; int sc = 0;
int wordCount = 0; int wordCount = 0;
int wordCountTotal = 0; int wordCountTotal = 0;
int batchCountTotal = 0; int batchCountTotal = 0;
...@@ -134,6 +179,9 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -134,6 +179,9 @@ void Trainer::Train(const char* fn, const char* validFN,
double startT = GetClockSec(); double startT = GetClockSec();
int curIdx = 0;
XList* buf = new XList;
batchLoader.Init(fn, bucketSize, true); batchLoader.Init(fn, bucketSize, true);
for (epoch = 1; epoch <= nepoch; epoch++) { for (epoch = 1; epoch <= nepoch; epoch++) {
...@@ -141,10 +189,7 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -141,10 +189,7 @@ void Trainer::Train(const char* fn, const char* validFN,
wordCount = 0; wordCount = 0;
loss = 0; loss = 0;
/* reset the batch loader */ while (step++ < nstep)
batchLoader.ClearBuf();
while (!batchLoader.IsEmpty())
{ {
XNet net; XNet net;
net.Clear(); net.Clear();
...@@ -160,21 +205,26 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -160,21 +205,26 @@ void Trainer::Train(const char* fn, const char* validFN,
XTensor paddingEnc; XTensor paddingEnc;
XTensor paddingDec; XTensor paddingDec;
UInt64List info = batchLoader.LoadBatch(&batchEnc, &paddingEnc, &batchDec, &paddingDec, &label, if (curIdx == 0 || curIdx == buf->Size()) {
sBatchSize, wBatchSize, devID); curIdx = 0;
batchLoader.LoadBatchToBuf(buf);
}
batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, wBatchSize, devID, wc, sc);
wc = (int)info[0];
ws = (int)info[1];
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
/* output probabilities */ /* output probabilities */
XTensor output; XTensor output;
/* make the network */ /* make the network */
if (model->isLM) if (model->isLM) {
model->MakeLM(batchEnc, output, paddingEnc, true); model->MakeLM(batchEnc, output, paddingEnc, true);
else if (model->isMT) }
else if (model->isMT) {
model->MakeMT(batchEnc, batchDec, output, paddingEnc, paddingDec, true); model->MakeMT(batchEnc, batchDec, output, paddingEnc, paddingDec, true);
}
else { else {
ShowNTErrors("Illegal model type!"); ShowNTErrors("Illegal model type!");
} }
...@@ -192,15 +242,29 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -192,15 +242,29 @@ void Trainer::Train(const char* fn, const char* validFN,
DTYPE lossLocal = lossBatch / wc; DTYPE lossLocal = lossBatch / wc;
bool doUpdate = (!IsNAN(lossLocal) && !IsINF(lossLocal) && lossLocal < 1e3F); bool doUpdate = (!IsNAN(lossLocal) && !IsINF(lossLocal) && lossLocal < 1e3F);
net.isGradEfficient = true;
bool debug(false);
if (debug) {
LOG("after forward:");
batchEnc.mem->ShowMemUsage(stderr);
exit(0);
}
if (doUpdate) { if (doUpdate) {
/* back-propagation */
net.Backward(lossTensor); net.Backward(lossTensor);
if (model->encoder->useHistory)
model->encoder->history->ClearHistory(/*reset=*/false);
if (model->decoder->useHistory)
model->decoder->history->ClearHistory(/*reset=*/false);
gradStep += 1; gradStep += 1;
loss += lossBatch; loss += lossBatch;
wordCount += wc; wordCount += wc;
wordCountTotal += wc; wordCountTotal += wc;
batchCountTotal += ws; batchCountTotal += sc;
/* update the parameters */ /* update the parameters */
if (gradStep == updateStep) { if (gradStep == updateStep) {
...@@ -227,18 +291,7 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -227,18 +291,7 @@ void Trainer::Train(const char* fn, const char* validFN,
else else
nSkipped++; nSkipped++;
if (++step >= nstep) { if (step % logInterval == 0) {
isEnd = true;
break;
}
if (step == 10) {
// LOG("after backward --------");
// lossTensor.mem->ShowMemUsage(stderr);
// exit(0);
}
if (step % 100 == 0) {
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
LOG("elapsed=%.1fs, step=%d, epoch=%d, " LOG("elapsed=%.1fs, step=%d, epoch=%d, "
"total word=%d, total batch=%d, loss=%.3f, ppl=%.3f, lr=%.2e", "total word=%d, total batch=%d, loss=%.3f, ppl=%.3f, lr=%.2e",
...@@ -256,13 +309,13 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -256,13 +309,13 @@ void Trainer::Train(const char* fn, const char* validFN,
} }
} }
if (isEnd)
break;
if (useEpochCheckpoint) if (useEpochCheckpoint)
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch); MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
} }
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
epoch = MIN(epoch, nepoch); epoch = MIN(epoch, nepoch);
...@@ -287,8 +340,12 @@ test the model ...@@ -287,8 +340,12 @@ test the model
*/ */
void Trainer::Validate(const char* fn, const char* ofn, Model* model) void Trainer::Validate(const char* fn, const char* ofn, Model* model)
{ {
double startT = GetClockSec();
DISABLE_GRAD;
int wc = 0; int wc = 0;
int ws = 0; int sc = 0;
int wordCount = 0; int wordCount = 0;
int sentCount = 0; int sentCount = 0;
float loss = 0; float loss = 0;
...@@ -296,9 +353,14 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model) ...@@ -296,9 +353,14 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
/* data files */ /* data files */
batchLoader.Init(fn, 0, false); batchLoader.Init(fn, 0, false);
double startT = GetClockSec(); int curIdx = 0;
XList* buf = new XList;
/* set the buffer size to the size of valiadation set */
batchLoader.bufferSize = batchLoader.totalSampleNum;
batchLoader.LoadBatchToBuf(buf);
while (!batchLoader.IsEmpty()) while (curIdx < buf->count)
{ {
/* batch of input sequences */ /* batch of input sequences */
XTensor batchEnc; XTensor batchEnc;
...@@ -318,10 +380,9 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model) ...@@ -318,10 +380,9 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
XTensor labelOnehot; XTensor labelOnehot;
XTensor lossTensor; XTensor lossTensor;
UInt64List info = batchLoader.LoadBatch(&batchEnc, &paddingEnc, &batchDec, &paddingDec, &label, batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, 0, model->devID); sBatchSize, 0, model->devID, wc, sc);
wc = (int)info[0];
ws = (int)info[1];
CheckNTErrors(batchEnc.order == 2, "Wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "Wrong tensor order of the sequence batch");
/* make the network */ /* make the network */
...@@ -337,18 +398,31 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model) ...@@ -337,18 +398,31 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
int length = output.GetDim(1); int length = output.GetDim(1);
labelOnehot = IndexToOnehot(label, vSizeTgt, 0); labelOnehot = IndexToOnehot(label, vSizeTgt, 0);
lossTensor = CrossEntropy(output, labelOnehot, paddingDec); lossTensor = CrossEntropy(output, labelOnehot, paddingDec);
float lossBatch = ReduceSumAllValue(lossTensor); float lossBatch = ReduceSumAllValue(lossTensor);
loss += lossBatch; loss += lossBatch;
wordCount += wc; wordCount += wc;
sentCount += bSize; sentCount += bSize;
if (model->encoder->useHistory)
model->encoder->history->ClearHistory(/*reset=*/false);
if (model->decoder->useHistory)
model->decoder->history->ClearHistory(/*reset=*/false);
} }
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
LOG("test finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)", ENABLE_GRAD;
LOG("validating finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)",
elapsed, sentCount, wordCount, loss / wordCount / log(2.0), exp(loss / wordCount)); elapsed, sentCount, wordCount, loss / wordCount / log(2.0), exp(loss / wordCount));
} }
...@@ -428,7 +502,7 @@ void Trainer::Update(Model* model, const float lr) ...@@ -428,7 +502,7 @@ void Trainer::Update(Model* model, const float lr)
_ScaleAndShiftMe(v, (1.0F - adamBeta2), 0); _ScaleAndShiftMe(v, (1.0F - adamBeta2), 0);
/* v2 = m / (sqrt(v) + delta) */ /* v2 = m / (sqrt(v) + delta) */
XTensor* v2 = NewTensorBuf(v, v->devID); XTensor* v2 = NewTensorBufV2(v, v->devID, v->mem);
_Power(v, v2, 0.5F); _Power(v, v2, 0.5F);
_ScaleAndShiftMe(v2, 1.0F, d); _ScaleAndShiftMe(v2, 1.0F, d);
_Div(m, v2, v2); _Div(m, v2, v2);
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -70,6 +70,9 @@ public: ...@@ -70,6 +70,9 @@ public:
/* traing step number */ /* traing step number */
int nstep; int nstep;
/* interval step for logging */
int logInterval;
/* the maximum number of saved checkpoints */ /* the maximum number of saved checkpoints */
int maxCheckpoint; int maxCheckpoint;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论