Commit 8706cdb5 by xiaotong

flag of early stop and length scalar of decoding

parent 7c68f1e7
...@@ -62,6 +62,8 @@ void T2TSearch::Init(int argc, char ** argv) ...@@ -62,6 +62,8 @@ void T2TSearch::Init(int argc, char ** argv)
LoadParamFloat(argc, argv, "lenalpha", &alpha, 1.0F); LoadParamFloat(argc, argv, "lenalpha", &alpha, 1.0F);
LoadParamInt(argc, argv, "endid", endSymbols, -1); LoadParamInt(argc, argv, "endid", endSymbols, -1);
LoadParamInt(argc, argv, "startid", &startSymbol, -1); LoadParamInt(argc, argv, "startid", &startSymbol, -1);
LoadParamFloat(argc, argv, "maxlenalpha", &scalarMaxLength, 2.0F);
LoadParamBool(argc, argv, "earlystop", &isEarlyStop, false);
if(endSymbols[0] >= 0) if(endSymbols[0] >= 0)
endSymbolNum = 1; endSymbolNum = 1;
...@@ -108,8 +110,8 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, ...@@ -108,8 +110,8 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding,
inputBeam.ReshapeMerged(inputBeam.order - 3); inputBeam.ReshapeMerged(inputBeam.order - 3);
paddingBeam.ReshapeMerged(paddingBeam.order - 3); paddingBeam.ReshapeMerged(paddingBeam.order - 3);
/* max output-length = 2 * source-length */ /* max output-length = scalar * source-length */
int lengthLimit = input->GetDim(-1) * 2; int lengthLimit = (int)(input->GetDim(-1) * scalarMaxLength);
CheckNTErrors(lengthLimit > 0, "no max length specified!"); CheckNTErrors(lengthLimit > 0, "no max length specified!");
maxLength = lengthLimit; maxLength = lengthLimit;
......
...@@ -62,6 +62,12 @@ private: ...@@ -62,6 +62,12 @@ private:
/* start symbol */ /* start symbol */
int startSymbol; int startSymbol;
/* scalar of the input sequence (for max number of search steps) */
float scalarMaxLength;
/* indicate whether the early stop strategy is used */
bool isEarlyStop;
public: public:
/* constructor */ /* constructor */
T2TSearch(); T2TSearch();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论