Commit 9b0d4df8 by xiaotong

fix the bug of suming the word numbers of both source and target-sides for the…

fix the bug of suming the word numbers of both source and target-sides for the loss computation in transformer MT models
parent 1c26ff5b
......@@ -221,9 +221,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
//if (output.GetDim(0) > 1)
// PadOutput(&output, &gold, &paddingDec);
//output.Dump(tmpFILE, "output: ");
//fflush(tmpFILE);
/* get probabilities */
float prob = GetProb(&output, &gold, NULL);
......@@ -275,19 +272,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
XPRINT(0, stderr, "\n");
}
//XMem * mem = model->mem;
//MTYPE used = 0;
//MTYPE total = 0;
//for(int i = 0; i < mem->blockNum; i++){
// if(mem->blocks[i].mem != NULL){
// used += mem->blocks[i].used;
// total += mem->blocks[i].size;
// }
//}
//fprintf(stderr, "%d %d %d %d mem: %lld %lld\n", paddingEnc.GetDim(0), paddingEnc.GetDim(1),
// paddingDec.GetDim(0), paddingDec.GetDim(1), used, total);
if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){
MakeCheckpoint(model, validFN, modelFN, "step", step);
nStepCheck = 0;
......@@ -884,6 +868,9 @@ int T2TTrainer::LoadBatchMT(FILE * file,
paddingDec->SetZeroAll();
gold->SetZeroAll();
int wCountEnc = 0;
int wCountDec = 0;
int wGold = 0;
wCount = 0;
MTYPE * batchEncOffsets = new MTYPE[batchEnc->unitNum];
......@@ -911,8 +898,8 @@ int T2TTrainer::LoadBatchMT(FILE * file,
batchEnc->SetDataBatched(batchEncOffsets, batchEncValues, wCount);
paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCount);
int wCountDec = 0;
int wGold = 0;
wCountEnc = wCount;
wCount = 0;
/* batch of the target-side sequences */
for(int s = seq + 1; s < seq + sc; s += 2){
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论