Commit 448ecd8b by xiaotong

beam search options

parent b7026578
...@@ -49,6 +49,7 @@ int TransformerMain(int argc, const char ** argv) ...@@ -49,6 +49,7 @@ int TransformerMain(int argc, const char ** argv)
ShowParams(argc, args); ShowParams(argc, args);
bool isBeamSearch = false;
char * trainFN = new char[MAX_LINE_LENGTH]; char * trainFN = new char[MAX_LINE_LENGTH];
char * modelFN = new char[MAX_LINE_LENGTH]; char * modelFN = new char[MAX_LINE_LENGTH];
char * testFN = new char[MAX_LINE_LENGTH]; char * testFN = new char[MAX_LINE_LENGTH];
...@@ -58,6 +59,7 @@ int TransformerMain(int argc, const char ** argv) ...@@ -58,6 +59,7 @@ int TransformerMain(int argc, const char ** argv)
LoadParamString(argc, args, "model", modelFN, ""); LoadParamString(argc, args, "model", modelFN, "");
LoadParamString(argc, args, "test", testFN, ""); LoadParamString(argc, args, "test", testFN, "");
LoadParamString(argc, args, "output", outputFN, ""); LoadParamString(argc, args, "output", outputFN, "");
LoadParamBool(argc, args, "beamsearch", &isBeamSearch, false);
srand((unsigned int)time(NULL)); srand((unsigned int)time(NULL));
...@@ -72,19 +74,32 @@ int TransformerMain(int argc, const char ** argv) ...@@ -72,19 +74,32 @@ int TransformerMain(int argc, const char ** argv)
trainer.Train(trainFN, testFN, strcmp(modelFN, "") ? modelFN : "checkpoint.model", &model); trainer.Train(trainFN, testFN, strcmp(modelFN, "") ? modelFN : "checkpoint.model", &model);
/* save the final model */ /* save the final model */
//if(strcmp(modelFN, "") && strcmp(trainFN, "")) if(strcmp(modelFN, "") && strcmp(trainFN, ""))
//model.Dump(modelFN); model.Dump(modelFN);
/* load the model if neccessary */ /* load the model if neccessary */
//if(strcmp(modelFN, "")) if(strcmp(modelFN, ""))
//model.Read(modelFN); model.Read(modelFN);
T2TTrainer tester;
tester.Init(argc, args);
/* test the model on the new data */ /* test the model on the new data */
if(strcmp(testFN, "") && strcmp(outputFN, "")) if(strcmp(testFN, "") && strcmp(outputFN, ""))
tester.Test(testFN, outputFN, &model); {
/* beam search */
if(isBeamSearch){
T2TTester searcher;
searcher.Init(argc, args);
searcher.Test(testFN, outputFN, &model);
}
/* forced decoding */
else{
T2TTrainer tester;
tester.Init(argc, args);
tester.Test(testFN, outputFN, &model);
}
}
delete[] trainFN; delete[] trainFN;
delete[] modelFN; delete[] modelFN;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论