Commit 9b0d4df8 by xiaotong

fix the bug of suming the word numbers of both source and target-sides for the…

fix the bug of suming the word numbers of both source and target-sides for the loss computation in transformer MT models
parent 1c26ff5b
...@@ -221,9 +221,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -221,9 +221,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
//if (output.GetDim(0) > 1) //if (output.GetDim(0) > 1)
// PadOutput(&output, &gold, &paddingDec); // PadOutput(&output, &gold, &paddingDec);
//output.Dump(tmpFILE, "output: ");
//fflush(tmpFILE);
/* get probabilities */ /* get probabilities */
float prob = GetProb(&output, &gold, NULL); float prob = GetProb(&output, &gold, NULL);
...@@ -275,19 +272,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -275,19 +272,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
XPRINT(0, stderr, "\n"); XPRINT(0, stderr, "\n");
} }
//XMem * mem = model->mem;
//MTYPE used = 0;
//MTYPE total = 0;
//for(int i = 0; i < mem->blockNum; i++){
// if(mem->blocks[i].mem != NULL){
// used += mem->blocks[i].used;
// total += mem->blocks[i].size;
// }
//}
//fprintf(stderr, "%d %d %d %d mem: %lld %lld\n", paddingEnc.GetDim(0), paddingEnc.GetDim(1),
// paddingDec.GetDim(0), paddingDec.GetDim(1), used, total);
if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){ if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){
MakeCheckpoint(model, validFN, modelFN, "step", step); MakeCheckpoint(model, validFN, modelFN, "step", step);
nStepCheck = 0; nStepCheck = 0;
...@@ -884,6 +868,9 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -884,6 +868,9 @@ int T2TTrainer::LoadBatchMT(FILE * file,
paddingDec->SetZeroAll(); paddingDec->SetZeroAll();
gold->SetZeroAll(); gold->SetZeroAll();
int wCountEnc = 0;
int wCountDec = 0;
int wGold = 0;
wCount = 0; wCount = 0;
MTYPE * batchEncOffsets = new MTYPE[batchEnc->unitNum]; MTYPE * batchEncOffsets = new MTYPE[batchEnc->unitNum];
...@@ -911,8 +898,8 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -911,8 +898,8 @@ int T2TTrainer::LoadBatchMT(FILE * file,
batchEnc->SetDataBatched(batchEncOffsets, batchEncValues, wCount); batchEnc->SetDataBatched(batchEncOffsets, batchEncValues, wCount);
paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCount); paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCount);
int wCountDec = 0; wCountEnc = wCount;
int wGold = 0; wCount = 0;
/* batch of the target-side sequences */ /* batch of the target-side sequences */
for(int s = seq + 1; s < seq + sc; s += 2){ for(int s = seq + 1; s < seq + sc; s += 2){
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论