Commit c260abfb by xiaotong

bug fix

parent 2161f65b
...@@ -908,10 +908,10 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -908,10 +908,10 @@ int T2TTrainer::LoadBatchMT(FILE * file,
maxDec = wnDec; maxDec = wnDec;
} }
nextSeq = seq + sc; nextSeq = seq + sc;*/
if(sc <= 0) if(bufBatchSize <= 0)
return 0;*/ return 0;
BatchNode & batch = bufBatch[nextBatch++]; BatchNode & batch = bufBatch[nextBatch++];
int seq = batch.beg; int seq = batch.beg;
...@@ -944,8 +944,6 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -944,8 +944,6 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int * batchEncValues = new int[batchEnc->unitNum]; int * batchEncValues = new int[batchEnc->unitNum];
int * batchDecValues = new int[batchDec->unitNum]; int * batchDecValues = new int[batchDec->unitNum];
//MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2];
//MTYPE * paddingDecOffsets = new MTYPE[sc * maxDec / 2];
MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2]; MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2];
memset(batchEncValues, 0, sizeof(int) * batchEnc->unitNum); memset(batchEncValues, 0, sizeof(int) * batchEnc->unitNum);
...@@ -958,13 +956,11 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -958,13 +956,11 @@ int T2TTrainer::LoadBatchMT(FILE * file,
for(int w = 0; w < len; w++){ for(int w = 0; w < len; w++){
int num = buf[seqOffset[s] + w]; int num = buf[seqOffset[s] + w];
batchEncValues[batchEnc->GetOffset2D(sent, w)] = num; batchEncValues[batchEnc->GetOffset2D(sent, w)] = num;
//paddingEncOffsets[wCountEnc] = paddingEnc->GetOffset2D(sent, w);
wCountEnc++; wCountEnc++;
} }
} }
batchEnc->SetData(batchEncValues, batchEnc->unitNum); batchEnc->SetData(batchEncValues, batchEnc->unitNum);
//paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc);
XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem); XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
_ConvertDataType(batchEnc, tmp); _ConvertDataType(batchEnc, tmp);
_NotEqual(tmp, paddingEnc, 0); _NotEqual(tmp, paddingEnc, 0);
...@@ -978,7 +974,6 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -978,7 +974,6 @@ int T2TTrainer::LoadBatchMT(FILE * file,
for(int w = 0; w < len; w++){ for(int w = 0; w < len; w++){
int num = buf[seqOffset[s] + w]; int num = buf[seqOffset[s] + w];
batchDecValues[batchDec->GetOffset2D(sent, w)] = num; batchDecValues[batchDec->GetOffset2D(sent, w)] = num;
//paddingDecOffsets[wCountDec] = paddingDec->GetOffset2D(sent, w);
if (w > 0) if (w > 0)
goldOffsets[wGold++] = gold->GetOffset3D(sent, w - 1, buf[seqOffset[s] + w]); goldOffsets[wGold++] = gold->GetOffset3D(sent, w - 1, buf[seqOffset[s] + w]);
...@@ -1002,7 +997,6 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -1002,7 +997,6 @@ int T2TTrainer::LoadBatchMT(FILE * file,
} }
batchDec->SetData(batchDecValues, batchDec->unitNum); batchDec->SetData(batchDecValues, batchDec->unitNum);
//paddingDec->SetDataBatched(paddingDecOffsets, 1.0F, wCountDec);
XTensor * tmp2 = NewTensorBuf(paddingDec, devID, mem); XTensor * tmp2 = NewTensorBuf(paddingDec, devID, mem);
_ConvertDataType(batchDec, tmp2); _ConvertDataType(batchDec, tmp2);
...@@ -1013,8 +1007,6 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -1013,8 +1007,6 @@ int T2TTrainer::LoadBatchMT(FILE * file,
delete[] batchEncValues; delete[] batchEncValues;
delete[] batchDecValues; delete[] batchDecValues;
//delete[] paddingEncOffsets;
//delete[] paddingDecOffsets;
delete[] goldOffsets; delete[] goldOffsets;
return sc; return sc;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论