Commit 98a9130d by hello

Refactor class `TrainDataSet`

parent 4bbd6a27
/* NiuTrans.NMT - an open-source neural machine translation system.
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -16,13 +16,10 @@
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2020-08-09
* TODO: refactor the data loader class and references
* $Updated by: CAO Hang and Wu Siming 2020-12-13
*/
#include <string>
#include <vector>
#include <cstdlib>
#include <fstream>
#include <algorithm>
#include "TrainDataSet.h"
......@@ -33,37 +30,56 @@ using namespace nmt;
namespace nts {
/* sort the dataset by length (in descending order) */
void TrainDataSet::SortByLength() {
sort(buffer.items, buffer.items + buffer.count,
[](TrainExample* a, TrainExample* b) {
return (a->srcSent.Size() + a->tgtSent.Size())
> (b->srcSent.Size() + b->tgtSent.Size());
});
/* get the maximum source sentence length in a range */
int TrainDataSet::MaxSrcLen(XList* buf, int begin, int end) {
CheckNTErrors((end > begin) && (begin >= 0) && (end <= buf->count), "Invalid range");
int maxLen = 0;
for (int i = begin; i < end; i++) {
IntList* srcSent = ((TrainExample*)buf->Get(i))->srcSent;
maxLen = MAX(int(srcSent->Size()), maxLen);
}
return maxLen;
}
/* sort buckets by key (in descending order) */
void TrainDataSet::SortBucket() {
sort(buffer.items, buffer.items + buffer.count,
[](TrainExample* a, TrainExample* b) {
return a->bucketKey > b->bucketKey;
});
/* get the maximum target sentence length in a range */
int TrainDataSet::MaxTgtLen(XList* buf, int begin, int end) {
CheckNTErrors((end > begin) && (begin >= 0) && (end <= buf->count), "Invalid range");
int maxLen = 0;
for (int i = begin; i < end; i++) {
IntList* tgtSent = ((TrainExample*)buf->Get(i))->tgtSent;
maxLen = MAX(int(tgtSent->Size()), maxLen);
}
return maxLen;
}
/*
sort the output by key in a range (in descending order)
>> begin - the first index of the range
>> end - the last index of the range
*/
void TrainDataSet::SortInBucket(int begin, int end) {
sort(buffer.items + begin, buffer.items + end,
[](TrainExample* a, TrainExample* b) {
return (a->key > b->key);
});
/* sort the buffer by source sentence length (in descending order) */
void TrainDataSet::SortBySrcLength(XList* buf) {
stable_sort(buf->items, buf->items + buf->count,
[](void* a, void* b) {
return ((TrainExample*)(a))->srcSent->Size() <
((TrainExample*)(b))->srcSent->Size();
});
}
/* sort the buffer by target sentence length (in descending order) */
void TrainDataSet::SortByTgtLength(XList* buf) {
stable_sort(buf->items, buf->items + buf->count,
[](void* a, void* b) {
return ((TrainExample*)(a))->tgtSent->Size() <
((TrainExample*)(b))->tgtSent->Size();
});
}
/* sort buckets by key (in descending order) */
void TrainDataSet::SortBuckets(XList* buf) {
sort(buf->items, buf->items + buf->count,
[](void* a, void* b) {
return ((TrainExample*)(a))->bucketKey <
((TrainExample*)(b))->bucketKey;
});
}
/*
load all data from a file to the buffer
load samples from a file into the buffer
training data format (binary):
first 8 bit: number of sentence pairs
subsequent segements:
......@@ -71,52 +87,63 @@ source sentence length (4 bit)
target sentence length (4 bit)
source tokens (4 bit per token)
target tokens (4 bit per token)
>> buf - the buffer (list) of samples
*/
void TrainDataSet::LoadDataToBuffer()
bool TrainDataSet::LoadBatchToBuf(XList* buf)
{
buffer.Clear();
curIdx = 0;
int id = 0;
uint64_t sentNum = 0;
ClearSamples(buf);
int srcVocabSize = 0;
int tgtVocabSize = 0;
fread(&srcVocabSize, sizeof(srcVocabSize), 1, fp);
fread(&tgtVocabSize, sizeof(tgtVocabSize), 1, fp);
int sampleNum = 0;
fread(&sentNum, sizeof(uint64_t), 1, fp);
CheckNTErrors(sentNum > 0, "Invalid sentence pairs number");
while ((sampleNum < bufferSize)) {
while (id < sentNum) {
int srcLen = 0;
int tgtLen = 0;
fread(&srcLen, sizeof(int), 1, fp);
size_t n = fread(&srcLen, sizeof(int), 1, fp);
if (n == 0)
break;
fread(&tgtLen, sizeof(int), 1, fp);
CheckNTErrors(srcLen > 0, "Invalid source sentence length");
CheckNTErrors(tgtLen > 0, "Invalid target sentence length");
IntList srcSent;
IntList tgtSent;
srcSent.ReadFromFile(fp, srcLen);
tgtSent.ReadFromFile(fp, tgtLen);
IntList *srcSent = new IntList(srcLen);
IntList *tgtSent = new IntList(tgtLen);
srcSent->ReadFromFile(fp, srcLen);
tgtSent->ReadFromFile(fp, tgtLen);
TrainExample* example = new TrainExample(sampleNum++, 0, srcSent, tgtSent);
buf->Add(example);
}
/* reset the file pointer to the begin */
if (feof(fp) && isTraining) {
TrainExample* example = new TrainExample;
example->id = id++;
example->key = id;
example->srcSent = srcSent;
example->tgtSent = tgtSent;
rewind(fp);
buffer.Add(example);
int srcVocabSize = 0;
int tgtVocabSize = 0;
fread(&srcVocabSize, sizeof(int), 1, fp);
fread(&tgtVocabSize, sizeof(int), 1, fp);
fread(&totalSampleNum, sizeof(totalSampleNum), 1, fp);
}
fclose(fp);
/* group samples in the buffer into buckets */
SortByTgtLength(buf);
SortBySrcLength(buf);
if (isTraining)
BuildBucket(buf);
XPRINT1(0, stderr, "[INFO] loaded %d sentences\n", id);
return true;
}
/*
load a mini-batch to the device (for training)
load a mini-batch to a device
>> buf - the buffer (list) of samples
>> curIdx - the index of the buffer
>> batchEnc - a tensor to store the batch of encoder input
>> paddingEnc - a tensor to store the batch of encoder paddings
>> batchDec - a tensor to store the batch of decoder input
......@@ -125,57 +152,34 @@ load a mini-batch to the device (for training)
>> minSentBatch - the minimum number of sentence batch
>> batchSize - the maxium number of words in a batch
>> devID - the device id, -1 for the CPU
<< return - number of target tokens and sentences
>> wc - number of target words in a batch
>> sc - number of samples in a batch
*/
UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label,
size_t minSentBatch, size_t batchSize, int devID)
bool TrainDataSet::LoadBatch(XList* buf, int & curIdx,
XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label,
int minSentBatch, int batchSize, int devID, int& wc, int& sc)
{
UInt64List info;
size_t srcTokenNum = 0;
size_t tgtTokenNum = 0;
size_t realBatchSize = 1;
if (!isTraining)
realBatchSize = minSentBatch;
/* get the maximum source sentence length in a mini-batch */
size_t maxSrcLen = buffer[(int)curIdx]->srcSent.Size();
/* max batch size */
const int MAX_BATCH_SIZE = 512;
/* dynamic batching for sentences, enabled when the dataset is used for training */
if (isTraining) {
while ((realBatchSize < (buffer.Size() - curIdx))
&& (realBatchSize * maxSrcLen < batchSize)
&& (realBatchSize < MAX_BATCH_SIZE)
&& (realBatchSize * buffer[(int)(curIdx + realBatchSize)]->srcSent.Size() < batchSize)) {
if (maxSrcLen < buffer[(int)(curIdx + realBatchSize)]->srcSent.Size())
maxSrcLen = buffer[(int)(curIdx + realBatchSize)]->srcSent.Size();
realBatchSize++;
}
}
/* real batch size */
if ((buffer.Size() - curIdx) < realBatchSize) {
realBatchSize = buffer.Size() - curIdx;
int srcTokenNum = 0;
int tgtTokenNum = 0;
int realBatchSize = 0;
/* dynamic batching for sentences */
int bucketKey = ((TrainExample*)(buf->Get(curIdx)))->bucketKey;
while ((realBatchSize < (int(buf->Size()) - curIdx)) &&
(((TrainExample*)(buf->Get(curIdx + realBatchSize)))->bucketKey == bucketKey)) {
realBatchSize++;
}
realBatchSize = MIN(realBatchSize, (int(buf->Size()) - curIdx));
CheckNTErrors(realBatchSize > 0, "Invalid batch size");
/* get the maximum target sentence length in a mini-batch */
size_t maxTgtLen = buffer[(int)curIdx]->tgtSent.Size();
for (size_t i = 0; i < realBatchSize; i++) {
if (maxTgtLen < buffer[(int)(curIdx + i)]->tgtSent.Size())
maxTgtLen = buffer[(int)(curIdx + i)]->tgtSent.Size();
}
for (size_t i = 0; i < realBatchSize; i++) {
if (maxSrcLen < buffer[(int)(curIdx + i)]->srcSent.Size())
maxSrcLen = buffer[(int)(curIdx + i)]->srcSent.Size();
}
int maxSrcLen = MaxSrcLen(buf, curIdx, curIdx + realBatchSize);
int maxTgtLen = MaxTgtLen(buf, curIdx, curIdx + realBatchSize);
CheckNTErrors(maxSrcLen != 0, "Invalid source length for batching");
CheckNTErrors(maxSrcLen > 0, "Invalid source length for batching");
CheckNTErrors(maxTgtLen > 0, "Invalid target length for batching");
int* batchEncValues = new int[realBatchSize * maxSrcLen];
float* paddingEncValues = new float[realBatchSize * maxSrcLen];
......@@ -185,17 +189,17 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
float* paddingDecValues = new float[realBatchSize * maxTgtLen];
for (int i = 0; i < realBatchSize * maxSrcLen; i++) {
batchEncValues[i] = PAD;
paddingEncValues[i] = 1;
batchEncValues[i] = 1;
paddingEncValues[i] = 1.0F;
}
for (int i = 0; i < realBatchSize * maxTgtLen; i++) {
batchDecValues[i] = PAD;
labelVaues[i] = PAD;
batchDecValues[i] = 1;
labelVaues[i] = 1;
paddingDecValues[i] = 1.0F;
}
size_t curSrc = 0;
size_t curTgt = 0;
int curSrc = 0;
int curTgt = 0;
/*
batchEnc: end with EOS (left padding)
......@@ -204,35 +208,33 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
*/
for (int i = 0; i < realBatchSize; ++i) {
srcTokenNum += buffer[(int)(curIdx + i)]->srcSent.Size();
tgtTokenNum += buffer[(int)(curIdx + i)]->tgtSent.Size();
TrainExample* sample = (TrainExample*)(buf->Get(curIdx + i));
srcTokenNum += int(sample->srcSent->Size());
tgtTokenNum += int(sample->tgtSent->Size());
curSrc = maxSrcLen * i;
for (int j = 0; j < buffer[(int)(curIdx + i)]->srcSent.Size(); j++) {
batchEncValues[curSrc++] = buffer[(int)(curIdx + i)]->srcSent[j];
for (int j = 0; j < sample->srcSent->Size(); j++) {
batchEncValues[curSrc++] = sample->srcSent->Get(j);
}
curTgt = maxTgtLen * i;
for (int j = 0; j < buffer[(int)(curIdx + i)]->tgtSent.Size(); j++) {
for (int j = 0; j < sample->tgtSent->Size(); j++) {
if (j > 0)
labelVaues[curTgt - 1] = buffer[(int)(curIdx + i)]->tgtSent[j];
batchDecValues[curTgt++] = buffer[(int)(curIdx + i)]->tgtSent[j];
labelVaues[curTgt - 1] = sample->tgtSent->Get(j);
batchDecValues[curTgt++] = sample->tgtSent->Get(j);
}
labelVaues[curTgt - 1] = EOS;
labelVaues[curTgt - 1] = 2;
while (curSrc < maxSrcLen * (i + 1))
paddingEncValues[curSrc++] = 0;
while (curTgt < maxTgtLen * (i + 1))
paddingDecValues[curTgt++] = 0;
}
int rbs = (int)realBatchSize;
int msl = (int)maxSrcLen;
InitTensor2D(batchEnc, rbs, msl, X_INT, devID);
InitTensor2D(paddingEnc, rbs, msl, X_FLOAT, devID);
InitTensor2D(batchDec, rbs, msl, X_INT, devID);
InitTensor2D(paddingDec, rbs, msl, X_FLOAT, devID);
InitTensor2D(label, rbs, msl, X_INT, devID);
InitTensor2D(batchEnc, realBatchSize, maxSrcLen, X_INT, devID);
InitTensor2D(paddingEnc, realBatchSize, maxSrcLen, X_FLOAT, devID);
InitTensor2D(batchDec, realBatchSize, maxTgtLen, X_INT, devID);
InitTensor2D(paddingDec, realBatchSize, maxTgtLen, X_FLOAT, devID);
InitTensor2D(label, realBatchSize, maxTgtLen, X_INT, devID);
curIdx += realBatchSize;
......@@ -248,9 +250,22 @@ UInt64List TrainDataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
delete[] paddingDecValues;
delete[] labelVaues;
info.Add(tgtTokenNum);
info.Add(realBatchSize);
return info;
wc = tgtTokenNum;
sc = realBatchSize;
return true;
}
/*
clear the buffer
>> buf - the buffer (list) of samples
*/
void TrainDataSet::ClearSamples(XList* buf)
{
for (int i = 0; i < buf->count; i++) {
TrainExample* sample = (TrainExample*)buf->Get(i);
delete sample;
}
buf->Clear();
}
/*
......@@ -263,98 +278,90 @@ void TrainDataSet::Init(const char* dataFile, int myBucketSize, bool training)
{
fp = fopen(dataFile, "rb");
CheckNTErrors(fp, "can not open the training file");
curIdx = 0;
bucketSize = myBucketSize;
isTraining = training;
LoadDataToBuffer();
SortByLength();
if (isTraining)
BuildBucket();
}
/* check if the buffer is empty */
bool TrainDataSet::IsEmpty() {
if (curIdx < buffer.Size())
return false;
return true;
}
/* reset the buffer */
void TrainDataSet::ClearBuf()
{
curIdx = 0;
int srcVocabSize = 0;
int tgtVocabSize = 0;
fread(&srcVocabSize, sizeof(int), 1, fp);
fread(&tgtVocabSize, sizeof(int), 1, fp);
CheckNTErrors(srcVocabSize > 0, "Invalid source vocabulary size");
CheckNTErrors(tgtVocabSize > 0, "Invalid target vocabulary size");
/* make different batches in different epochs */
SortByLength();
fread(&totalSampleNum, sizeof(totalSampleNum), 1, fp);
CheckNTErrors(totalSampleNum > 0, "Invalid sentence pairs number");
if (isTraining)
BuildBucket();
bucketSize = myBucketSize;
isTraining = training;
}
/* group data into buckets with similar length */
void TrainDataSet::BuildBucket()
/* group data with similar length into buckets */
void TrainDataSet::BuildBucket(XList * buf)
{
size_t idx = 0;
int idx = 0;
/* build and shuffle buckets */
while (idx < buffer.Size()) {
/* build buckets by the length of source and target sentences */
while (idx < int(buf->Size())) {
/* sentence number in a bucket */
size_t sentNum = 1;
int sentNum = 1;
/* get the maximum source sentence length in a bucket */
size_t maxSrcLen = buffer[(int)idx]->srcSent.Size();
/* bucketing for sentences */
while ((sentNum < (buffer.Size() - idx))
&& (sentNum * maxSrcLen < bucketSize)
&& (sentNum * buffer[(int)(curIdx + sentNum)]->srcSent.Size() < bucketSize)) {
if (maxSrcLen < buffer[(int)(idx + sentNum)]->srcSent.Size())
maxSrcLen = buffer[(int)(idx + sentNum)]->srcSent.Size();
int maxSrcLen = MaxSrcLen(buf, idx, idx + sentNum);
int maxTgtLen = MaxTgtLen(buf, idx, idx + sentNum);
int maxLen = MAX(maxSrcLen, maxTgtLen);
/* the maximum sentence number in a bucket */
const int MAX_SENT_NUM = 5120;
while ((sentNum < (buf->count - idx))
&& (sentNum < MAX_SENT_NUM)
&& (sentNum * maxLen <= bucketSize)) {
sentNum++;
maxSrcLen = MaxSrcLen(buf, idx, idx + sentNum);
maxTgtLen = MaxTgtLen(buf, idx, idx + sentNum);
maxLen = MAX(maxSrcLen, maxTgtLen);
}
/* make sure the number is valid */
if ((buffer.Size() - idx) < sentNum) {
sentNum = buffer.Size() - idx;
if ((sentNum) * maxLen > bucketSize || sentNum >= MAX_SENT_NUM) {
sentNum--;
sentNum = max(8 * (sentNum / 8), sentNum % 8);
}
if ((int(buf->Size()) - idx) < sentNum)
sentNum = int(buf->Size()) - idx;
/* assign the same key for items in a bucket */
int randomKey = rand();
/* shuffle items in a bucket */
for (size_t i = 0; i < sentNum; i++) {
buffer[(int)(idx + i)]->bucketKey = randomKey;
for (int i = 0; i < sentNum; i++) {
((TrainExample*)(buf->Get(idx + i)))->bucketKey = randomKey;
}
idx += sentNum;
}
SortBucket();
/* sort items in a bucket */
idx = 0;
while (idx < buffer.Size()) {
size_t sentNum = 0;
int bucketKey = buffer[(int)(idx + sentNum)]->bucketKey;
while (sentNum < (buffer.Size() - idx)
&& buffer[(int)(idx + sentNum)]->bucketKey == bucketKey) {
buffer[(int)(idx + sentNum)]->key = (int)buffer[(int)(idx + sentNum)]->srcSent.Size();
sentNum++;
}
SortInBucket((int)idx, (int)(idx + sentNum));
idx += sentNum;
}
/* sort buckets by their keys */
SortBuckets(buf);
}
/* de-constructor */
TrainDataSet::~TrainDataSet()
{
fclose(fp);
}
/* constructor */
TrainExample::TrainExample(int myID, int myKey, IntList* s, IntList* t)
{
id = myID;
bucketKey = myKey;
srcSent = s;
tgtSent = t;
}
/* release the buffer */
for (int i = 0; i < buffer.Size(); i++)
delete buffer[i];
/* de-constructor */
TrainExample::~TrainExample()
{
delete srcSent;
delete tgtSent;
}
}
\ No newline at end of file
}
/* NiuTrans.NMT - an open-source neural machine translation system.
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -30,7 +30,6 @@
#include "../../../tensor/XTensor.h"
#include "../../../tensor/XGlobal.h"
#define MAX_WORD_NUM 120
using namespace std;
......@@ -39,39 +38,54 @@ namespace nts {
/* a class of sentence pairs for training */
struct TrainExample {
public:
/* id of the sentence pair */
int id;
/* source language setence (tokenized) */
IntList srcSent;
IntList * srcSent;
/* target language setence (tokenized) */
IntList tgtSent;
/* the key used to shuffle items in a bucket */
int key;
IntList * tgtSent;
/* the key used to shuffle buckets */
int bucketKey;
public:
/* constructor */
TrainExample(int myID, int myKey, IntList* s, IntList* t);
/* de-constructor */
~TrainExample();
};
struct ReservedIDs {
/* the padding id */
int padID;
/* the unk id */
int unkID;
/* start symbol */
int startID;
/* end symbol */
int endID;
};
/* A `TrainDataSet` is associated with a file which contains training data. */
struct TrainDataSet {
public:
/* the data buffer */
TrainBufferType buffer;
/* a list of empty line number */
IntList emptyLines;
public:
/* the pointer to file stream */
FILE* fp;
/* current index in the buffer */
size_t curIdx;
/* number of training samples */
size_t totalSampleNum;
/* size of used data in the buffer */
size_t bufferUsed;
/* buffer size */
size_t bufferSize;
/* size of the bucket used for grouping sentences */
size_t bucketSize;
......@@ -79,34 +93,51 @@ public:
/* indicates whether it is used for training */
bool isTraining;
/* the padding id */
int padID;
/* the unk id */
int unkID;
/* start symbol */
int startID;
/* end symbol */
int endID;
/* the maximum length for a source sentence */
int maxSrcLen;
/* the maximum length for a target sentence */
int maxTgtLen;
public:
/* sort the input by length (in descending order) */
void SortByLength();
/* get the maximum source sentence length in a range */
static
int MaxSrcLen(XList* buf, int begin, int end);
/* sort buckets by key (in descending order) */
void SortBucket();
/* get the maximum target sentence length in a range */
static
int MaxTgtLen(XList* buf, int begin, int end);
/* sort the output by key (in descending order) */
void SortInBucket(int begin, int end);
/* sort the input by source sentence length (in descending order) */
void SortBySrcLength(XList* buf);
/* load data from a file to the buffer */
void LoadDataToBuffer();
/* sort the input by target sentence length (in descending order) */
void SortByTgtLength(XList* buf);
/* generate a mini-batch */
UInt64List LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label,
size_t minSentBatch, size_t batchSize, int devID);
/* sort buckets by key (in descending order) */
void SortBuckets(XList* buf);
/* load the samples into the buffer (a list) */
bool LoadBatchToBuf(XList * buf);
/* load the samples into tensors from the buffer */
static
bool LoadBatch(XList * buf,
bool LoadBatch(XList * buf, int & curIdx,
XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label,
size_t minSentBatch, size_t batchSize, int devID,
XTensor* batchDec, XTensor* paddingDec, XTensor* label, int minSentBatch, int batchSize, int devID,
int &wc, int &sc);
/* release the samples in a buffer */
......@@ -116,14 +147,8 @@ public:
/* initialization function */
void Init(const char* dataFile, int bucketSize, bool training);
/* check if the buffer is empty */
bool IsEmpty();
/* reset the buffer */
void ClearBuf();
/* group data into buckets with similar length */
void BuildBucket();
void BuildBucket(XList * buf);
/* de-constructor */
~TrainDataSet();
......
/* NiuTrans.NMT - an open-source neural machine translation system.
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -39,6 +39,41 @@ namespace nmt
Trainer::Trainer()
{
cfg = NULL;
lrate = 0.0F;
lrbias = 0.0F;
sBatchSize = 0;
wBatchSize = 0;
bucketSize = 0;
nstep = 0;
nepoch = 0;
logInterval = 0;
maxCheckpoint = 0;
d = 0;
nwarmup = 0;
vSize = 0;
vSizeTgt = 0;
useAdam = false;
adamBeta1 = 0.0F;
adamBeta2 = 0.0F;
adamDelta = 0.0F;
isShuffled = false;
labelSmoothingP = 0.0F;
nStepCheckpoint = 0;
useEpochCheckpoint = false;
updateStep = 0;
isLenSorted = 0;
adamBeta1T = 1.0F;
adamBeta2T = 1.0F;
batchLoader.startID = 0;
batchLoader.endID = 0;
batchLoader.unkID = 0;
batchLoader.padID = 0;
batchLoader.maxSrcLen = 0;
batchLoader.maxTgtLen = 0;
batchLoader.bufferSize = 0;
}
/* de-constructor */
......@@ -62,13 +97,15 @@ initialization
void Trainer::Init(Config& config)
{
cfg = &config;
lrate = config.lrate;
lrbias = config.lrbias;
sBatchSize = config.sBatchSize;
wBatchSize = config.wBatchSize;
bucketSize = config.bucketSize;
nepoch = config.nepoch;
nstep = config.nstep;
nepoch = config.nepoch;
logInterval = config.logInterval;
maxCheckpoint = config.maxCheckpoint;
d = config.modelSize;
nwarmup = config.nwarmup;
......@@ -87,6 +124,14 @@ void Trainer::Init(Config& config)
adamBeta1T = 1.0F;
adamBeta2T = 1.0F;
batchLoader.startID = config.startID;
batchLoader.endID = config.endID;
batchLoader.unkID = config.unkID;
batchLoader.padID = config.padID;
batchLoader.maxSrcLen = config.maxSrcLen;
batchLoader.maxTgtLen = config.maxTgtLen;
batchLoader.bufferSize = config.bufSize;
}
/*
......@@ -106,7 +151,7 @@ void Trainer::Train(const char* fn, const char* validFN,
}
int step = 0;
int wc = 0;
int ws = 0;
int sc = 0;
int wordCount = 0;
int wordCountTotal = 0;
int batchCountTotal = 0;
......@@ -134,6 +179,9 @@ void Trainer::Train(const char* fn, const char* validFN,
double startT = GetClockSec();
int curIdx = 0;
XList* buf = new XList;
batchLoader.Init(fn, bucketSize, true);
for (epoch = 1; epoch <= nepoch; epoch++) {
......@@ -141,10 +189,7 @@ void Trainer::Train(const char* fn, const char* validFN,
wordCount = 0;
loss = 0;
/* reset the batch loader */
batchLoader.ClearBuf();
while (!batchLoader.IsEmpty())
while (step++ < nstep)
{
XNet net;
net.Clear();
......@@ -160,21 +205,26 @@ void Trainer::Train(const char* fn, const char* validFN,
XTensor paddingEnc;
XTensor paddingDec;
UInt64List info = batchLoader.LoadBatch(&batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, wBatchSize, devID);
if (curIdx == 0 || curIdx == buf->Size()) {
curIdx = 0;
batchLoader.LoadBatchToBuf(buf);
}
batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, wBatchSize, devID, wc, sc);
wc = (int)info[0];
ws = (int)info[1];
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
/* output probabilities */
XTensor output;
/* make the network */
if (model->isLM)
if (model->isLM) {
model->MakeLM(batchEnc, output, paddingEnc, true);
else if (model->isMT)
}
else if (model->isMT) {
model->MakeMT(batchEnc, batchDec, output, paddingEnc, paddingDec, true);
}
else {
ShowNTErrors("Illegal model type!");
}
......@@ -192,15 +242,29 @@ void Trainer::Train(const char* fn, const char* validFN,
DTYPE lossLocal = lossBatch / wc;
bool doUpdate = (!IsNAN(lossLocal) && !IsINF(lossLocal) && lossLocal < 1e3F);
net.isGradEfficient = true;
bool debug(false);
if (debug) {
LOG("after forward:");
batchEnc.mem->ShowMemUsage(stderr);
exit(0);
}
if (doUpdate) {
/* back-propagation */
net.Backward(lossTensor);
if (model->encoder->useHistory)
model->encoder->history->ClearHistory(/*reset=*/false);
if (model->decoder->useHistory)
model->decoder->history->ClearHistory(/*reset=*/false);
gradStep += 1;
loss += lossBatch;
wordCount += wc;
wordCountTotal += wc;
batchCountTotal += ws;
batchCountTotal += sc;
/* update the parameters */
if (gradStep == updateStep) {
......@@ -227,18 +291,7 @@ void Trainer::Train(const char* fn, const char* validFN,
else
nSkipped++;
if (++step >= nstep) {
isEnd = true;
break;
}
if (step == 10) {
// LOG("after backward --------");
// lossTensor.mem->ShowMemUsage(stderr);
// exit(0);
}
if (step % 100 == 0) {
if (step % logInterval == 0) {
double elapsed = GetClockSec() - startT;
LOG("elapsed=%.1fs, step=%d, epoch=%d, "
"total word=%d, total batch=%d, loss=%.3f, ppl=%.3f, lr=%.2e",
......@@ -256,13 +309,13 @@ void Trainer::Train(const char* fn, const char* validFN,
}
}
if (isEnd)
break;
if (useEpochCheckpoint)
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
}
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT;
epoch = MIN(epoch, nepoch);
......@@ -287,8 +340,12 @@ test the model
*/
void Trainer::Validate(const char* fn, const char* ofn, Model* model)
{
double startT = GetClockSec();
DISABLE_GRAD;
int wc = 0;
int ws = 0;
int sc = 0;
int wordCount = 0;
int sentCount = 0;
float loss = 0;
......@@ -296,9 +353,14 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
/* data files */
batchLoader.Init(fn, 0, false);
double startT = GetClockSec();
int curIdx = 0;
XList* buf = new XList;
/* set the buffer size to the size of valiadation set */
batchLoader.bufferSize = batchLoader.totalSampleNum;
batchLoader.LoadBatchToBuf(buf);
while (!batchLoader.IsEmpty())
while (curIdx < buf->count)
{
/* batch of input sequences */
XTensor batchEnc;
......@@ -318,10 +380,9 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
XTensor labelOnehot;
XTensor lossTensor;
UInt64List info = batchLoader.LoadBatch(&batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, 0, model->devID);
wc = (int)info[0];
ws = (int)info[1];
batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, 0, model->devID, wc, sc);
CheckNTErrors(batchEnc.order == 2, "Wrong tensor order of the sequence batch");
/* make the network */
......@@ -337,18 +398,31 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
int length = output.GetDim(1);
labelOnehot = IndexToOnehot(label, vSizeTgt, 0);
lossTensor = CrossEntropy(output, labelOnehot, paddingDec);
float lossBatch = ReduceSumAllValue(lossTensor);
loss += lossBatch;
wordCount += wc;
sentCount += bSize;
if (model->encoder->useHistory)
model->encoder->history->ClearHistory(/*reset=*/false);
if (model->decoder->useHistory)
model->decoder->history->ClearHistory(/*reset=*/false);
}
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT;
LOG("test finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)",
ENABLE_GRAD;
LOG("validating finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)",
elapsed, sentCount, wordCount, loss / wordCount / log(2.0), exp(loss / wordCount));
}
......@@ -428,7 +502,7 @@ void Trainer::Update(Model* model, const float lr)
_ScaleAndShiftMe(v, (1.0F - adamBeta2), 0);
/* v2 = m / (sqrt(v) + delta) */
XTensor* v2 = NewTensorBuf(v, v->devID);
XTensor* v2 = NewTensorBufV2(v, v->devID, v->mem);
_Power(v, v2, 0.5F);
_ScaleAndShiftMe(v2, 1.0F, d);
_Div(m, v2, v2);
......
/* NiuTrans.NMT - an open-source neural machine translation system.
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -70,6 +70,9 @@ public:
/* traing step number */
int nstep;
/* interval step for logging */
int logInterval;
/* the maximum number of saved checkpoints */
int maxCheckpoint;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论