Commit 448ecd8b by xiaotong

beam search options

parent b7026578
......@@ -49,6 +49,7 @@ int TransformerMain(int argc, const char ** argv)
ShowParams(argc, args);
bool isBeamSearch = false;
char * trainFN = new char[MAX_LINE_LENGTH];
char * modelFN = new char[MAX_LINE_LENGTH];
char * testFN = new char[MAX_LINE_LENGTH];
......@@ -58,6 +59,7 @@ int TransformerMain(int argc, const char ** argv)
LoadParamString(argc, args, "model", modelFN, "");
LoadParamString(argc, args, "test", testFN, "");
LoadParamString(argc, args, "output", outputFN, "");
LoadParamBool(argc, args, "beamsearch", &isBeamSearch, false);
srand((unsigned int)time(NULL));
......@@ -72,19 +74,32 @@ int TransformerMain(int argc, const char ** argv)
trainer.Train(trainFN, testFN, strcmp(modelFN, "") ? modelFN : "checkpoint.model", &model);
/* save the final model */
//if(strcmp(modelFN, "") && strcmp(trainFN, ""))
//model.Dump(modelFN);
if(strcmp(modelFN, "") && strcmp(trainFN, ""))
model.Dump(modelFN);
/* load the model if neccessary */
//if(strcmp(modelFN, ""))
//model.Read(modelFN);
if(strcmp(modelFN, ""))
model.Read(modelFN);
T2TTrainer tester;
tester.Init(argc, args);
/* test the model on the new data */
if(strcmp(testFN, "") && strcmp(outputFN, ""))
tester.Test(testFN, outputFN, &model);
{
/* beam search */
if(isBeamSearch){
T2TTester searcher;
searcher.Init(argc, args);
searcher.Test(testFN, outputFN, &model);
}
/* forced decoding */
else{
T2TTrainer tester;
tester.Init(argc, args);
tester.Test(testFN, outputFN, &model);
}
}
delete[] trainFN;
delete[] modelFN;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论