/* NiuTrans.Tensor - an open-source tensor library * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. * All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31 */ #include <math.h> #include <time.h> #include "Transformer.h" #include "T2TModel.h" #include "T2TUtility.h" #include "T2TPredictor.h" #include "T2TTester.h" #include "../../tensor/XDevice.h" #include "../../tensor/XUtility.h" #include "../../tensor/XGlobal.h" namespace transformer { int TransformerMain(int argc, const char ** argv) { if(argc == 0) return 1; char ** args = new char*[argc]; for(int i = 0; i < argc; i++){ args[i] = new char[strlen(argv[i]) + 1]; strcpy(args[i], argv[i]); } ShowParams(argc, args); bool isBeamSearch = false; char * trainFN = new char[MAX_LINE_LENGTH]; char * modelFN = new char[MAX_LINE_LENGTH]; char * testFN = new char[MAX_LINE_LENGTH]; char * outputFN = new char[MAX_LINE_LENGTH]; char * rawModel = new char[MAX_LINE_LENGTH]; LoadParamString(argc, args, "model", modelFN, ""); LoadParamString(argc, args, "rawmodel", rawModel, ""); LoadParamString(argc, args, "test", testFN, ""); LoadParamString(argc, args, "output", outputFN, ""); LoadParamBool(argc, args, "beamsearch", &isBeamSearch, false); srand((unsigned int)time(NULL)); T2TModel model; model.InitModel(argc, args); /* load the model if neccessary */ if(strcmp(modelFN, "")) model.Read(modelFN); /* test the model on the new data */ if(strcmp(testFN, "") && strcmp(outputFN, "")){ T2TTester searcher; searcher.Init(argc, args); searcher.Test(testFN, outputFN, &model); } delete[] trainFN; delete[] modelFN; delete[] testFN; delete[] outputFN; delete[] rawModel; for(int i = 0; i < argc; i++) delete[] args[i]; delete[] args; return 0; } }