Commit 98a9130d by hello

Refactor class `TrainDataSet`

parent 4bbd6a27
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -30,7 +30,6 @@ ...@@ -30,7 +30,6 @@
#include "../../../tensor/XTensor.h" #include "../../../tensor/XTensor.h"
#include "../../../tensor/XGlobal.h" #include "../../../tensor/XGlobal.h"
#define MAX_WORD_NUM 120
using namespace std; using namespace std;
...@@ -39,39 +38,54 @@ namespace nts { ...@@ -39,39 +38,54 @@ namespace nts {
/* a class of sentence pairs for training */ /* a class of sentence pairs for training */
struct TrainExample { struct TrainExample {
public:
/* id of the sentence pair */ /* id of the sentence pair */
int id; int id;
/* source language setence (tokenized) */ /* source language setence (tokenized) */
IntList srcSent; IntList * srcSent;
/* target language setence (tokenized) */ /* target language setence (tokenized) */
IntList tgtSent; IntList * tgtSent;
/* the key used to shuffle items in a bucket */
int key;
/* the key used to shuffle buckets */ /* the key used to shuffle buckets */
int bucketKey; int bucketKey;
public:
/* constructor */
TrainExample(int myID, int myKey, IntList* s, IntList* t);
/* de-constructor */
~TrainExample();
};
struct ReservedIDs {
/* the padding id */
int padID;
/* the unk id */
int unkID;
/* start symbol */
int startID;
/* end symbol */
int endID;
}; };
/* A `TrainDataSet` is associated with a file which contains training data. */ /* A `TrainDataSet` is associated with a file which contains training data. */
struct TrainDataSet { struct TrainDataSet {
public:
/* the data buffer */
TrainBufferType buffer;
/* a list of empty line number */ public:
IntList emptyLines;
/* the pointer to file stream */ /* the pointer to file stream */
FILE* fp; FILE* fp;
/* current index in the buffer */ /* number of training samples */
size_t curIdx; size_t totalSampleNum;
/* size of used data in the buffer */ /* buffer size */
size_t bufferUsed; size_t bufferSize;
/* size of the bucket used for grouping sentences */ /* size of the bucket used for grouping sentences */
size_t bucketSize; size_t bucketSize;
...@@ -79,34 +93,51 @@ public: ...@@ -79,34 +93,51 @@ public:
/* indicates whether it is used for training */ /* indicates whether it is used for training */
bool isTraining; bool isTraining;
/* the padding id */
int padID;
/* the unk id */
int unkID;
/* start symbol */
int startID;
/* end symbol */
int endID;
/* the maximum length for a source sentence */
int maxSrcLen;
/* the maximum length for a target sentence */
int maxTgtLen;
public: public:
/* sort the input by length (in descending order) */ /* get the maximum source sentence length in a range */
void SortByLength(); static
int MaxSrcLen(XList* buf, int begin, int end);
/* sort buckets by key (in descending order) */ /* get the maximum target sentence length in a range */
void SortBucket(); static
int MaxTgtLen(XList* buf, int begin, int end);
/* sort the output by key (in descending order) */ /* sort the input by source sentence length (in descending order) */
void SortInBucket(int begin, int end); void SortBySrcLength(XList* buf);
/* load data from a file to the buffer */ /* sort the input by target sentence length (in descending order) */
void LoadDataToBuffer(); void SortByTgtLength(XList* buf);
/* generate a mini-batch */ /* sort buckets by key (in descending order) */
UInt64List LoadBatch(XTensor* batchEnc, XTensor* paddingEnc, void SortBuckets(XList* buf);
XTensor* batchDec, XTensor* paddingDec, XTensor* label,
size_t minSentBatch, size_t batchSize, int devID);
/* load the samples into the buffer (a list) */ /* load the samples into the buffer (a list) */
bool LoadBatchToBuf(XList * buf); bool LoadBatchToBuf(XList * buf);
/* load the samples into tensors from the buffer */ /* load the samples into tensors from the buffer */
static static
bool LoadBatch(XList * buf, bool LoadBatch(XList * buf, int & curIdx,
XTensor* batchEnc, XTensor* paddingEnc, XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label, XTensor* batchDec, XTensor* paddingDec, XTensor* label, int minSentBatch, int batchSize, int devID,
size_t minSentBatch, size_t batchSize, int devID,
int &wc, int &sc); int &wc, int &sc);
/* release the samples in a buffer */ /* release the samples in a buffer */
...@@ -116,14 +147,8 @@ public: ...@@ -116,14 +147,8 @@ public:
/* initialization function */ /* initialization function */
void Init(const char* dataFile, int bucketSize, bool training); void Init(const char* dataFile, int bucketSize, bool training);
/* check if the buffer is empty */
bool IsEmpty();
/* reset the buffer */
void ClearBuf();
/* group data into buckets with similar length */ /* group data into buckets with similar length */
void BuildBucket(); void BuildBucket(XList * buf);
/* de-constructor */ /* de-constructor */
~TrainDataSet(); ~TrainDataSet();
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -39,6 +39,41 @@ namespace nmt ...@@ -39,6 +39,41 @@ namespace nmt
Trainer::Trainer() Trainer::Trainer()
{ {
cfg = NULL; cfg = NULL;
lrate = 0.0F;
lrbias = 0.0F;
sBatchSize = 0;
wBatchSize = 0;
bucketSize = 0;
nstep = 0;
nepoch = 0;
logInterval = 0;
maxCheckpoint = 0;
d = 0;
nwarmup = 0;
vSize = 0;
vSizeTgt = 0;
useAdam = false;
adamBeta1 = 0.0F;
adamBeta2 = 0.0F;
adamDelta = 0.0F;
isShuffled = false;
labelSmoothingP = 0.0F;
nStepCheckpoint = 0;
useEpochCheckpoint = false;
updateStep = 0;
isLenSorted = 0;
adamBeta1T = 1.0F;
adamBeta2T = 1.0F;
batchLoader.startID = 0;
batchLoader.endID = 0;
batchLoader.unkID = 0;
batchLoader.padID = 0;
batchLoader.maxSrcLen = 0;
batchLoader.maxTgtLen = 0;
batchLoader.bufferSize = 0;
} }
/* de-constructor */ /* de-constructor */
...@@ -62,13 +97,15 @@ initialization ...@@ -62,13 +97,15 @@ initialization
void Trainer::Init(Config& config) void Trainer::Init(Config& config)
{ {
cfg = &config; cfg = &config;
lrate = config.lrate; lrate = config.lrate;
lrbias = config.lrbias; lrbias = config.lrbias;
sBatchSize = config.sBatchSize; sBatchSize = config.sBatchSize;
wBatchSize = config.wBatchSize; wBatchSize = config.wBatchSize;
bucketSize = config.bucketSize; bucketSize = config.bucketSize;
nepoch = config.nepoch;
nstep = config.nstep; nstep = config.nstep;
nepoch = config.nepoch;
logInterval = config.logInterval;
maxCheckpoint = config.maxCheckpoint; maxCheckpoint = config.maxCheckpoint;
d = config.modelSize; d = config.modelSize;
nwarmup = config.nwarmup; nwarmup = config.nwarmup;
...@@ -87,6 +124,14 @@ void Trainer::Init(Config& config) ...@@ -87,6 +124,14 @@ void Trainer::Init(Config& config)
adamBeta1T = 1.0F; adamBeta1T = 1.0F;
adamBeta2T = 1.0F; adamBeta2T = 1.0F;
batchLoader.startID = config.startID;
batchLoader.endID = config.endID;
batchLoader.unkID = config.unkID;
batchLoader.padID = config.padID;
batchLoader.maxSrcLen = config.maxSrcLen;
batchLoader.maxTgtLen = config.maxTgtLen;
batchLoader.bufferSize = config.bufSize;
} }
/* /*
...@@ -106,7 +151,7 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -106,7 +151,7 @@ void Trainer::Train(const char* fn, const char* validFN,
} }
int step = 0; int step = 0;
int wc = 0; int wc = 0;
int ws = 0; int sc = 0;
int wordCount = 0; int wordCount = 0;
int wordCountTotal = 0; int wordCountTotal = 0;
int batchCountTotal = 0; int batchCountTotal = 0;
...@@ -134,6 +179,9 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -134,6 +179,9 @@ void Trainer::Train(const char* fn, const char* validFN,
double startT = GetClockSec(); double startT = GetClockSec();
int curIdx = 0;
XList* buf = new XList;
batchLoader.Init(fn, bucketSize, true); batchLoader.Init(fn, bucketSize, true);
for (epoch = 1; epoch <= nepoch; epoch++) { for (epoch = 1; epoch <= nepoch; epoch++) {
...@@ -141,10 +189,7 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -141,10 +189,7 @@ void Trainer::Train(const char* fn, const char* validFN,
wordCount = 0; wordCount = 0;
loss = 0; loss = 0;
/* reset the batch loader */ while (step++ < nstep)
batchLoader.ClearBuf();
while (!batchLoader.IsEmpty())
{ {
XNet net; XNet net;
net.Clear(); net.Clear();
...@@ -160,21 +205,26 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -160,21 +205,26 @@ void Trainer::Train(const char* fn, const char* validFN,
XTensor paddingEnc; XTensor paddingEnc;
XTensor paddingDec; XTensor paddingDec;
UInt64List info = batchLoader.LoadBatch(&batchEnc, &paddingEnc, &batchDec, &paddingDec, &label, if (curIdx == 0 || curIdx == buf->Size()) {
sBatchSize, wBatchSize, devID); curIdx = 0;
batchLoader.LoadBatchToBuf(buf);
}
batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, wBatchSize, devID, wc, sc);
wc = (int)info[0];
ws = (int)info[1];
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
/* output probabilities */ /* output probabilities */
XTensor output; XTensor output;
/* make the network */ /* make the network */
if (model->isLM) if (model->isLM) {
model->MakeLM(batchEnc, output, paddingEnc, true); model->MakeLM(batchEnc, output, paddingEnc, true);
else if (model->isMT) }
else if (model->isMT) {
model->MakeMT(batchEnc, batchDec, output, paddingEnc, paddingDec, true); model->MakeMT(batchEnc, batchDec, output, paddingEnc, paddingDec, true);
}
else { else {
ShowNTErrors("Illegal model type!"); ShowNTErrors("Illegal model type!");
} }
...@@ -192,15 +242,29 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -192,15 +242,29 @@ void Trainer::Train(const char* fn, const char* validFN,
DTYPE lossLocal = lossBatch / wc; DTYPE lossLocal = lossBatch / wc;
bool doUpdate = (!IsNAN(lossLocal) && !IsINF(lossLocal) && lossLocal < 1e3F); bool doUpdate = (!IsNAN(lossLocal) && !IsINF(lossLocal) && lossLocal < 1e3F);
net.isGradEfficient = true;
bool debug(false);
if (debug) {
LOG("after forward:");
batchEnc.mem->ShowMemUsage(stderr);
exit(0);
}
if (doUpdate) { if (doUpdate) {
/* back-propagation */
net.Backward(lossTensor); net.Backward(lossTensor);
if (model->encoder->useHistory)
model->encoder->history->ClearHistory(/*reset=*/false);
if (model->decoder->useHistory)
model->decoder->history->ClearHistory(/*reset=*/false);
gradStep += 1; gradStep += 1;
loss += lossBatch; loss += lossBatch;
wordCount += wc; wordCount += wc;
wordCountTotal += wc; wordCountTotal += wc;
batchCountTotal += ws; batchCountTotal += sc;
/* update the parameters */ /* update the parameters */
if (gradStep == updateStep) { if (gradStep == updateStep) {
...@@ -227,18 +291,7 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -227,18 +291,7 @@ void Trainer::Train(const char* fn, const char* validFN,
else else
nSkipped++; nSkipped++;
if (++step >= nstep) { if (step % logInterval == 0) {
isEnd = true;
break;
}
if (step == 10) {
// LOG("after backward --------");
// lossTensor.mem->ShowMemUsage(stderr);
// exit(0);
}
if (step % 100 == 0) {
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
LOG("elapsed=%.1fs, step=%d, epoch=%d, " LOG("elapsed=%.1fs, step=%d, epoch=%d, "
"total word=%d, total batch=%d, loss=%.3f, ppl=%.3f, lr=%.2e", "total word=%d, total batch=%d, loss=%.3f, ppl=%.3f, lr=%.2e",
...@@ -256,13 +309,13 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -256,13 +309,13 @@ void Trainer::Train(const char* fn, const char* validFN,
} }
} }
if (isEnd)
break;
if (useEpochCheckpoint) if (useEpochCheckpoint)
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch); MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
} }
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
epoch = MIN(epoch, nepoch); epoch = MIN(epoch, nepoch);
...@@ -287,8 +340,12 @@ test the model ...@@ -287,8 +340,12 @@ test the model
*/ */
void Trainer::Validate(const char* fn, const char* ofn, Model* model) void Trainer::Validate(const char* fn, const char* ofn, Model* model)
{ {
double startT = GetClockSec();
DISABLE_GRAD;
int wc = 0; int wc = 0;
int ws = 0; int sc = 0;
int wordCount = 0; int wordCount = 0;
int sentCount = 0; int sentCount = 0;
float loss = 0; float loss = 0;
...@@ -296,9 +353,14 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model) ...@@ -296,9 +353,14 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
/* data files */ /* data files */
batchLoader.Init(fn, 0, false); batchLoader.Init(fn, 0, false);
double startT = GetClockSec(); int curIdx = 0;
XList* buf = new XList;
/* set the buffer size to the size of valiadation set */
batchLoader.bufferSize = batchLoader.totalSampleNum;
batchLoader.LoadBatchToBuf(buf);
while (!batchLoader.IsEmpty()) while (curIdx < buf->count)
{ {
/* batch of input sequences */ /* batch of input sequences */
XTensor batchEnc; XTensor batchEnc;
...@@ -318,10 +380,9 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model) ...@@ -318,10 +380,9 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
XTensor labelOnehot; XTensor labelOnehot;
XTensor lossTensor; XTensor lossTensor;
UInt64List info = batchLoader.LoadBatch(&batchEnc, &paddingEnc, &batchDec, &paddingDec, &label, batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, 0, model->devID); sBatchSize, 0, model->devID, wc, sc);
wc = (int)info[0];
ws = (int)info[1];
CheckNTErrors(batchEnc.order == 2, "Wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "Wrong tensor order of the sequence batch");
/* make the network */ /* make the network */
...@@ -337,18 +398,31 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model) ...@@ -337,18 +398,31 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
int length = output.GetDim(1); int length = output.GetDim(1);
labelOnehot = IndexToOnehot(label, vSizeTgt, 0); labelOnehot = IndexToOnehot(label, vSizeTgt, 0);
lossTensor = CrossEntropy(output, labelOnehot, paddingDec); lossTensor = CrossEntropy(output, labelOnehot, paddingDec);
float lossBatch = ReduceSumAllValue(lossTensor); float lossBatch = ReduceSumAllValue(lossTensor);
loss += lossBatch; loss += lossBatch;
wordCount += wc; wordCount += wc;
sentCount += bSize; sentCount += bSize;
if (model->encoder->useHistory)
model->encoder->history->ClearHistory(/*reset=*/false);
if (model->decoder->useHistory)
model->decoder->history->ClearHistory(/*reset=*/false);
} }
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
LOG("test finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)", ENABLE_GRAD;
LOG("validating finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)",
elapsed, sentCount, wordCount, loss / wordCount / log(2.0), exp(loss / wordCount)); elapsed, sentCount, wordCount, loss / wordCount / log(2.0), exp(loss / wordCount));
} }
...@@ -428,7 +502,7 @@ void Trainer::Update(Model* model, const float lr) ...@@ -428,7 +502,7 @@ void Trainer::Update(Model* model, const float lr)
_ScaleAndShiftMe(v, (1.0F - adamBeta2), 0); _ScaleAndShiftMe(v, (1.0F - adamBeta2), 0);
/* v2 = m / (sqrt(v) + delta) */ /* v2 = m / (sqrt(v) + delta) */
XTensor* v2 = NewTensorBuf(v, v->devID); XTensor* v2 = NewTensorBufV2(v, v->devID, v->mem);
_Power(v, v2, 0.5F); _Power(v, v2, 0.5F);
_ScaleAndShiftMe(v2, 1.0F, d); _ScaleAndShiftMe(v2, 1.0F, d);
_Div(m, v2, v2); _Div(m, v2, v2);
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -70,6 +70,9 @@ public: ...@@ -70,6 +70,9 @@ public:
/* traing step number */ /* traing step number */
int nstep; int nstep;
/* interval step for logging */
int logInterval;
/* the maximum number of saved checkpoints */ /* the maximum number of saved checkpoints */
int maxCheckpoint; int maxCheckpoint;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论