Commit d0c724f6 by xiaotong

bug fixes and a new method of the class

parent 2f91e04e
......@@ -253,6 +253,15 @@ void T2TBatchLoader::ClearBuf()
}
/*
set the random batch flag
>> flag - as it is
*/
void T2TBatchLoader::SetRandomBatch(bool flag)
{
isRandomBatch = flag;
}
/*
load a batch of sequences
>> file - the handle to the data file
>> isLM - indicates whether the data is used for training lms
......@@ -580,7 +589,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
int * batchEncValues = new int[batchEnc->unitNum];
int * batchDecValues = new int[batchDec->unitNum];
int * labelValues = new int[label->unitNum];
//MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2];
MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2];
MTYPE * paddingDecOffsets = new MTYPE[sc * maxDec / 2];
//MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2];
......@@ -595,17 +604,18 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
for(int w = 0; w < len; w++){
int num = buf[seqOffset[s] + w];
batchEncValues[batchEnc->GetOffset2D(sent, w)] = num;
//paddingEncOffsets[wCountEnc] = paddingEnc->GetOffset2D(sent, w);
paddingEncOffsets[wCountEnc] = paddingEnc->GetOffset2D(sent, w);
wCountEnc++;
}
}
ws = wCountEnc;
batchEnc->SetData(batchEncValues, batchEnc->unitNum);
//paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc);
XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
_ConvertDataType(batchEnc, tmp);
_NotEqual(tmp, paddingEnc, 0);
DelTensorBuf(tmp);
paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc);
//XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
//_ConvertDataType(batchEnc, tmp);
//tmp->Dump(stderr, "tmp:");
//_NotEqual(tmp, paddingEnc, 0);
//DelTensorBuf(tmp);
/* batch of the target-side sequences */
for(int s = seq + 1; s < seq + sc; s += 2){
......@@ -660,7 +670,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
delete[] batchEncValues;
delete[] batchDecValues;
delete[] labelValues;
//delete[] paddingEncOffsets;
delete[] paddingEncOffsets;
delete[] paddingDecOffsets;
//delete[] goldOffsets;
......
......@@ -120,6 +120,9 @@ public:
/* clear data buffer */
void ClearBuf();
/* set the random batch flag */
void SetRandomBatch(bool flag = true);
/* load a batch of sequences */
int LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论