Commit 36c80fc7 by xiaotong

bug fixes

parent 699ddac6
......@@ -44,10 +44,13 @@ T2TModel::T2TModel()
/* de-constructor */
T2TModel::~T2TModel()
{
delete mem;
delete encoder;
delete decoder;
delete outputLayer;
/* we delete "mem" at the end because other members are using it and we must
remove the memory space before all tensors are destroyed. */
delete mem;
}
/*
......
......@@ -75,6 +75,10 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
encoding = model->MakeEncoder(*input, maskEnc, false);
encoding.SetName(ENCODING_NAME);
/* max output-length = 2 * source-length */
maxLength = input->GetDim(-2) * 2;
CheckNTErrors(maxLength > 0, "no max length specified!");
T2TStateBundle * states = new T2TStateBundle[maxLength];
T2TStateBundle * first = states;
......
......@@ -19,6 +19,7 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27
*/
#include <math.h>
#include "T2TUtility.h"
#include "T2TTester.h"
#include "T2TSearch.h"
......@@ -130,7 +131,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n",
elapsed,wordCountTotal, exp(loss / wordCount));
elapsed,wordCountTotal, exp(loss/wordCount));
}
}
\ No newline at end of file
}
......@@ -154,7 +154,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
for(epoch = 1; epoch <= nepoch; epoch++){
#ifndef WIN32
if(isShuffled)
Shuffle(fn, trainFN);
batchLoader.Shuffle(fn, trainFN);
#endif
FILE * file = fopen(trainFN, "rb");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论