Commit 98a9130d by hello

Refactor class `TrainDataSet`

parent 4bbd6a27
/* NiuTrans.NMT - an open-source neural machine translation system.
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -30,7 +30,6 @@
#include "../../../tensor/XTensor.h"
#include "../../../tensor/XGlobal.h"
#define MAX_WORD_NUM 120
using namespace std;
......@@ -39,39 +38,54 @@ namespace nts {
/* a class of sentence pairs for training */
struct TrainExample {
public:
/* id of the sentence pair */
int id;
/* source language setence (tokenized) */
IntList srcSent;
IntList * srcSent;
/* target language setence (tokenized) */
IntList tgtSent;
/* the key used to shuffle items in a bucket */
int key;
IntList * tgtSent;
/* the key used to shuffle buckets */
int bucketKey;
public:
/* constructor */
TrainExample(int myID, int myKey, IntList* s, IntList* t);
/* de-constructor */
~TrainExample();
};
struct ReservedIDs {
/* the padding id */
int padID;
/* the unk id */
int unkID;
/* start symbol */
int startID;
/* end symbol */
int endID;
};
/* A `TrainDataSet` is associated with a file which contains training data. */
struct TrainDataSet {
public:
/* the data buffer */
TrainBufferType buffer;
/* a list of empty line number */
IntList emptyLines;
public:
/* the pointer to file stream */
FILE* fp;
/* current index in the buffer */
size_t curIdx;
/* number of training samples */
size_t totalSampleNum;
/* size of used data in the buffer */
size_t bufferUsed;
/* buffer size */
size_t bufferSize;
/* size of the bucket used for grouping sentences */
size_t bucketSize;
......@@ -79,34 +93,51 @@ public:
/* indicates whether it is used for training */
bool isTraining;
/* the padding id */
int padID;
/* the unk id */
int unkID;
/* start symbol */
int startID;
/* end symbol */
int endID;
/* the maximum length for a source sentence */
int maxSrcLen;
/* the maximum length for a target sentence */
int maxTgtLen;
public:
/* sort the input by length (in descending order) */
void SortByLength();
/* get the maximum source sentence length in a range */
static
int MaxSrcLen(XList* buf, int begin, int end);
/* sort buckets by key (in descending order) */
void SortBucket();
/* get the maximum target sentence length in a range */
static
int MaxTgtLen(XList* buf, int begin, int end);
/* sort the output by key (in descending order) */
void SortInBucket(int begin, int end);
/* sort the input by source sentence length (in descending order) */
void SortBySrcLength(XList* buf);
/* load data from a file to the buffer */
void LoadDataToBuffer();
/* sort the input by target sentence length (in descending order) */
void SortByTgtLength(XList* buf);
/* generate a mini-batch */
UInt64List LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label,
size_t minSentBatch, size_t batchSize, int devID);
/* sort buckets by key (in descending order) */
void SortBuckets(XList* buf);
/* load the samples into the buffer (a list) */
bool LoadBatchToBuf(XList * buf);
/* load the samples into tensors from the buffer */
static
bool LoadBatch(XList * buf,
bool LoadBatch(XList * buf, int & curIdx,
XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label,
size_t minSentBatch, size_t batchSize, int devID,
XTensor* batchDec, XTensor* paddingDec, XTensor* label, int minSentBatch, int batchSize, int devID,
int &wc, int &sc);
/* release the samples in a buffer */
......@@ -116,14 +147,8 @@ public:
/* initialization function */
void Init(const char* dataFile, int bucketSize, bool training);
/* check if the buffer is empty */
bool IsEmpty();
/* reset the buffer */
void ClearBuf();
/* group data into buckets with similar length */
void BuildBucket();
void BuildBucket(XList * buf);
/* de-constructor */
~TrainDataSet();
......
/* NiuTrans.NMT - an open-source neural machine translation system.
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -39,6 +39,41 @@ namespace nmt
Trainer::Trainer()
{
cfg = NULL;
lrate = 0.0F;
lrbias = 0.0F;
sBatchSize = 0;
wBatchSize = 0;
bucketSize = 0;
nstep = 0;
nepoch = 0;
logInterval = 0;
maxCheckpoint = 0;
d = 0;
nwarmup = 0;
vSize = 0;
vSizeTgt = 0;
useAdam = false;
adamBeta1 = 0.0F;
adamBeta2 = 0.0F;
adamDelta = 0.0F;
isShuffled = false;
labelSmoothingP = 0.0F;
nStepCheckpoint = 0;
useEpochCheckpoint = false;
updateStep = 0;
isLenSorted = 0;
adamBeta1T = 1.0F;
adamBeta2T = 1.0F;
batchLoader.startID = 0;
batchLoader.endID = 0;
batchLoader.unkID = 0;
batchLoader.padID = 0;
batchLoader.maxSrcLen = 0;
batchLoader.maxTgtLen = 0;
batchLoader.bufferSize = 0;
}
/* de-constructor */
......@@ -62,13 +97,15 @@ initialization
void Trainer::Init(Config& config)
{
cfg = &config;
lrate = config.lrate;
lrbias = config.lrbias;
sBatchSize = config.sBatchSize;
wBatchSize = config.wBatchSize;
bucketSize = config.bucketSize;
nepoch = config.nepoch;
nstep = config.nstep;
nepoch = config.nepoch;
logInterval = config.logInterval;
maxCheckpoint = config.maxCheckpoint;
d = config.modelSize;
nwarmup = config.nwarmup;
......@@ -87,6 +124,14 @@ void Trainer::Init(Config& config)
adamBeta1T = 1.0F;
adamBeta2T = 1.0F;
batchLoader.startID = config.startID;
batchLoader.endID = config.endID;
batchLoader.unkID = config.unkID;
batchLoader.padID = config.padID;
batchLoader.maxSrcLen = config.maxSrcLen;
batchLoader.maxTgtLen = config.maxTgtLen;
batchLoader.bufferSize = config.bufSize;
}
/*
......@@ -106,7 +151,7 @@ void Trainer::Train(const char* fn, const char* validFN,
}
int step = 0;
int wc = 0;
int ws = 0;
int sc = 0;
int wordCount = 0;
int wordCountTotal = 0;
int batchCountTotal = 0;
......@@ -134,6 +179,9 @@ void Trainer::Train(const char* fn, const char* validFN,
double startT = GetClockSec();
int curIdx = 0;
XList* buf = new XList;
batchLoader.Init(fn, bucketSize, true);
for (epoch = 1; epoch <= nepoch; epoch++) {
......@@ -141,10 +189,7 @@ void Trainer::Train(const char* fn, const char* validFN,
wordCount = 0;
loss = 0;
/* reset the batch loader */
batchLoader.ClearBuf();
while (!batchLoader.IsEmpty())
while (step++ < nstep)
{
XNet net;
net.Clear();
......@@ -160,21 +205,26 @@ void Trainer::Train(const char* fn, const char* validFN,
XTensor paddingEnc;
XTensor paddingDec;
UInt64List info = batchLoader.LoadBatch(&batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, wBatchSize, devID);
if (curIdx == 0 || curIdx == buf->Size()) {
curIdx = 0;
batchLoader.LoadBatchToBuf(buf);
}
batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, wBatchSize, devID, wc, sc);
wc = (int)info[0];
ws = (int)info[1];
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
/* output probabilities */
XTensor output;
/* make the network */
if (model->isLM)
if (model->isLM) {
model->MakeLM(batchEnc, output, paddingEnc, true);
else if (model->isMT)
}
else if (model->isMT) {
model->MakeMT(batchEnc, batchDec, output, paddingEnc, paddingDec, true);
}
else {
ShowNTErrors("Illegal model type!");
}
......@@ -192,15 +242,29 @@ void Trainer::Train(const char* fn, const char* validFN,
DTYPE lossLocal = lossBatch / wc;
bool doUpdate = (!IsNAN(lossLocal) && !IsINF(lossLocal) && lossLocal < 1e3F);
net.isGradEfficient = true;
bool debug(false);
if (debug) {
LOG("after forward:");
batchEnc.mem->ShowMemUsage(stderr);
exit(0);
}
if (doUpdate) {
/* back-propagation */
net.Backward(lossTensor);
if (model->encoder->useHistory)
model->encoder->history->ClearHistory(/*reset=*/false);
if (model->decoder->useHistory)
model->decoder->history->ClearHistory(/*reset=*/false);
gradStep += 1;
loss += lossBatch;
wordCount += wc;
wordCountTotal += wc;
batchCountTotal += ws;
batchCountTotal += sc;
/* update the parameters */
if (gradStep == updateStep) {
......@@ -227,18 +291,7 @@ void Trainer::Train(const char* fn, const char* validFN,
else
nSkipped++;
if (++step >= nstep) {
isEnd = true;
break;
}
if (step == 10) {
// LOG("after backward --------");
// lossTensor.mem->ShowMemUsage(stderr);
// exit(0);
}
if (step % 100 == 0) {
if (step % logInterval == 0) {
double elapsed = GetClockSec() - startT;
LOG("elapsed=%.1fs, step=%d, epoch=%d, "
"total word=%d, total batch=%d, loss=%.3f, ppl=%.3f, lr=%.2e",
......@@ -256,13 +309,13 @@ void Trainer::Train(const char* fn, const char* validFN,
}
}
if (isEnd)
break;
if (useEpochCheckpoint)
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
}
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT;
epoch = MIN(epoch, nepoch);
......@@ -287,8 +340,12 @@ test the model
*/
void Trainer::Validate(const char* fn, const char* ofn, Model* model)
{
double startT = GetClockSec();
DISABLE_GRAD;
int wc = 0;
int ws = 0;
int sc = 0;
int wordCount = 0;
int sentCount = 0;
float loss = 0;
......@@ -296,9 +353,14 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
/* data files */
batchLoader.Init(fn, 0, false);
double startT = GetClockSec();
int curIdx = 0;
XList* buf = new XList;
/* set the buffer size to the size of valiadation set */
batchLoader.bufferSize = batchLoader.totalSampleNum;
batchLoader.LoadBatchToBuf(buf);
while (!batchLoader.IsEmpty())
while (curIdx < buf->count)
{
/* batch of input sequences */
XTensor batchEnc;
......@@ -318,10 +380,9 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
XTensor labelOnehot;
XTensor lossTensor;
UInt64List info = batchLoader.LoadBatch(&batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, 0, model->devID);
wc = (int)info[0];
ws = (int)info[1];
batchLoader.LoadBatch(buf, curIdx, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &label,
sBatchSize, 0, model->devID, wc, sc);
CheckNTErrors(batchEnc.order == 2, "Wrong tensor order of the sequence batch");
/* make the network */
......@@ -337,18 +398,31 @@ void Trainer::Validate(const char* fn, const char* ofn, Model* model)
int length = output.GetDim(1);
labelOnehot = IndexToOnehot(label, vSizeTgt, 0);
lossTensor = CrossEntropy(output, labelOnehot, paddingDec);
float lossBatch = ReduceSumAllValue(lossTensor);
loss += lossBatch;
wordCount += wc;
sentCount += bSize;
if (model->encoder->useHistory)
model->encoder->history->ClearHistory(/*reset=*/false);
if (model->decoder->useHistory)
model->decoder->history->ClearHistory(/*reset=*/false);
}
batchLoader.ClearSamples(buf);
delete buf;
double elapsed = GetClockSec() - startT;
LOG("test finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)",
ENABLE_GRAD;
LOG("validating finished (took %.1fs, sentence=%d, word=%d, loss=%.3f and ppl=%.3f)",
elapsed, sentCount, wordCount, loss / wordCount / log(2.0), exp(loss / wordCount));
}
......@@ -428,7 +502,7 @@ void Trainer::Update(Model* model, const float lr)
_ScaleAndShiftMe(v, (1.0F - adamBeta2), 0);
/* v2 = m / (sqrt(v) + delta) */
XTensor* v2 = NewTensorBuf(v, v->devID);
XTensor* v2 = NewTensorBufV2(v, v->devID, v->mem);
_Power(v, v2, 0.5F);
_ScaleAndShiftMe(v2, 1.0F, d);
_Div(m, v2, v2);
......
/* NiuTrans.NMT - an open-source neural machine translation system.
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -70,6 +70,9 @@ public:
/* traing step number */
int nstep;
/* interval step for logging */
int logInterval;
/* the maximum number of saved checkpoints */
int maxCheckpoint;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论