Commit 411cff4c by xiaotong

use smaller batches for training

parent 383759ae
......@@ -114,7 +114,8 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false);
LoadParamInt(argc, argv, "updatestep", &updateStep, 1);
LoadParamBool(argc, argv, "doubledend", &isDoubledEnd, false);
LoadParamBool(argc, argv, "smallbatch", &isSmallBatch, false);
LoadParamBool(argc, argv, "smallbatch", &isSmallBatch, true);
LoadParamBool(argc, argv, "bigbatch", &isBigBatch, false);
buf = new int[bufSize];
buf2 = new int[bufSize];
......@@ -692,7 +693,7 @@ int T2TTrainer::LoadBatchLM(FILE * file,
if(max < wn)
max = wn;
int tc = isSmallBatch ? max * sc : wc;
int tc = isBigBatch ? wc : max * sc;
if(sc >= sBatch && tc >= wBatch)
break;
}
......@@ -836,8 +837,9 @@ int T2TTrainer::LoadBatchMT(FILE * file,
if(maxDec < wnDec)
maxDec = wnDec;
int tc = isSmallBatch ? maxEnc * sc / 2 : wcEnc;
if(sc >= sBatch * 2 && tc >= wBatch)
int tcEnc = isBigBatch ? wcEnc : maxEnc * sc / 2;
int tcDec = isBigBatch ? wcDec : maxDec * sc / 2;
if(sc >= sBatch * 2 && (tcEnc >= wBatch || tcDec >= wBatch))
break;
}
......
......@@ -143,6 +143,9 @@ public:
length and sc is the sentence number */
bool isSmallBatch;
/* counterpart of "isSmallBatch" */
bool isBigBatch;
public:
/* constructor */
T2TTrainer();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论