Commit 8706cdb5 by xiaotong

flag of early stop and length scalar of decoding

parent 7c68f1e7
......@@ -62,6 +62,8 @@ void T2TSearch::Init(int argc, char ** argv)
LoadParamFloat(argc, argv, "lenalpha", &alpha, 1.0F);
LoadParamInt(argc, argv, "endid", endSymbols, -1);
LoadParamInt(argc, argv, "startid", &startSymbol, -1);
LoadParamFloat(argc, argv, "maxlenalpha", &scalarMaxLength, 2.0F);
LoadParamBool(argc, argv, "earlystop", &isEarlyStop, false);
if(endSymbols[0] >= 0)
endSymbolNum = 1;
......@@ -108,8 +110,8 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding,
inputBeam.ReshapeMerged(inputBeam.order - 3);
paddingBeam.ReshapeMerged(paddingBeam.order - 3);
/* max output-length = 2 * source-length */
int lengthLimit = input->GetDim(-1) * 2;
/* max output-length = scalar * source-length */
int lengthLimit = (int)(input->GetDim(-1) * scalarMaxLength);
CheckNTErrors(lengthLimit > 0, "no max length specified!");
maxLength = lengthLimit;
......
......@@ -62,6 +62,12 @@ private:
/* start symbol */
int startSymbol;
/* scalar of the input sequence (for max number of search steps) */
float scalarMaxLength;
/* indicate whether the early stop strategy is used */
bool isEarlyStop;
public:
/* constructor */
T2TSearch();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论