Commit 1739a22f by xiaotong

better definition of "batching"

parent 32e78bda
......@@ -815,25 +815,27 @@ int T2TTrainer::LoadBatchMT(FILE * file,
/* source-side sequence */
wnEnc = seqLen[seq + sc];
/* target-side sequence */
wnDec = isDoubledEnd ? seqLen[seq + sc + 1] : seqLen[seq + sc + 1] - 1;
int tcEnc = isBigBatch ? (wcEnc + wnEnc): MAX(maxEnc, wnEnc) * (sc + 2) / 2;
int tcDec = isBigBatch ? (wcDec + wnDec): MAX(maxDec, wnDec) * (sc + 2) / 2;
if(sc != 0 && sc > sBatch * 2 && (tcEnc > wBatch || tcDec > wBatch))
break;
wcEnc += wnEnc;
sc += 1;
if(maxEnc < wnEnc)
maxEnc = wnEnc;
/* target-side sequence */
int len = isDoubledEnd ? seqLen[seq + sc] : seqLen[seq + sc] - 1;
wnDec = len;
wcDec += wnDec;
sc += 1;
if(maxDec < wnDec)
maxDec = wnDec;
int tcEnc = isBigBatch ? wcEnc : maxEnc * sc / 2;
int tcDec = isBigBatch ? wcDec : maxDec * sc / 2;
if(sc >= sBatch * 2 && (tcEnc >= wBatch || tcDec >= wBatch))
break;
}
nextSeq = seq + sc;
......
......@@ -24,7 +24,7 @@
#include "Sum.h"
#include "SumDim.h"
#include "SumDim.cuh"
#include "../Shape/Unsqueeze.h"
#include "../shape/Unsqueeze.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "../movement/CopyValues.h"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论