Commit 4a609624 by xiaotong

local machine version

parent c6f50a22
...@@ -114,6 +114,7 @@ void T2TTrainer::Init(int argc, char ** argv) ...@@ -114,6 +114,7 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false); LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false);
LoadParamInt(argc, argv, "updatestep", &updateStep, 1); LoadParamInt(argc, argv, "updatestep", &updateStep, 1);
LoadParamBool(argc, argv, "doubledend", &isDoubledEnd, false); LoadParamBool(argc, argv, "doubledend", &isDoubledEnd, false);
LoadParamBool(argc, argv, "smallbatch", &isSmallBatch, false);
buf = new int[bufSize]; buf = new int[bufSize];
buf2 = new int[bufSize]; buf2 = new int[bufSize];
...@@ -648,7 +649,8 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -648,7 +649,8 @@ int T2TTrainer::LoadBatchLM(FILE * file,
if(max < wn) if(max < wn)
max = wn; max = wn;
if(sc >= sBatch && wc >= wBatch) int tc = isSmallBatch ? max * sc : wc;
if(sc >= sBatch && tc >= wBatch)
break; break;
} }
...@@ -773,7 +775,8 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -773,7 +775,8 @@ int T2TTrainer::LoadBatchMT(FILE * file,
if(maxDec < wnDec) if(maxDec < wnDec)
maxDec = wnDec; maxDec = wnDec;
if(sc >= sBatch * 2 && wcEnc >= wBatch) int tc = isSmallBatch ? maxEnc * sc / 2 : wcEnc;
if(sc >= sBatch * 2 && tc >= wBatch)
break; break;
} }
......
...@@ -134,6 +134,11 @@ public: ...@@ -134,6 +134,11 @@ public:
/* indicates whether we double the </s> symble for the output of lms */ /* indicates whether we double the </s> symble for the output of lms */
bool isDoubledEnd; bool isDoubledEnd;
/* indicates whether we use batchsize = max * sc
rather rather than batchsize = word-number, where max is the maximum
length and sc is the sentence number */
bool isSmallBatch;
public: public:
/* constructor */ /* constructor */
T2TTrainer(); T2TTrainer();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论