Commit 12195a67 by xiaotong

bug fixes of the input of the decoder for t2t

parent 645c32dc
......@@ -179,8 +179,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
wordCount = 0;
loss = 0;
/* batch of input sequences */
XTensor batch;
/* batch of sequences (on the encoder and decoder sides) */
XTensor batchEnc;
XTensor batchDec;
/* padding */
XTensor paddingEnc;
......@@ -192,21 +193,21 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* label smoothed gold standard (if needed) */
XTensor goldSmoothed;
while (LoadBatch(file, model->isLM, &batch, &paddingEnc, &gold, &paddingDec,
while (LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold,
NULL, vSize, vSizeTgt,
sBatchSize, wBatchSize, isLenSorted, wc, devID, mem))
{
CheckNTErrors(batch.order == 3, "wrong tensor order of the sequence batch");
CheckNTErrors(batchEnc.order == 3, "wrong tensor order of the sequence batch");
/* output probabilities */
XTensor output;
/* make the network */
if(model->isLM)
model->MakeLM(batch, output, paddingEnc, true);
model->MakeLM(batchEnc, output, paddingEnc, true);
else if(model->isMT)
model->MakeMT(batch, gold, output, paddingEnc, true);
model->MakeMT(batchEnc, batchDec, output, paddingEnc, true);
else{
ShowNTErrors("Illegal model type!");
}
......@@ -330,7 +331,8 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
wordCount = 0;
/* batch of input sequences */
XTensor batch;
XTensor batchEnc;
XTensor batchDec;
/* padding */
XTensor paddingEnc;
......@@ -344,27 +346,27 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
ClearBuf();
while(LoadBatch(file, model->isLM, &batch, &paddingEnc, &gold, &paddingDec,
while(LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold,
seqs, vSize, vSizeTgt,
1, 1, false, wc, devID, mem))
{
CheckNTErrors(batch.order == 3, "wrong tensor order of the sequence batch");
CheckNTErrors(batchEnc.order == 3, "wrong tensor order of the sequence batch");
/* output probabilities */
XTensor output;
/* make the network */
if(model->isLM)
model->MakeLM(batch, output, paddingEnc, false);
model->MakeLM(batchEnc, output, paddingEnc, false);
else if(model->isMT)
model->MakeMT(batch, gold, output, paddingEnc, false);
model->MakeMT(batchEnc, batchDec, output, paddingEnc, false);
else{
ShowNTErrors("Illegal model type!");
}
int bSize = batch.GetDim(0);
int length = batch.GetDim(1);
int bSize = batchDec.GetDim(0);
int length = batchDec.GetDim(1);
/* prediction probabilities */
XTensor probs;
......@@ -589,6 +591,7 @@ load a batch of sequences
>> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> gold - gold standard
>> seqs - keep the sequences in an array
>> vsEnc - size of the encoder vocabulary
>> vsDec - size of the decoder vocabulary
......@@ -602,19 +605,20 @@ load a batch of sequences
int T2TTrainer::LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem)
{
if(isLM){
return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, seqs,
vsEnc, sBatch, wBatch,
return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, gold,
seqs, vsEnc, sBatch, wBatch,
isSorted, wCount, devID, mem);
}
else{
return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, seqs,
vsEnc, vsDec, sBatch, wBatch,
return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, gold,
seqs, vsEnc, vsDec, sBatch, wBatch,
isSorted, wCount, devID, mem);
}
}
......@@ -627,6 +631,7 @@ load a batch of sequences (for LM)
>> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> gold - gold standard
>> seqs - keep the sequences in an array
>> vs - vocabulary size
>> sBatch - batch size of sequences
......@@ -639,6 +644,7 @@ load a batch of sequences (for LM)
int T2TTrainer::LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs,
int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
......@@ -680,21 +686,21 @@ int T2TTrainer::LoadBatchLM(FILE * file,
InitTensor(batchEnc, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingEnc, sc, max, X_FLOAT, devID, mem);
InitTensor(batchDec, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor(gold, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingDec, sc, max, X_FLOAT, devID, mem);
XNoder::MakeGrad(batchEnc);
XNoder::MakeGrad(paddingEnc);
XNoder::MakeGrad(batchDec);
XNoder::MakeGrad(gold);
XNoder::MakeGrad(paddingDec);
batchEnc->SetZeroAll();
paddingEnc->SetZeroAll();
batchDec->SetZeroAll();
gold->SetZeroAll();
paddingDec->SetZeroAll();
batchEnc->grad->SetZeroAll();
paddingEnc->grad->SetZeroAll();
batchDec->grad->SetZeroAll();
gold->grad->SetZeroAll();
paddingDec->grad->SetZeroAll();
int seqSize = 0;
......@@ -710,13 +716,13 @@ int T2TTrainer::LoadBatchLM(FILE * file,
paddingEnc->Set2D(1.0F, s - seq, w);
paddingDec->Set2D(1.0F, s - seq, w);
if (w > 0)
batchDec->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]);
gold->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]);
if (w == len - 1) {
if (isDoubledEnd)
batchDec->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
gold->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
else
batchDec->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w + 1]);
gold->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w + 1]);
}
wCount++;
......@@ -747,6 +753,7 @@ load a batch of sequences (for MT)
>> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> gold - gold standard
>> seqs - keep the sequences in an array
>> vsEnc - size of the encoder vocabulary
>> vsDec - size of the decoder vocabulary
......@@ -760,6 +767,7 @@ load a batch of sequences (for MT)
int T2TTrainer::LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
......@@ -817,11 +825,13 @@ int T2TTrainer::LoadBatchMT(FILE * file,
InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem);
InitTensor(batchDec, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem);
InitTensor(gold, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
batchEnc->SetZeroAll();
paddingEnc->SetZeroAll();
batchDec->SetZeroAll();
paddingDec->SetZeroAll();
gold->SetZeroAll();
wCount = 0;
......@@ -843,13 +853,14 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int sent = (s - seq - 1)/2;
for(int w = 0; w < len; w++){
paddingDec->Set2D(1.0F, sent, w);
batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
if(w > 0)
batchDec->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]);
gold->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]);
if (w == len - 1) {
if(isDoubledEnd)
batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
gold->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
else
batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w + 1]);
gold->Set3D(1.0F, sent, w, buf[seqOffset[s] + w + 1]);
}
wCount++;
......
......@@ -168,6 +168,7 @@ public:
int LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
......@@ -177,6 +178,7 @@ public:
int LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
......@@ -185,6 +187,7 @@ public:
int LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论