/* NiuTrans.NMT - an open-source neural machine translation system. * Copyright (C) 2020 NiuTrans Research. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * $Created by: HU Chi (huchinlp@foxmail.com) 2020-08-09 * TODO: refactor the data loader class and references */ #include <string> #include <vector> #include <cstdlib> #include <fstream> #include <algorithm> #include "TrainDataSet.h" #include "../Utility.h" #include "../translate/Vocab.h" using namespace nmt; namespace nts { /* sort the dataset by length (in descending order) */ void TrainDataSet::SortByLength() { sort(buffer.items, buffer.items + buffer.count, [](TrainExample* a, TrainExample* b) { return (a->srcSent.Size() + a->tgtSent.Size()) > (b->srcSent.Size() + b->tgtSent.Size()); }); } /* sort buckets by key (in descending order) */ void TrainDataSet::SortBucket() { sort(buffer.items, buffer.items + buffer.count, [](TrainExample* a, TrainExample* b) { return a->bucketKey > b->bucketKey; }); } /* sort the output by key in a range (in descending order) >> begin - the first index of the range >> end - the last index of the range */ void TrainDataSet::SortInBucket(int begin, int end) { sort(buffer.items + begin, buffer.items + end, [](TrainExample* a, TrainExample* b) { return (a->key > b->key); }); } /* load all data from a file to the buffer training data format (binary): first 8 bit: number of sentence pairs subsequent segements: source sentence length (4 bit) target sentence length (4 bit) source tokens (4 bit per token) target tokens (4 bit per token) */ void TrainDataSet::LoadDataToBuffer() { buffer.Clear(); curIdx = 0; int id = 0; uint64_t sentNum = 0; int srcVocabSize = 0; int tgtVocabSize = 0; fread(&srcVocabSize, sizeof(srcVocabSize), 1, fp); fread(&tgtVocabSize, sizeof(tgtVocabSize), 1, fp); fread(&sentNum, sizeof(uint64_t), 1, fp); CheckNTErrors(sentNum > 0, "Invalid sentence pairs number"); while (id < sentNum) { int srcLen = 0; int tgtLen = 0; fread(&srcLen, sizeof(int), 1, fp); fread(&tgtLen, sizeof(int), 1, fp); CheckNTErrors(srcLen > 0, "Invalid source sentence length"); CheckNTErrors(tgtLen > 0, "Invalid target sentence length"); IntList srcSent; IntList tgtSent; srcSent.ReadFromFile(fp, srcLen); tgtSent.ReadFromFile(fp, tgtLen); TrainExample* example = new TrainExample; example->id = id++; example->key = id; example->srcSent = srcSent; example->tgtSent = tgtSent; buffer.Add(example); } fclose(fp); XPRINT1(0, stderr, "[INFO] loaded %d sentences\n", id); } /* load a mini-batch to the device (for training) >> batchEnc - a tensor to store the batch of encoder input >> paddingEnc - a tensor to store the batch of encoder paddings >> batchDec - a tensor to store the batch of decoder input >> paddingDec - a tensor to store the batch of decoder paddings >> label - a tensor to store the label of input >> minSentBatch - the minimum number of sentence batch >> batchSize - the maxium number of words in a batch >> devID - the device id, -1 for the CPU << return - number of target tokens and sentences */ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc, XTensor* batchDec, XTensor* paddingDec, XTensor* label, size_t minSentBatch, size_t batchSize, int devID) { UInt64List info; size_t srcTokenNum = 0; size_t tgtTokenNum = 0; size_t realBatchSize = 1; if (!isTraining) realBatchSize = minSentBatch; /* get the maximum source sentence length in a mini-batch */ size_t maxSrcLen = buffer[(int)curIdx]->srcSent.Size(); /* max batch size */ const int MAX_BATCH_SIZE = 512; /* dynamic batching for sentences, enabled when the dataset is used for training */ if (isTraining) { while ((realBatchSize < (buffer.Size() - curIdx)) && (realBatchSize * maxSrcLen < batchSize) && (realBatchSize < MAX_BATCH_SIZE) && (realBatchSize * buffer[(int)(curIdx + realBatchSize)]->srcSent.Size() < batchSize)) { if (maxSrcLen < buffer[(int)(curIdx + realBatchSize)]->srcSent.Size()) maxSrcLen = buffer[(int)(curIdx + realBatchSize)]->srcSent.Size(); realBatchSize++; } } /* real batch size */ if ((buffer.Size() - curIdx) < realBatchSize) { realBatchSize = buffer.Size() - curIdx; } CheckNTErrors(realBatchSize > 0, "Invalid batch size"); /* get the maximum target sentence length in a mini-batch */ size_t maxTgtLen = buffer[(int)curIdx]->tgtSent.Size(); for (size_t i = 0; i < realBatchSize; i++) { if (maxTgtLen < buffer[(int)(curIdx + i)]->tgtSent.Size()) maxTgtLen = buffer[(int)(curIdx + i)]->tgtSent.Size(); } for (size_t i = 0; i < realBatchSize; i++) { if (maxSrcLen < buffer[(int)(curIdx + i)]->srcSent.Size()) maxSrcLen = buffer[(int)(curIdx + i)]->srcSent.Size(); } CheckNTErrors(maxSrcLen != 0, "Invalid source length for batching"); int* batchEncValues = new int[realBatchSize * maxSrcLen]; float* paddingEncValues = new float[realBatchSize * maxSrcLen]; int* labelVaues = new int[realBatchSize * maxTgtLen]; int* batchDecValues = new int[realBatchSize * maxTgtLen]; float* paddingDecValues = new float[realBatchSize * maxTgtLen]; for (int i = 0; i < realBatchSize * maxSrcLen; i++) { batchEncValues[i] = PAD; paddingEncValues[i] = 1; } for (int i = 0; i < realBatchSize * maxTgtLen; i++) { batchDecValues[i] = PAD; labelVaues[i] = PAD; paddingDecValues[i] = 1.0F; } size_t curSrc = 0; size_t curTgt = 0; /* batchEnc: end with EOS (left padding) batchDec: begin with SOS (right padding) label: end with EOS (right padding) */ for (int i = 0; i < realBatchSize; ++i) { srcTokenNum += buffer[(int)(curIdx + i)]->srcSent.Size(); tgtTokenNum += buffer[(int)(curIdx + i)]->tgtSent.Size(); curSrc = maxSrcLen * i; for (int j = 0; j < buffer[(int)(curIdx + i)]->srcSent.Size(); j++) { batchEncValues[curSrc++] = buffer[(int)(curIdx + i)]->srcSent[j]; } curTgt = maxTgtLen * i; for (int j = 0; j < buffer[(int)(curIdx + i)]->tgtSent.Size(); j++) { if (j > 0) labelVaues[curTgt - 1] = buffer[(int)(curIdx + i)]->tgtSent[j]; batchDecValues[curTgt++] = buffer[(int)(curIdx + i)]->tgtSent[j]; } labelVaues[curTgt - 1] = EOS; while (curSrc < maxSrcLen * (i + 1)) paddingEncValues[curSrc++] = 0; while (curTgt < maxTgtLen * (i + 1)) paddingDecValues[curTgt++] = 0; } int rbs = (int)realBatchSize; int msl = (int)maxSrcLen; InitTensor2D(batchEnc, rbs, msl, X_INT, devID); InitTensor2D(paddingEnc, rbs, msl, X_FLOAT, devID); InitTensor2D(batchDec, rbs, msl, X_INT, devID); InitTensor2D(paddingDec, rbs, msl, X_FLOAT, devID); InitTensor2D(label, rbs, msl, X_INT, devID); curIdx += realBatchSize; batchEnc->SetData(batchEncValues, batchEnc->unitNum); paddingEnc->SetData(paddingEncValues, paddingEnc->unitNum); batchDec->SetData(batchDecValues, batchDec->unitNum); paddingDec->SetData(paddingDecValues, paddingDec->unitNum); label->SetData(labelVaues, label->unitNum); delete[] batchEncValues; delete[] paddingEncValues; delete[] batchDecValues; delete[] paddingDecValues; delete[] labelVaues; info.Add(tgtTokenNum); info.Add(realBatchSize); return info; } /* the constructor of DataSet >> dataFile - path of the data file >> bucketSize - size of the bucket to keep similar length sentence pairs >> training - indicates whether it is used for training */ void TrainDataSet::Init(const char* dataFile, int myBucketSize, bool training) { fp = fopen(dataFile, "rb"); CheckNTErrors(fp, "can not open the training file"); curIdx = 0; bucketSize = myBucketSize; isTraining = training; LoadDataToBuffer(); SortByLength(); if (isTraining) BuildBucket(); } /* check if the buffer is empty */ bool TrainDataSet::IsEmpty() { if (curIdx < buffer.Size()) return false; return true; } /* reset the buffer */ void TrainDataSet::ClearBuf() { curIdx = 0; /* make different batches in different epochs */ SortByLength(); if (isTraining) BuildBucket(); } /* group data into buckets with similar length */ void TrainDataSet::BuildBucket() { size_t idx = 0; /* build and shuffle buckets */ while (idx < buffer.Size()) { /* sentence number in a bucket */ size_t sentNum = 1; /* get the maximum source sentence length in a bucket */ size_t maxSrcLen = buffer[(int)idx]->srcSent.Size(); /* bucketing for sentences */ while ((sentNum < (buffer.Size() - idx)) && (sentNum * maxSrcLen < bucketSize) && (sentNum * buffer[(int)(curIdx + sentNum)]->srcSent.Size() < bucketSize)) { if (maxSrcLen < buffer[(int)(idx + sentNum)]->srcSent.Size()) maxSrcLen = buffer[(int)(idx + sentNum)]->srcSent.Size(); sentNum++; } /* make sure the number is valid */ if ((buffer.Size() - idx) < sentNum) { sentNum = buffer.Size() - idx; } int randomKey = rand(); /* shuffle items in a bucket */ for (size_t i = 0; i < sentNum; i++) { buffer[(int)(idx + i)]->bucketKey = randomKey; } idx += sentNum; } SortBucket(); /* sort items in a bucket */ idx = 0; while (idx < buffer.Size()) { size_t sentNum = 0; int bucketKey = buffer[(int)(idx + sentNum)]->bucketKey; while (sentNum < (buffer.Size() - idx) && buffer[(int)(idx + sentNum)]->bucketKey == bucketKey) { buffer[(int)(idx + sentNum)]->key = (int)buffer[(int)(idx + sentNum)]->srcSent.Size(); sentNum++; } SortInBucket((int)idx, (int)(idx + sentNum)); idx += sentNum; } } /* de-constructor */ TrainDataSet::~TrainDataSet() { /* release the buffer */ for (int i = 0; i < buffer.Size(); i++) delete buffer[i]; } }