Commit 67096d88 by huchi

no message

parent 7b6840d4
......@@ -124,10 +124,10 @@ bool TrainDataSet::LoadBatchToBuf()
rewind(fp);
int srcVocabSize = 0;
int tgtVocabSize = 0;
fread(&srcVocabSize, sizeof(int), 1, fp);
fread(&tgtVocabSize, sizeof(int), 1, fp);
/*int srcVocabSize = 0;
int tgtVocabSize = 0;*/
fread(&srcVocabSize, sizeof(srcVocabSize), 1, fp);
fread(&tgtVocabSize, sizeof(tgtVocabSize), 1, fp);
fread(&totalSampleNum, sizeof(totalSampleNum), 1, fp);
}
......@@ -224,12 +224,6 @@ bool TrainDataSet::GetBatchSimple(XList* inputs, XList* golds)
InitTensor2D(paddingDec, sc, maxTgtLen, X_FLOAT);
InitTensor2D(label, sc, maxTgtLen, X_INT);
inputs->Add(batchEnc);
inputs->Add(paddingEnc);
golds->Add(batchDec);
golds->Add(paddingDec);
golds->Add(label);
curIdx += sc;
batchEnc->SetData(batchEncValues, batchEnc->unitNum);
......@@ -271,10 +265,8 @@ void TrainDataSet::Init(const char* dataFile, int myBatchSize, int myBucketSize,
fp = fopen(dataFile, "rb");
CheckNTErrors(fp, "can not open the training file");
int srcVocabSize = 0;
int tgtVocabSize = 0;
fread(&srcVocabSize, sizeof(int), 1, fp);
fread(&tgtVocabSize, sizeof(int), 1, fp);
fread(&srcVocabSize, sizeof(srcVocabSize), 1, fp);
fread(&tgtVocabSize, sizeof(tgtVocabSize), 1, fp);
CheckNTErrors(srcVocabSize > 0, "Invalid source vocabulary size");
CheckNTErrors(tgtVocabSize > 0, "Invalid target vocabulary size");
......@@ -286,6 +278,7 @@ void TrainDataSet::Init(const char* dataFile, int myBatchSize, int myBucketSize,
isTraining = training;
buf = new XList;
curIdx = 0;
}
/* group samples with similar length into buckets */
......
......@@ -127,6 +127,11 @@ public:
/* the buffer (a list) of samples */
XList * buf;
/* */
int srcVocabSize;
int tgtVocabSize;
public:
/* get the maximum source sentence length in a range */
......
......@@ -214,6 +214,12 @@ void Trainer::Train(const char* fn, const char* validFN,
batchLoader.GetBatchSimple((XList*)(&inputs), (XList*)(&golds));
batchEnc.SetDevice(model->devID);
paddingEnc.SetDevice(model->devID);
batchDec.SetDevice(model->devID);
paddingDec.SetDevice(model->devID);
label.SetDevice(model->devID);
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
/* output probabilities */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论