Commit d0c724f6 by xiaotong

bug fixes and a new method of the class

parent 2f91e04e
...@@ -251,6 +251,15 @@ void T2TBatchLoader::ClearBuf() ...@@ -251,6 +251,15 @@ void T2TBatchLoader::ClearBuf()
nseqBuf = 0; nseqBuf = 0;
nextSeq = -1; nextSeq = -1;
} }
/*
set the random batch flag
>> flag - as it is
*/
void T2TBatchLoader::SetRandomBatch(bool flag)
{
isRandomBatch = flag;
}
/* /*
load a batch of sequences load a batch of sequences
...@@ -580,7 +589,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file, ...@@ -580,7 +589,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
int * batchEncValues = new int[batchEnc->unitNum]; int * batchEncValues = new int[batchEnc->unitNum];
int * batchDecValues = new int[batchDec->unitNum]; int * batchDecValues = new int[batchDec->unitNum];
int * labelValues = new int[label->unitNum]; int * labelValues = new int[label->unitNum];
//MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2]; MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2];
MTYPE * paddingDecOffsets = new MTYPE[sc * maxDec / 2]; MTYPE * paddingDecOffsets = new MTYPE[sc * maxDec / 2];
//MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2]; //MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2];
...@@ -595,17 +604,18 @@ int T2TBatchLoader::LoadBatchMT(FILE * file, ...@@ -595,17 +604,18 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
for(int w = 0; w < len; w++){ for(int w = 0; w < len; w++){
int num = buf[seqOffset[s] + w]; int num = buf[seqOffset[s] + w];
batchEncValues[batchEnc->GetOffset2D(sent, w)] = num; batchEncValues[batchEnc->GetOffset2D(sent, w)] = num;
//paddingEncOffsets[wCountEnc] = paddingEnc->GetOffset2D(sent, w); paddingEncOffsets[wCountEnc] = paddingEnc->GetOffset2D(sent, w);
wCountEnc++; wCountEnc++;
} }
} }
ws = wCountEnc; ws = wCountEnc;
batchEnc->SetData(batchEncValues, batchEnc->unitNum); batchEnc->SetData(batchEncValues, batchEnc->unitNum);
//paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc); paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc);
XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem); //XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
_ConvertDataType(batchEnc, tmp); //_ConvertDataType(batchEnc, tmp);
_NotEqual(tmp, paddingEnc, 0); //tmp->Dump(stderr, "tmp:");
DelTensorBuf(tmp); //_NotEqual(tmp, paddingEnc, 0);
//DelTensorBuf(tmp);
/* batch of the target-side sequences */ /* batch of the target-side sequences */
for(int s = seq + 1; s < seq + sc; s += 2){ for(int s = seq + 1; s < seq + sc; s += 2){
...@@ -660,7 +670,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file, ...@@ -660,7 +670,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
delete[] batchEncValues; delete[] batchEncValues;
delete[] batchDecValues; delete[] batchDecValues;
delete[] labelValues; delete[] labelValues;
//delete[] paddingEncOffsets; delete[] paddingEncOffsets;
delete[] paddingDecOffsets; delete[] paddingDecOffsets;
//delete[] goldOffsets; //delete[] goldOffsets;
...@@ -685,4 +695,4 @@ void T2TBatchLoader::Shuffle(const char * srcFile, const char * tgtFile) ...@@ -685,4 +695,4 @@ void T2TBatchLoader::Shuffle(const char * srcFile, const char * tgtFile)
} }
} }
\ No newline at end of file
...@@ -119,6 +119,9 @@ public: ...@@ -119,6 +119,9 @@ public:
/* clear data buffer */ /* clear data buffer */
void ClearBuf(); void ClearBuf();
/* set the random batch flag */
void SetRandomBatch(bool flag = true);
/* load a batch of sequences */ /* load a batch of sequences */
int LoadBatch(FILE * file, bool isLM, int LoadBatch(FILE * file, bool isLM,
...@@ -156,4 +159,4 @@ public: ...@@ -156,4 +159,4 @@ public:
}; };
} }
#endif #endif
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论