Commit a400b619 by xuchen

gather implemention in mt

parent 7809ed05
......@@ -838,9 +838,9 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int dimsEnc[3] = {sCount, maxEnc, vsEnc};
int dimsDec[3] = {sCount, maxDec, vsDec};
InitTensor(batchEnc, 3, dimsEnc, X_FLOAT, 1.0F, devID, mem);
InitTensor(batchEnc, 2, dimsEnc, X_INT, 1.0F, -1);
InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem);
InitTensor(batchDec, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
InitTensor(batchDec, 2, dimsDec, X_INT, 1.0F, -1);
InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem);
InitTensor(gold, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
......@@ -857,7 +857,8 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int len = seqLen[s];
int sent = (s - seq)/2;
for(int w = 0; w < len; w++){
batchEnc->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
batchEnc->Set2DInt(buf[seqOffset[s] + w], sent, w);
//batchEnc->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
paddingEnc->Set2D(1.0F, sent, w);
wCount++;
}
......@@ -869,8 +870,9 @@ int T2TTrainer::LoadBatchMT(FILE * file,
CheckNTErrors(len <= maxDec, "Something is wrong!");
int sent = (s - seq - 1)/2;
for(int w = 0; w < len; w++){
batchDec->Set2DInt(buf[seqOffset[s] + w], sent, w);
//batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
paddingDec->Set2D(1.0F, sent, w);
batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
if(w > 0)
gold->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]);
if (w == len - 1) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论