Commit 2161f65b by xiaotong

randomize sample batches rarther than loading them is sorted manner

parent 52a50ab2
...@@ -41,12 +41,15 @@ T2TTrainer::T2TTrainer() ...@@ -41,12 +41,15 @@ T2TTrainer::T2TTrainer()
seqLen2 = NULL; seqLen2 = NULL;
nseqBuf = 0; nseqBuf = 0;
nextSeq = -1; nextSeq = -1;
nextBatch = -1;
argNum = 0; argNum = 0;
argArray = NULL; argArray = NULL;
buf = NULL; buf = NULL;
buf2 = NULL; buf2 = NULL;
bufBatch = NULL;
bufSize = 0; bufSize = 0;
bufBatchSize = 0;
seqOffset = NULL; seqOffset = NULL;
} }
...@@ -55,6 +58,7 @@ T2TTrainer::~T2TTrainer() ...@@ -55,6 +58,7 @@ T2TTrainer::~T2TTrainer()
{ {
delete[] buf; delete[] buf;
delete[] buf2; delete[] buf2;
delete[] bufBatch;
delete[] seqLen; delete[] seqLen;
delete[] seqLen2; delete[] seqLen2;
delete[] seqOffset; delete[] seqOffset;
...@@ -117,9 +121,11 @@ void T2TTrainer::Init(int argc, char ** argv) ...@@ -117,9 +121,11 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamBool(argc, argv, "smallbatch", &isSmallBatch, true); LoadParamBool(argc, argv, "smallbatch", &isSmallBatch, true);
LoadParamBool(argc, argv, "bigbatch", &isBigBatch, false); LoadParamBool(argc, argv, "bigbatch", &isBigBatch, false);
LoadParamBool(argc, argv, "debug", &isDebugged, false); LoadParamBool(argc, argv, "debug", &isDebugged, false);
LoadParamBool(argc, argv, "randbatch", &isRandomBatch, false);
buf = new int[bufSize]; buf = new int[bufSize];
buf2 = new int[bufSize]; buf2 = new int[bufSize];
bufBatch = new BatchNode[bufSize];
seqLen = new int[bufSize]; seqLen = new int[bufSize];
seqLen2 = new int[bufSize]; seqLen2 = new int[bufSize];
seqOffset = new int[bufSize]; seqOffset = new int[bufSize];
...@@ -768,6 +774,12 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -768,6 +774,12 @@ int T2TTrainer::LoadBatchLM(FILE * file,
return sc; return sc;
} }
int CompareBatchNode(const void * a, const void * b)
{
return ((BatchNode*)b)->key - ((BatchNode*)a)->key;
}
/* /*
load a batch of sequences (for MT) load a batch of sequences (for MT)
>> file - the handle to the data file >> file - the handle to the data file
...@@ -797,10 +809,70 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -797,10 +809,70 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int devID, XMem * mem, int devID, XMem * mem,
bool isTraining) bool isTraining)
{ {
if(nextSeq < 0 || nextSeq >= nseqBuf) //if (nextSeq < 0 || nextSeq >= nseqBuf)
// LoadBuf(file, isSorted, 2);
if (nextBatch < 0 || nextBatch >= bufBatchSize) {
LoadBuf(file, isSorted, 2); LoadBuf(file, isSorted, 2);
int seq = MAX(nextSeq, 0); int seq = 0;
bufBatchSize = 0;
nextBatch = 0;
/* we segment the buffer into batches */
while (seq < nseqBuf) {
int wcEnc = 0;
int wcDec = 0;
int wnEnc = 0;
int wnDec = 0;
int maxEnc = 0;
int maxDec = 0;
int sc = 0;
while (seq + sc < nseqBuf) {
/* source-side sequence */
wnEnc = seqLen[seq + sc];
/* target-side sequence */
wnDec = isDoubledEnd ? seqLen[seq + sc + 1] : seqLen[seq + sc + 1] - 1;
int tcEnc = isBigBatch ? (wcEnc + wnEnc) : MAX(maxEnc, wnEnc) * (sc + 2) / 2;
int tcDec = isBigBatch ? (wcDec + wnDec) : MAX(maxDec, wnDec) * (sc + 2) / 2;
if (sc != 0 && sc > sBatch * 2 && (tcEnc > wBatch || tcDec > wBatch))
break;
wcEnc += wnEnc;
sc += 1;
if (maxEnc < wnEnc)
maxEnc = wnEnc;
wcDec += wnDec;
sc += 1;
if (maxDec < wnDec)
maxDec = wnDec;
}
BatchNode & batch = bufBatch[bufBatchSize];
batch.beg = seq;
batch.end = seq + sc;
batch.maxEnc = maxEnc;
batch.maxDec = maxDec;
batch.key = rand();
bufBatchSize++;
seq = seq + sc;
}
if(isRandomBatch)
qsort(bufBatch, bufBatchSize, sizeof(BatchNode), CompareBatchNode);
}
/*int seq = MAX(nextSeq, 0);
int wcEnc = 0; int wcEnc = 0;
int wcDec = 0; int wcDec = 0;
int wnEnc = 0; int wnEnc = 0;
...@@ -813,10 +885,8 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -813,10 +885,8 @@ int T2TTrainer::LoadBatchMT(FILE * file,
while(seq + sc < nseqBuf){ while(seq + sc < nseqBuf){
/* source-side sequence */
wnEnc = seqLen[seq + sc]; wnEnc = seqLen[seq + sc];
/* target-side sequence */
wnDec = isDoubledEnd ? seqLen[seq + sc + 1] : seqLen[seq + sc + 1] - 1; wnDec = isDoubledEnd ? seqLen[seq + sc + 1] : seqLen[seq + sc + 1] - 1;
int tcEnc = isBigBatch ? (wcEnc + wnEnc): MAX(maxEnc, wnEnc) * (sc + 2) / 2; int tcEnc = isBigBatch ? (wcEnc + wnEnc): MAX(maxEnc, wnEnc) * (sc + 2) / 2;
...@@ -841,7 +911,15 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -841,7 +911,15 @@ int T2TTrainer::LoadBatchMT(FILE * file,
nextSeq = seq + sc; nextSeq = seq + sc;
if(sc <= 0) if(sc <= 0)
return 0; return 0;*/
BatchNode & batch = bufBatch[nextBatch++];
int seq = batch.beg;
int sc = batch.end - batch.beg;
int maxEnc = batch.maxEnc;
int maxDec = batch.maxDec;
CheckNTErrors(sc % 2 == 0, "The input samples must be paired");
int sCount = sc/2; int sCount = sc/2;
int seqSize = 0; int seqSize = 0;
......
...@@ -33,6 +33,25 @@ using namespace nts; ...@@ -33,6 +33,25 @@ using namespace nts;
namespace transformer namespace transformer
{ {
/* node to keep batch information */
struct BatchNode
{
/* begining position */
int beg;
/* end position */
int end;
/* maximum word number on the encoder side */
int maxEnc;
/* maximum word number on the decoder side */
int maxDec;
/* a key for sorting */
int key;
};
/* trainer of the T2T model */ /* trainer of the T2T model */
class T2TTrainer class T2TTrainer
{ {
...@@ -49,9 +68,15 @@ public: ...@@ -49,9 +68,15 @@ public:
/* another buffer */ /* another buffer */
int * buf2; int * buf2;
/* batch buf */
BatchNode * bufBatch;
/* buffer size */ /* buffer size */
int bufSize; int bufSize;
/* size of batch buffer */
int bufBatchSize;
/* length of each sequence */ /* length of each sequence */
int * seqLen; int * seqLen;
...@@ -66,6 +91,9 @@ public: ...@@ -66,6 +91,9 @@ public:
/* offset for next sequence in the buffer */ /* offset for next sequence in the buffer */
int nextSeq; int nextSeq;
/* offset for next batch */
int nextBatch;
/* indicates whether the sequence is sorted by length */ /* indicates whether the sequence is sorted by length */
bool isLenSorted; bool isLenSorted;
...@@ -142,6 +170,9 @@ public: ...@@ -142,6 +170,9 @@ public:
/* counterpart of "isSmallBatch" */ /* counterpart of "isSmallBatch" */
bool isBigBatch; bool isBigBatch;
/* randomize batches */
bool isRandomBatch;
/* indicates whether we intend to debug the net */ /* indicates whether we intend to debug the net */
bool isDebugged; bool isDebugged;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论