Commit d0c724f6 by xiaotong

bug fixes and a new method of the class

parent 2f91e04e
...@@ -253,6 +253,15 @@ void T2TBatchLoader::ClearBuf() ...@@ -253,6 +253,15 @@ void T2TBatchLoader::ClearBuf()
} }
/* /*
set the random batch flag
>> flag - as it is
*/
void T2TBatchLoader::SetRandomBatch(bool flag)
{
isRandomBatch = flag;
}
/*
load a batch of sequences load a batch of sequences
>> file - the handle to the data file >> file - the handle to the data file
>> isLM - indicates whether the data is used for training lms >> isLM - indicates whether the data is used for training lms
...@@ -580,7 +589,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file, ...@@ -580,7 +589,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
int * batchEncValues = new int[batchEnc->unitNum]; int * batchEncValues = new int[batchEnc->unitNum];
int * batchDecValues = new int[batchDec->unitNum]; int * batchDecValues = new int[batchDec->unitNum];
int * labelValues = new int[label->unitNum]; int * labelValues = new int[label->unitNum];
//MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2]; MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2];
MTYPE * paddingDecOffsets = new MTYPE[sc * maxDec / 2]; MTYPE * paddingDecOffsets = new MTYPE[sc * maxDec / 2];
//MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2]; //MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2];
...@@ -595,17 +604,18 @@ int T2TBatchLoader::LoadBatchMT(FILE * file, ...@@ -595,17 +604,18 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
for(int w = 0; w < len; w++){ for(int w = 0; w < len; w++){
int num = buf[seqOffset[s] + w]; int num = buf[seqOffset[s] + w];
batchEncValues[batchEnc->GetOffset2D(sent, w)] = num; batchEncValues[batchEnc->GetOffset2D(sent, w)] = num;
//paddingEncOffsets[wCountEnc] = paddingEnc->GetOffset2D(sent, w); paddingEncOffsets[wCountEnc] = paddingEnc->GetOffset2D(sent, w);
wCountEnc++; wCountEnc++;
} }
} }
ws = wCountEnc; ws = wCountEnc;
batchEnc->SetData(batchEncValues, batchEnc->unitNum); batchEnc->SetData(batchEncValues, batchEnc->unitNum);
//paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc); paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc);
XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem); //XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
_ConvertDataType(batchEnc, tmp); //_ConvertDataType(batchEnc, tmp);
_NotEqual(tmp, paddingEnc, 0); //tmp->Dump(stderr, "tmp:");
DelTensorBuf(tmp); //_NotEqual(tmp, paddingEnc, 0);
//DelTensorBuf(tmp);
/* batch of the target-side sequences */ /* batch of the target-side sequences */
for(int s = seq + 1; s < seq + sc; s += 2){ for(int s = seq + 1; s < seq + sc; s += 2){
...@@ -660,7 +670,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file, ...@@ -660,7 +670,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
delete[] batchEncValues; delete[] batchEncValues;
delete[] batchDecValues; delete[] batchDecValues;
delete[] labelValues; delete[] labelValues;
//delete[] paddingEncOffsets; delete[] paddingEncOffsets;
delete[] paddingDecOffsets; delete[] paddingDecOffsets;
//delete[] goldOffsets; //delete[] goldOffsets;
......
...@@ -120,6 +120,9 @@ public: ...@@ -120,6 +120,9 @@ public:
/* clear data buffer */ /* clear data buffer */
void ClearBuf(); void ClearBuf();
/* set the random batch flag */
void SetRandomBatch(bool flag = true);
/* load a batch of sequences */ /* load a batch of sequences */
int LoadBatch(FILE * file, bool isLM, int LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论