Commit 67096d88 by huchi

no message

parent 7b6840d4
...@@ -124,10 +124,10 @@ bool TrainDataSet::LoadBatchToBuf() ...@@ -124,10 +124,10 @@ bool TrainDataSet::LoadBatchToBuf()
rewind(fp); rewind(fp);
int srcVocabSize = 0; /*int srcVocabSize = 0;
int tgtVocabSize = 0; int tgtVocabSize = 0;*/
fread(&srcVocabSize, sizeof(int), 1, fp); fread(&srcVocabSize, sizeof(srcVocabSize), 1, fp);
fread(&tgtVocabSize, sizeof(int), 1, fp); fread(&tgtVocabSize, sizeof(tgtVocabSize), 1, fp);
fread(&totalSampleNum, sizeof(totalSampleNum), 1, fp); fread(&totalSampleNum, sizeof(totalSampleNum), 1, fp);
} }
...@@ -224,12 +224,6 @@ bool TrainDataSet::GetBatchSimple(XList* inputs, XList* golds) ...@@ -224,12 +224,6 @@ bool TrainDataSet::GetBatchSimple(XList* inputs, XList* golds)
InitTensor2D(paddingDec, sc, maxTgtLen, X_FLOAT); InitTensor2D(paddingDec, sc, maxTgtLen, X_FLOAT);
InitTensor2D(label, sc, maxTgtLen, X_INT); InitTensor2D(label, sc, maxTgtLen, X_INT);
inputs->Add(batchEnc);
inputs->Add(paddingEnc);
golds->Add(batchDec);
golds->Add(paddingDec);
golds->Add(label);
curIdx += sc; curIdx += sc;
batchEnc->SetData(batchEncValues, batchEnc->unitNum); batchEnc->SetData(batchEncValues, batchEnc->unitNum);
...@@ -271,10 +265,8 @@ void TrainDataSet::Init(const char* dataFile, int myBatchSize, int myBucketSize, ...@@ -271,10 +265,8 @@ void TrainDataSet::Init(const char* dataFile, int myBatchSize, int myBucketSize,
fp = fopen(dataFile, "rb"); fp = fopen(dataFile, "rb");
CheckNTErrors(fp, "can not open the training file"); CheckNTErrors(fp, "can not open the training file");
int srcVocabSize = 0; fread(&srcVocabSize, sizeof(srcVocabSize), 1, fp);
int tgtVocabSize = 0; fread(&tgtVocabSize, sizeof(tgtVocabSize), 1, fp);
fread(&srcVocabSize, sizeof(int), 1, fp);
fread(&tgtVocabSize, sizeof(int), 1, fp);
CheckNTErrors(srcVocabSize > 0, "Invalid source vocabulary size"); CheckNTErrors(srcVocabSize > 0, "Invalid source vocabulary size");
CheckNTErrors(tgtVocabSize > 0, "Invalid target vocabulary size"); CheckNTErrors(tgtVocabSize > 0, "Invalid target vocabulary size");
...@@ -286,6 +278,7 @@ void TrainDataSet::Init(const char* dataFile, int myBatchSize, int myBucketSize, ...@@ -286,6 +278,7 @@ void TrainDataSet::Init(const char* dataFile, int myBatchSize, int myBucketSize,
isTraining = training; isTraining = training;
buf = new XList; buf = new XList;
curIdx = 0;
} }
/* group samples with similar length into buckets */ /* group samples with similar length into buckets */
......
...@@ -127,6 +127,11 @@ public: ...@@ -127,6 +127,11 @@ public:
/* the buffer (a list) of samples */ /* the buffer (a list) of samples */
XList * buf; XList * buf;
/* */
int srcVocabSize;
int tgtVocabSize;
public: public:
/* get the maximum source sentence length in a range */ /* get the maximum source sentence length in a range */
......
...@@ -214,6 +214,12 @@ void Trainer::Train(const char* fn, const char* validFN, ...@@ -214,6 +214,12 @@ void Trainer::Train(const char* fn, const char* validFN,
batchLoader.GetBatchSimple((XList*)(&inputs), (XList*)(&golds)); batchLoader.GetBatchSimple((XList*)(&inputs), (XList*)(&golds));
batchEnc.SetDevice(model->devID);
paddingEnc.SetDevice(model->devID);
batchDec.SetDevice(model->devID);
paddingDec.SetDevice(model->devID);
label.SetDevice(model->devID);
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
/* output probabilities */ /* output probabilities */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论