Commit 12195a67 by xiaotong

bug fixes of the input of the decoder for t2t

parent 645c32dc
...@@ -179,8 +179,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -179,8 +179,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
wordCount = 0; wordCount = 0;
loss = 0; loss = 0;
/* batch of input sequences */ /* batch of sequences (on the encoder and decoder sides) */
XTensor batch; XTensor batchEnc;
XTensor batchDec;
/* padding */ /* padding */
XTensor paddingEnc; XTensor paddingEnc;
...@@ -192,21 +193,21 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -192,21 +193,21 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* label smoothed gold standard (if needed) */ /* label smoothed gold standard (if needed) */
XTensor goldSmoothed; XTensor goldSmoothed;
while (LoadBatch(file, model->isLM, &batch, &paddingEnc, &gold, &paddingDec, while (LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold,
NULL, vSize, vSizeTgt, NULL, vSize, vSizeTgt,
sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)) sBatchSize, wBatchSize, isLenSorted, wc, devID, mem))
{ {
CheckNTErrors(batch.order == 3, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 3, "wrong tensor order of the sequence batch");
/* output probabilities */ /* output probabilities */
XTensor output; XTensor output;
/* make the network */ /* make the network */
if(model->isLM) if(model->isLM)
model->MakeLM(batch, output, paddingEnc, true); model->MakeLM(batchEnc, output, paddingEnc, true);
else if(model->isMT) else if(model->isMT)
model->MakeMT(batch, gold, output, paddingEnc, true); model->MakeMT(batchEnc, batchDec, output, paddingEnc, true);
else{ else{
ShowNTErrors("Illegal model type!"); ShowNTErrors("Illegal model type!");
} }
...@@ -330,7 +331,8 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -330,7 +331,8 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
wordCount = 0; wordCount = 0;
/* batch of input sequences */ /* batch of input sequences */
XTensor batch; XTensor batchEnc;
XTensor batchDec;
/* padding */ /* padding */
XTensor paddingEnc; XTensor paddingEnc;
...@@ -344,27 +346,27 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -344,27 +346,27 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
ClearBuf(); ClearBuf();
while(LoadBatch(file, model->isLM, &batch, &paddingEnc, &gold, &paddingDec, while(LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold,
seqs, vSize, vSizeTgt, seqs, vSize, vSizeTgt,
1, 1, false, wc, devID, mem)) 1, 1, false, wc, devID, mem))
{ {
CheckNTErrors(batch.order == 3, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 3, "wrong tensor order of the sequence batch");
/* output probabilities */ /* output probabilities */
XTensor output; XTensor output;
/* make the network */ /* make the network */
if(model->isLM) if(model->isLM)
model->MakeLM(batch, output, paddingEnc, false); model->MakeLM(batchEnc, output, paddingEnc, false);
else if(model->isMT) else if(model->isMT)
model->MakeMT(batch, gold, output, paddingEnc, false); model->MakeMT(batchEnc, batchDec, output, paddingEnc, false);
else{ else{
ShowNTErrors("Illegal model type!"); ShowNTErrors("Illegal model type!");
} }
int bSize = batch.GetDim(0); int bSize = batchDec.GetDim(0);
int length = batch.GetDim(1); int length = batchDec.GetDim(1);
/* prediction probabilities */ /* prediction probabilities */
XTensor probs; XTensor probs;
...@@ -589,6 +591,7 @@ load a batch of sequences ...@@ -589,6 +591,7 @@ load a batch of sequences
>> paddingEnc - padding of the input sequences >> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences >> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences >> paddingDec - padding of the output sequences
>> gold - gold standard
>> seqs - keep the sequences in an array >> seqs - keep the sequences in an array
>> vsEnc - size of the encoder vocabulary >> vsEnc - size of the encoder vocabulary
>> vsDec - size of the decoder vocabulary >> vsDec - size of the decoder vocabulary
...@@ -602,19 +605,20 @@ load a batch of sequences ...@@ -602,19 +605,20 @@ load a batch of sequences
int T2TTrainer::LoadBatch(FILE * file, bool isLM, int T2TTrainer::LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec, XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs, int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
int devID, XMem * mem) int devID, XMem * mem)
{ {
if(isLM){ if(isLM){
return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, seqs, return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, gold,
vsEnc, sBatch, wBatch, seqs, vsEnc, sBatch, wBatch,
isSorted, wCount, devID, mem); isSorted, wCount, devID, mem);
} }
else{ else{
return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, seqs, return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, gold,
vsEnc, vsDec, sBatch, wBatch, seqs, vsEnc, vsDec, sBatch, wBatch,
isSorted, wCount, devID, mem); isSorted, wCount, devID, mem);
} }
} }
...@@ -627,6 +631,7 @@ load a batch of sequences (for LM) ...@@ -627,6 +631,7 @@ load a batch of sequences (for LM)
>> paddingEnc - padding of the input sequences >> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences >> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences >> paddingDec - padding of the output sequences
>> gold - gold standard
>> seqs - keep the sequences in an array >> seqs - keep the sequences in an array
>> vs - vocabulary size >> vs - vocabulary size
>> sBatch - batch size of sequences >> sBatch - batch size of sequences
...@@ -639,6 +644,7 @@ load a batch of sequences (for LM) ...@@ -639,6 +644,7 @@ load a batch of sequences (for LM)
int T2TTrainer::LoadBatchLM(FILE * file, int T2TTrainer::LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec, XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs, int * seqs,
int vs, int sBatch, int wBatch, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
...@@ -680,21 +686,21 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -680,21 +686,21 @@ int T2TTrainer::LoadBatchLM(FILE * file,
InitTensor(batchEnc, 3, dims, X_FLOAT, 1.0F, devID, mem); InitTensor(batchEnc, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingEnc, sc, max, X_FLOAT, devID, mem); InitTensor2D(paddingEnc, sc, max, X_FLOAT, devID, mem);
InitTensor(batchDec, 3, dims, X_FLOAT, 1.0F, devID, mem); InitTensor(gold, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingDec, sc, max, X_FLOAT, devID, mem); InitTensor2D(paddingDec, sc, max, X_FLOAT, devID, mem);
XNoder::MakeGrad(batchEnc); XNoder::MakeGrad(batchEnc);
XNoder::MakeGrad(paddingEnc); XNoder::MakeGrad(paddingEnc);
XNoder::MakeGrad(batchDec); XNoder::MakeGrad(gold);
XNoder::MakeGrad(paddingDec); XNoder::MakeGrad(paddingDec);
batchEnc->SetZeroAll(); batchEnc->SetZeroAll();
paddingEnc->SetZeroAll(); paddingEnc->SetZeroAll();
batchDec->SetZeroAll(); gold->SetZeroAll();
paddingDec->SetZeroAll(); paddingDec->SetZeroAll();
batchEnc->grad->SetZeroAll(); batchEnc->grad->SetZeroAll();
paddingEnc->grad->SetZeroAll(); paddingEnc->grad->SetZeroAll();
batchDec->grad->SetZeroAll(); gold->grad->SetZeroAll();
paddingDec->grad->SetZeroAll(); paddingDec->grad->SetZeroAll();
int seqSize = 0; int seqSize = 0;
...@@ -710,13 +716,13 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -710,13 +716,13 @@ int T2TTrainer::LoadBatchLM(FILE * file,
paddingEnc->Set2D(1.0F, s - seq, w); paddingEnc->Set2D(1.0F, s - seq, w);
paddingDec->Set2D(1.0F, s - seq, w); paddingDec->Set2D(1.0F, s - seq, w);
if (w > 0) if (w > 0)
batchDec->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]); gold->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]);
if (w == len - 1) { if (w == len - 1) {
if (isDoubledEnd) if (isDoubledEnd)
batchDec->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]); gold->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
else else
batchDec->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w + 1]); gold->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w + 1]);
} }
wCount++; wCount++;
...@@ -747,6 +753,7 @@ load a batch of sequences (for MT) ...@@ -747,6 +753,7 @@ load a batch of sequences (for MT)
>> paddingEnc - padding of the input sequences >> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences >> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences >> paddingDec - padding of the output sequences
>> gold - gold standard
>> seqs - keep the sequences in an array >> seqs - keep the sequences in an array
>> vsEnc - size of the encoder vocabulary >> vsEnc - size of the encoder vocabulary
>> vsDec - size of the decoder vocabulary >> vsDec - size of the decoder vocabulary
...@@ -760,6 +767,7 @@ load a batch of sequences (for MT) ...@@ -760,6 +767,7 @@ load a batch of sequences (for MT)
int T2TTrainer::LoadBatchMT(FILE * file, int T2TTrainer::LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec, XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs, int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
...@@ -817,11 +825,13 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -817,11 +825,13 @@ int T2TTrainer::LoadBatchMT(FILE * file,
InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem); InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem);
InitTensor(batchDec, 3, dimsDec, X_FLOAT, 1.0F, devID, mem); InitTensor(batchDec, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem); InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem);
InitTensor(gold, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
batchEnc->SetZeroAll(); batchEnc->SetZeroAll();
paddingEnc->SetZeroAll(); paddingEnc->SetZeroAll();
batchDec->SetZeroAll(); batchDec->SetZeroAll();
paddingDec->SetZeroAll(); paddingDec->SetZeroAll();
gold->SetZeroAll();
wCount = 0; wCount = 0;
...@@ -843,13 +853,14 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -843,13 +853,14 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int sent = (s - seq - 1)/2; int sent = (s - seq - 1)/2;
for(int w = 0; w < len; w++){ for(int w = 0; w < len; w++){
paddingDec->Set2D(1.0F, sent, w); paddingDec->Set2D(1.0F, sent, w);
batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
if(w > 0) if(w > 0)
batchDec->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]); gold->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]);
if (w == len - 1) { if (w == len - 1) {
if(isDoubledEnd) if(isDoubledEnd)
batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]); gold->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
else else
batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w + 1]); gold->Set3D(1.0F, sent, w, buf[seqOffset[s] + w + 1]);
} }
wCount++; wCount++;
......
...@@ -168,6 +168,7 @@ public: ...@@ -168,6 +168,7 @@ public:
int LoadBatch(FILE * file, bool isLM, int LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec, XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs, int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
...@@ -177,6 +178,7 @@ public: ...@@ -177,6 +178,7 @@ public:
int LoadBatchLM(FILE * file, int LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec, XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs, int vs, int sBatch, int wBatch, int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
int devID, XMem * mem); int devID, XMem * mem);
...@@ -185,6 +187,7 @@ public: ...@@ -185,6 +187,7 @@ public:
int LoadBatchMT(FILE * file, int LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec, XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch, int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
int devID, XMem * mem); int devID, XMem * mem);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论