Commit cc6319e4 by xiaotong

update the code of length alpha

parent 5a9327f8
...@@ -35,7 +35,6 @@ XTensor T2TLengthPenalizer::GNMT(const XTensor & length, float alpha) ...@@ -35,7 +35,6 @@ XTensor T2TLengthPenalizer::GNMT(const XTensor & length, float alpha)
XTensor base; XTensor base;
XTensor lp; XTensor lp;
//base = ScaleAndShift(ScaleAndShift(length, 0, 5.0F), 1.0F/(5 + 1));
base = (length + 5)/(1 + 5); base = (length + 5)/(1 + 5);
lp = Power(base, alpha); lp = Power(base, alpha);
......
...@@ -59,7 +59,7 @@ void T2TSearch::Init(int argc, char ** argv) ...@@ -59,7 +59,7 @@ void T2TSearch::Init(int argc, char ** argv)
{ {
LoadParamInt(argc, argv, "beamsize", &beamSize, 1); LoadParamInt(argc, argv, "beamsize", &beamSize, 1);
LoadParamInt(argc, argv, "batchsize", &batchSize, 1); LoadParamInt(argc, argv, "batchsize", &batchSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F); LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.6F);
LoadParamInt(argc, argv, "endid", endSymbols, -1); LoadParamInt(argc, argv, "endid", endSymbols, -1);
LoadParamInt(argc, argv, "startid", &startSymbol, -1); LoadParamInt(argc, argv, "startid", &startSymbol, -1);
...@@ -108,8 +108,8 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -108,8 +108,8 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* max output-length = 2 * source-length */ /* max output-length = 2 * source-length */
int lengthLimit = input->GetDim(-1) * 2; int lengthLimit = input->GetDim(-1) * 2;
int l = 0;
CheckNTErrors(lengthLimit > 0, "no max length specified!"); CheckNTErrors(lengthLimit > 0, "no max length specified!");
maxLength = lengthLimit;
T2TStateBundle * states = new T2TStateBundle[lengthLimit + 1]; T2TStateBundle * states = new T2TStateBundle[lengthLimit + 1];
T2TStateBundle * first = states; T2TStateBundle * first = states;
...@@ -146,9 +146,11 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -146,9 +146,11 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
Collect(next); Collect(next);
/* stop searching when all hypotheses are completed */ /* stop searching when all hypotheses are completed */
if(IsAllCompleted(next)) if(IsAllCompleted(next)){
maxLength = l + 1;
break; break;
} }
}
/* fill the heap with imcomplete hypotheses if neccesary */ /* fill the heap with imcomplete hypotheses if neccesary */
FillHeap(next); FillHeap(next);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论