Commit 2161f65b by xiaotong

randomize sample batches rarther than loading them is sorted manner

parent 52a50ab2
......@@ -41,12 +41,15 @@ T2TTrainer::T2TTrainer()
seqLen2 = NULL;
nseqBuf = 0;
nextSeq = -1;
nextBatch = -1;
argNum = 0;
argArray = NULL;
buf = NULL;
buf2 = NULL;
bufBatch = NULL;
bufSize = 0;
bufBatchSize = 0;
seqOffset = NULL;
}
......@@ -55,6 +58,7 @@ T2TTrainer::~T2TTrainer()
{
delete[] buf;
delete[] buf2;
delete[] bufBatch;
delete[] seqLen;
delete[] seqLen2;
delete[] seqOffset;
......@@ -117,9 +121,11 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamBool(argc, argv, "smallbatch", &isSmallBatch, true);
LoadParamBool(argc, argv, "bigbatch", &isBigBatch, false);
LoadParamBool(argc, argv, "debug", &isDebugged, false);
LoadParamBool(argc, argv, "randbatch", &isRandomBatch, false);
buf = new int[bufSize];
buf2 = new int[bufSize];
bufBatch = new BatchNode[bufSize];
seqLen = new int[bufSize];
seqLen2 = new int[bufSize];
seqOffset = new int[bufSize];
......@@ -768,6 +774,12 @@ int T2TTrainer::LoadBatchLM(FILE * file,
return sc;
}
int CompareBatchNode(const void * a, const void * b)
{
return ((BatchNode*)b)->key - ((BatchNode*)a)->key;
}
/*
load a batch of sequences (for MT)
>> file - the handle to the data file
......@@ -797,10 +809,70 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int devID, XMem * mem,
bool isTraining)
{
if(nextSeq < 0 || nextSeq >= nseqBuf)
//if (nextSeq < 0 || nextSeq >= nseqBuf)
// LoadBuf(file, isSorted, 2);
if (nextBatch < 0 || nextBatch >= bufBatchSize) {
LoadBuf(file, isSorted, 2);
int seq = MAX(nextSeq, 0);
int seq = 0;
bufBatchSize = 0;
nextBatch = 0;
/* we segment the buffer into batches */
while (seq < nseqBuf) {
int wcEnc = 0;
int wcDec = 0;
int wnEnc = 0;
int wnDec = 0;
int maxEnc = 0;
int maxDec = 0;
int sc = 0;
while (seq + sc < nseqBuf) {
/* source-side sequence */
wnEnc = seqLen[seq + sc];
/* target-side sequence */
wnDec = isDoubledEnd ? seqLen[seq + sc + 1] : seqLen[seq + sc + 1] - 1;
int tcEnc = isBigBatch ? (wcEnc + wnEnc) : MAX(maxEnc, wnEnc) * (sc + 2) / 2;
int tcDec = isBigBatch ? (wcDec + wnDec) : MAX(maxDec, wnDec) * (sc + 2) / 2;
if (sc != 0 && sc > sBatch * 2 && (tcEnc > wBatch || tcDec > wBatch))
break;
wcEnc += wnEnc;
sc += 1;
if (maxEnc < wnEnc)
maxEnc = wnEnc;
wcDec += wnDec;
sc += 1;
if (maxDec < wnDec)
maxDec = wnDec;
}
BatchNode & batch = bufBatch[bufBatchSize];
batch.beg = seq;
batch.end = seq + sc;
batch.maxEnc = maxEnc;
batch.maxDec = maxDec;
batch.key = rand();
bufBatchSize++;
seq = seq + sc;
}
if(isRandomBatch)
qsort(bufBatch, bufBatchSize, sizeof(BatchNode), CompareBatchNode);
}
/*int seq = MAX(nextSeq, 0);
int wcEnc = 0;
int wcDec = 0;
int wnEnc = 0;
......@@ -813,10 +885,8 @@ int T2TTrainer::LoadBatchMT(FILE * file,
while(seq + sc < nseqBuf){
/* source-side sequence */
wnEnc = seqLen[seq + sc];
/* target-side sequence */
wnDec = isDoubledEnd ? seqLen[seq + sc + 1] : seqLen[seq + sc + 1] - 1;
int tcEnc = isBigBatch ? (wcEnc + wnEnc): MAX(maxEnc, wnEnc) * (sc + 2) / 2;
......@@ -841,7 +911,15 @@ int T2TTrainer::LoadBatchMT(FILE * file,
nextSeq = seq + sc;
if(sc <= 0)
return 0;
return 0;*/
BatchNode & batch = bufBatch[nextBatch++];
int seq = batch.beg;
int sc = batch.end - batch.beg;
int maxEnc = batch.maxEnc;
int maxDec = batch.maxDec;
CheckNTErrors(sc % 2 == 0, "The input samples must be paired");
int sCount = sc/2;
int seqSize = 0;
......
......@@ -33,6 +33,25 @@ using namespace nts;
namespace transformer
{
/* node to keep batch information */
struct BatchNode
{
/* begining position */
int beg;
/* end position */
int end;
/* maximum word number on the encoder side */
int maxEnc;
/* maximum word number on the decoder side */
int maxDec;
/* a key for sorting */
int key;
};
/* trainer of the T2T model */
class T2TTrainer
{
......@@ -49,9 +68,15 @@ public:
/* another buffer */
int * buf2;
/* batch buf */
BatchNode * bufBatch;
/* buffer size */
int bufSize;
/* size of batch buffer */
int bufBatchSize;
/* length of each sequence */
int * seqLen;
......@@ -66,6 +91,9 @@ public:
/* offset for next sequence in the buffer */
int nextSeq;
/* offset for next batch */
int nextBatch;
/* indicates whether the sequence is sorted by length */
bool isLenSorted;
......@@ -142,6 +170,9 @@ public:
/* counterpart of "isSmallBatch" */
bool isBigBatch;
/* randomize batches */
bool isRandomBatch;
/* indicates whether we intend to debug the net */
bool isDebugged;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论