Commit 1739a22f by xiaotong

better definition of "batching"

parent 32e78bda
...@@ -815,25 +815,27 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -815,25 +815,27 @@ int T2TTrainer::LoadBatchMT(FILE * file,
/* source-side sequence */ /* source-side sequence */
wnEnc = seqLen[seq + sc]; wnEnc = seqLen[seq + sc];
/* target-side sequence */
wnDec = isDoubledEnd ? seqLen[seq + sc + 1] : seqLen[seq + sc + 1] - 1;
int tcEnc = isBigBatch ? (wcEnc + wnEnc): MAX(maxEnc, wnEnc) * (sc + 2) / 2;
int tcDec = isBigBatch ? (wcDec + wnDec): MAX(maxDec, wnDec) * (sc + 2) / 2;
if(sc != 0 && sc > sBatch * 2 && (tcEnc > wBatch || tcDec > wBatch))
break;
wcEnc += wnEnc; wcEnc += wnEnc;
sc += 1; sc += 1;
if(maxEnc < wnEnc) if(maxEnc < wnEnc)
maxEnc = wnEnc; maxEnc = wnEnc;
/* target-side sequence */
int len = isDoubledEnd ? seqLen[seq + sc] : seqLen[seq + sc] - 1;
wnDec = len;
wcDec += wnDec; wcDec += wnDec;
sc += 1; sc += 1;
if(maxDec < wnDec) if(maxDec < wnDec)
maxDec = wnDec; maxDec = wnDec;
int tcEnc = isBigBatch ? wcEnc : maxEnc * sc / 2;
int tcDec = isBigBatch ? wcDec : maxDec * sc / 2;
if(sc >= sBatch * 2 && (tcEnc >= wBatch || tcDec >= wBatch))
break;
} }
nextSeq = seq + sc; nextSeq = seq + sc;
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "Sum.h" #include "Sum.h"
#include "SumDim.h" #include "SumDim.h"
#include "SumDim.cuh" #include "SumDim.cuh"
#include "../Shape/Unsqueeze.h" #include "../shape/Unsqueeze.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XUtility.h" #include "../../XUtility.h"
#include "../movement/CopyValues.h" #include "../movement/CopyValues.h"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论