Commit 8b0e06ab by xiaotong

improve the t2t implementation

parent 6fb9ad1b
...@@ -53,7 +53,7 @@ initialize the model ...@@ -53,7 +53,7 @@ initialize the model
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool >> myMem - the memory pool
*/ */
void T2TAttention::InitModel(int argc, const char ** argv, void T2TAttention::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored, bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem) int myDevID, XMem * myMem)
{ {
......
...@@ -84,7 +84,7 @@ public: ...@@ -84,7 +84,7 @@ public:
~T2TAttention(); ~T2TAttention();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored, bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL); int myDevID = -1, XMem * myMem = NULL);
......
...@@ -34,7 +34,7 @@ class AttDecoder : T2TDecoder ...@@ -34,7 +34,7 @@ class AttDecoder : T2TDecoder
{ {
public: public:
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv); void InitModel(int argc, char ** argv);
}; };
} }
......
...@@ -48,7 +48,7 @@ initialize the model ...@@ -48,7 +48,7 @@ initialize the model
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool >> myMem - the memory pool
*/ */
void T2TEmbedder::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem) void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
{ {
devID = myDevID; devID = myDevID;
mem = myMem; mem = myMem;
......
...@@ -71,7 +71,7 @@ public: ...@@ -71,7 +71,7 @@ public:
~T2TEmbedder(); ~T2TEmbedder();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL); void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make positional embeddings */ /* make positional embeddings */
void MakePosEmbedding(int eSize, int d, int length); void MakePosEmbedding(int eSize, int d, int length);
......
...@@ -51,7 +51,7 @@ initialize the model ...@@ -51,7 +51,7 @@ initialize the model
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool >> myMem - the memory pool
*/ */
void AttEncoder::InitModel(int argc, const char ** argv, void AttEncoder::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored, bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem) int myDevID, XMem * myMem)
{ {
......
...@@ -113,7 +113,7 @@ public: ...@@ -113,7 +113,7 @@ public:
~AttEncoder(); ~AttEncoder();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored, bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL); int myDevID = -1, XMem * myMem = NULL);
......
...@@ -49,7 +49,7 @@ initialize the model ...@@ -49,7 +49,7 @@ initialize the model
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool >> myMem - the memory pool
*/ */
void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem) void T2TFNN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
{ {
devID = myDevID; devID = myDevID;
mem = myMem; mem = myMem;
......
...@@ -69,7 +69,7 @@ public: ...@@ -69,7 +69,7 @@ public:
~T2TFNN(); ~T2TFNN();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL); void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */ /* make the network */
XTensor Make(XTensor &input); XTensor Make(XTensor &input);
......
...@@ -47,7 +47,7 @@ initialize the model ...@@ -47,7 +47,7 @@ initialize the model
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool >> myMem - the memory pool
*/ */
void T2TLN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem) void T2TLN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
{ {
devID = myDevID; devID = myDevID;
mem = myMem; mem = myMem;
......
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
~T2TLN(); ~T2TLN();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL); void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */ /* make the network */
XTensor Make(XTensor &input); XTensor Make(XTensor &input);
......
...@@ -48,7 +48,7 @@ initialize the model ...@@ -48,7 +48,7 @@ initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
*/ */
void T2TModel::InitModel(int argc, const char ** argv) void T2TModel::InitModel(int argc, char ** argv)
{ {
bool useMem = false; bool useMem = false;
int memSize = 0; int memSize = 0;
...@@ -64,7 +64,7 @@ void T2TModel::InitModel(int argc, const char ** argv) ...@@ -64,7 +64,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
if(useMem){ if(useMem){
delete mem; delete mem;
mem = new XMem(devID, isMemFreeOTF ? FREE_ON_THE_FLY : UNI_FREE, (MTYPE)MILLION * 256, 1024, MILLION * 128); mem = new XMem(devID, FREE_ON_THE_FLY, (MTYPE)MILLION * 256, 1024, MILLION * 128);
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION); mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
} }
...@@ -144,7 +144,7 @@ void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding, bool isTr ...@@ -144,7 +144,7 @@ void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding, bool isTr
//_Sum(&mask, padding3, &mask); //_Sum(&mask, padding3, &mask);
encoding = MakeEncoding(input, mask, true, isTraining); encoding = MakeEncoding(input, mask, false, isTraining);
outputLayer.Make(encoding, output); outputLayer.Make(encoding, output);
delete[] dims; delete[] dims;
......
...@@ -66,7 +66,7 @@ public: ...@@ -66,7 +66,7 @@ public:
~T2TModel(); ~T2TModel();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv); void InitModel(int argc, char ** argv);
/* make the encoding network */ /* make the encoding network */
XTensor MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining); XTensor MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining);
......
...@@ -49,7 +49,7 @@ initialize the model ...@@ -49,7 +49,7 @@ initialize the model
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool >> myMem - the memory pool
*/ */
void T2TOutput::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem) void T2TOutput::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
{ {
devID = myDevID; devID = myDevID;
mem = myMem; mem = myMem;
......
...@@ -59,7 +59,7 @@ public: ...@@ -59,7 +59,7 @@ public:
~T2TOutput(); ~T2TOutput();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL); void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */ /* make the network */
XTensor Make(XTensor &input); XTensor Make(XTensor &input);
......
...@@ -26,6 +26,11 @@ ...@@ -26,6 +26,11 @@
#include "../../tensor/core/CHeader.h" #include "../../tensor/core/CHeader.h"
#include "../../network/XNoder.h" #include "../../network/XNoder.h"
#ifndef WIN32
#include <sys/time.h>
#include <unistd.h>
#endif
namespace transformer namespace transformer
{ {
...@@ -33,8 +38,16 @@ namespace transformer ...@@ -33,8 +38,16 @@ namespace transformer
T2TTrainer::T2TTrainer() T2TTrainer::T2TTrainer()
{ {
seqLen = NULL; seqLen = NULL;
seqLen2 = NULL;
nseqBuf = 0; nseqBuf = 0;
nextSeq = -1; nextSeq = -1;
argNum = 0;
argArray = NULL;
buf = NULL;
buf2 = NULL;
bufSize = 0;
seqOffset = NULL;
} }
/* de-constructor */ /* de-constructor */
...@@ -55,6 +68,11 @@ T2TTrainer::~T2TTrainer() ...@@ -55,6 +68,11 @@ T2TTrainer::~T2TTrainer()
XTensor * m = (XTensor*)moments2nd.Get(i); XTensor * m = (XTensor*)moments2nd.Get(i);
delete m; delete m;
} }
for(int i = 0; i < argNum; i++)
delete[] argArray[i];
delete[] argArray;
} }
/* /*
...@@ -62,8 +80,15 @@ initialization ...@@ -62,8 +80,15 @@ initialization
>> argc - number of arguments >> argc - number of arguments
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
*/ */
void T2TTrainer::Init(int argc, const char ** argv) void T2TTrainer::Init(int argc, char ** argv)
{ {
argNum = argc;
argArray = new char*[argc];
for(int i = 0; i < argNum; i++){
argArray[i] = new char[strlen(argv[i]) + 1];
strcpy(argArray[i], argv[i]);
}
bool useMem = false; bool useMem = false;
LoadParamBool(argc, argv, "mem", &useMem, useMem); LoadParamBool(argc, argv, "mem", &useMem, useMem);
...@@ -82,6 +107,9 @@ void T2TTrainer::Init(int argc, const char ** argv) ...@@ -82,6 +107,9 @@ void T2TTrainer::Init(int argc, const char ** argv)
LoadParamFloat(argc, argv, "adambeta1", &adamBeta1, 0.9F); LoadParamFloat(argc, argv, "adambeta1", &adamBeta1, 0.9F);
LoadParamFloat(argc, argv, "adambeta2", &adamBeta2, 0.999F); LoadParamFloat(argc, argv, "adambeta2", &adamBeta2, 0.999F);
LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-8F); LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-8F);
LoadParamBool(argc, argv, "shuffled", &isShuffled, false);
LoadParamInt(argc, argv, "nstepcheckpoint", &nStepCheckpoint, -1);
LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false);
buf = new int[bufSize]; buf = new int[bufSize];
buf2 = new int[bufSize]; buf2 = new int[bufSize];
...@@ -91,7 +119,6 @@ void T2TTrainer::Init(int argc, const char ** argv) ...@@ -91,7 +119,6 @@ void T2TTrainer::Init(int argc, const char ** argv)
adamBeta1T = 1.0F; adamBeta1T = 1.0F;
adamBeta2T = 1.0F; adamBeta2T = 1.0F;
} }
int tc = 0; int tc = 0;
...@@ -99,9 +126,11 @@ int tc = 0; ...@@ -99,9 +126,11 @@ int tc = 0;
/* /*
train the model train the model
>> fn - training data file >> fn - training data file
>> validFN - validation data file
>> modelFN - where we keep the model
>> model - model to train >> model - model to train
*/ */
void T2TTrainer::Train(const char * fn, T2TModel * model) void T2TTrainer::Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model)
{ {
int epoch = 0; int epoch = 0;
int step = 0; int step = 0;
...@@ -111,32 +140,36 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -111,32 +140,36 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
bool isEnd = false; bool isEnd = false;
float loss = 0; float loss = 0;
float lr = 0; float lr = 0;
int nStepCheck = 0;
int nCheckpoint = 0;
char * trainFN = new char[(int)strlen(fn) + 10];
strcpy(trainFN, fn);
#ifndef WIN32
if(isShuffled)
sprintf(trainFN, "%s.random", fn);
#endif
PrepareModel(model); PrepareModel(model);
int devID = model->devID; int devID = model->devID;
XMem * mem = model->mem; XMem * mem = model->mem;
if(mem != NULL && mem->mode == UNI_FREE)
mem->SetPin();
XNet net; XNet net;
tf = fopen("tmp.xx.txt", "wb");
tc = 0;
double startT = GetClockSec(); double startT = GetClockSec();
for(epoch = 1; epoch <= nepoch; epoch++){ for(epoch = 1; epoch <= nepoch; epoch++){
#ifndef WIN32
if(isShuffled)
Shuffle(fn, trainFN);
#endif
FILE * file = fopen(fn, "rb"); FILE * file = fopen(trainFN, "rb");
CheckNTErrors(file, "cannot open training file!"); CheckNTErrors(file, "cannot open training file!");
wordCount = 0; wordCount = 0;
loss = 0; loss = 0;
if(mem != NULL)
mem->BackToPin();
/* batch of input sequences */ /* batch of input sequences */
XTensor batch; XTensor batch;
...@@ -186,22 +219,23 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -186,22 +219,23 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount), exp(-prob/wc)); lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount), exp(-prob/wc));
} }
if(mem != NULL && mem->mode == UNI_FREE) if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){
mem->BackToPin(); MakeCheckpoint(model, validFN, modelFN, "step", step);
nStepCheck = 0;
nCheckpoint++;
}
} }
fclose(file); fclose(file);
if (isEnd) if (isEnd)
break; break;
if(useEpochCheckpoint)
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
} }
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
fclose(tf);
epoch = MIN(epoch, nepoch); epoch = MIN(epoch, nepoch);
...@@ -209,6 +243,8 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -209,6 +243,8 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount)); lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n", XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
elapsed, step, epoch); elapsed, step, epoch);
delete[] trainFN;
} }
/* /*
...@@ -234,16 +270,10 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -234,16 +270,10 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
XMem * mem = model->mem; XMem * mem = model->mem;
XNet net; XNet net;
tf = fopen("tmp.xx.txt", "wb");
tc = 0;
double startT = GetClockSec(); double startT = GetClockSec();
wordCount = 0; wordCount = 0;
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
/* batch of input sequences */ /* batch of input sequences */
XTensor batch; XTensor batch;
...@@ -306,13 +336,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -306,13 +336,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
loss += -prob; loss += -prob;
wordCount += wc; wordCount += wc;
wordCountTotal += wc; wordCountTotal += wc;
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
} }
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
fclose(file); fclose(file);
fclose(ofile); fclose(ofile);
...@@ -320,13 +344,37 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -320,13 +344,37 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
delete[] seqs; delete[] seqs;
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
fclose(tf);
XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n", XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n",
elapsed,wordCountTotal, exp(loss / wordCount)); elapsed,wordCountTotal, exp(loss / wordCount));
} }
/*
make a checkpoint
>> model - the model
>> validFN - validation data file
>> modelFN - model data file
>> label - label of the model
>> id - id of the checkpoint
*/
void T2TTrainer::MakeCheckpoint(T2TModel * model, const char * validFN, const char * modelFN, const char * label, int id)
{
char * fn = new char[MAX_LINE_LENGTH];
char * fn2 = new char[MAX_LINE_LENGTH];
sprintf(fn, "%s.%s.%3d", modelFN, label, id);
sprintf(fn2, "%s.%s.%3d.output", modelFN, label, id);
//model->Dump(fn);
if(validFN != NULL){
T2TTrainer trainer;
trainer.Init(argNum, argArray);
trainer.Test(validFN, fn2, model);
}
delete[] fn;
delete[] fn2;
}
char line[MAX_SEQUENCE_LENGTH]; char line[MAX_SEQUENCE_LENGTH];
struct SampleNode struct SampleNode
...@@ -583,6 +631,24 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM, ...@@ -583,6 +631,24 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM,
return sc; return sc;
} }
/*
shuffle lines of the file
>> srcFile - the source file to shuffle
>> tgtFile - the resulting file
*/
void T2TTrainer::Shuffle(const char * srcFile, const char * tgtFile)
{
char * line = new char[MAX_LINE_LENGTH];
#ifndef WIN32
sprintf(line, "shuf %s > %s", srcFile, tgtFile);
system(line);
#else
ShowNTErrors("Cannot shuffle the file on WINDOWS systems!");
#endif
delete[] line;
}
/* /*
get word probabilities for a batch of sequences get word probabilities for a batch of sequences
......
...@@ -37,6 +37,12 @@ namespace transformer ...@@ -37,6 +37,12 @@ namespace transformer
class T2TTrainer class T2TTrainer
{ {
public: public:
/* paramter number */
int argNum;
/* parameter array */
char ** argArray;
/* buffer for loading words */ /* buffer for loading words */
int * buf; int * buf;
...@@ -107,6 +113,15 @@ public: ...@@ -107,6 +113,15 @@ public:
/* list of the 2nd order moment of the parameter matrics */ /* list of the 2nd order moment of the parameter matrics */
XList moments2nd; XList moments2nd;
/* indicates whether the data file is shuffled for training */
bool isShuffled;
/* number of steps after which we make a checkpoint */
int nStepCheckpoint;
/* indicates whether we make a checkpoint after each traing epoch */
bool useEpochCheckpoint;
public: public:
/* constructor */ /* constructor */
T2TTrainer(); T2TTrainer();
...@@ -115,14 +130,17 @@ public: ...@@ -115,14 +130,17 @@ public:
~T2TTrainer(); ~T2TTrainer();
/* initialize the trainer */ /* initialize the trainer */
void Init(int argc, const char ** argv); void Init(int argc, char ** argv);
/* train the model */ /* train the model */
void Train(const char * fn, T2TModel * model); void Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model);
/* test the model */ /* test the model */
void Test(const char * fn, const char * ofn, T2TModel * model); void Test(const char * fn, const char * ofn, T2TModel * model);
/* make a checkpoint */
void MakeCheckpoint(T2TModel * model, const char * validFN, const char * modelFN, const char * label, int id);
/* load data to buffer */ /* load data to buffer */
int LoadBuf(FILE * file, bool isSorted, int step); int LoadBuf(FILE * file, bool isSorted, int step);
...@@ -136,6 +154,9 @@ public: ...@@ -136,6 +154,9 @@ public:
int step, int vs, int sBatch, int wBatch, int step, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
int devID, XMem * mem); int devID, XMem * mem);
/* shuffle the data file */
void Shuffle(const char * srcFile, const char * tgtFile);
/* get word probabilities for a batch of sequences */ /* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs); float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
......
...@@ -30,7 +30,7 @@ FILE * tmpFILE; ...@@ -30,7 +30,7 @@ FILE * tmpFILE;
int llnum = 0; int llnum = 0;
FILE * tf = NULL; FILE * tf = NULL;
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP) void LoadParamString(int argc, char ** argv, const char * name, char * p, const char * defaultP)
{ {
char vname[128]; char vname[128];
vname[0] = '-'; vname[0] = '-';
...@@ -47,7 +47,7 @@ void LoadParamString(int argc, const char ** argv, const char * name, char * p, ...@@ -47,7 +47,7 @@ void LoadParamString(int argc, const char ** argv, const char * name, char * p,
strcpy(p, defaultP); strcpy(p, defaultP);
} }
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP) void LoadParamInt(int argc, char ** argv, const char * name, int * p, int defaultP)
{ {
char vname[128]; char vname[128];
vname[0] = '-'; vname[0] = '-';
...@@ -64,7 +64,7 @@ void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int ...@@ -64,7 +64,7 @@ void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int
*p = defaultP; *p = defaultP;
} }
void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP) void LoadParamBool(int argc, char ** argv, const char * name, bool * p, bool defaultP)
{ {
char vname[128]; char vname[128];
vname[0] = '-'; vname[0] = '-';
...@@ -81,7 +81,7 @@ void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bo ...@@ -81,7 +81,7 @@ void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bo
*p = defaultP; *p = defaultP;
} }
void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP) void LoadParamFloat(int argc, char ** argv, const char * name, float * p, float defaultP)
{ {
char vname[128]; char vname[128];
vname[0] = '-'; vname[0] = '-';
...@@ -98,7 +98,7 @@ void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, ...@@ -98,7 +98,7 @@ void LoadParamFloat(int argc, const char ** argv, const char * name, float * p,
*p = defaultP; *p = defaultP;
} }
void ShowParams(int argc, const char ** argv) void ShowParams(int argc, char ** argv)
{ {
fprintf(stderr, "args:\n"); fprintf(stderr, "args:\n");
for(int i = 0; i < argc; i++){ for(int i = 0; i < argc; i++){
......
...@@ -30,13 +30,13 @@ namespace transformer ...@@ -30,13 +30,13 @@ namespace transformer
extern FILE * tmpFILE; extern FILE * tmpFILE;
/* load arguments */ /* load arguments */
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP); void LoadParamString(int argc, char ** argv, const char * name, char * p, const char * defaultP);
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP); void LoadParamInt(int argc, char ** argv, const char * name, int * p, int defaultP);
void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP); void LoadParamBool(int argc, char ** argv, const char * name, bool * p, bool defaultP);
void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP); void LoadParamFloat(int argc, char ** argv, const char * name, float * p, float defaultP);
/* show arguments */ /* show arguments */
void ShowParams(int argc, const char ** argv); void ShowParams(int argc, char ** argv);
extern int llnum; extern int llnum;
extern FILE * tf; extern FILE * tf;
......
...@@ -33,30 +33,36 @@ int TransformerMain(int argc, const char ** argv) ...@@ -33,30 +33,36 @@ int TransformerMain(int argc, const char ** argv)
if(argc == 0) if(argc == 0)
return 1; return 1;
char ** args = new char*[argc];
for(int i = 0; i < argc; i++){
args[i] = new char[strlen(argv[i]) + 1];
strcpy(args[i], argv[i]);
}
tmpFILE = fopen("tmp.txt", "wb"); tmpFILE = fopen("tmp.txt", "wb");
ShowParams(argc, argv); ShowParams(argc, args);
char * trainFN = new char[MAX_LINE_LENGTH]; char * trainFN = new char[MAX_LINE_LENGTH];
char * modelFN = new char[MAX_LINE_LENGTH]; char * modelFN = new char[MAX_LINE_LENGTH];
char * testFN = new char[MAX_LINE_LENGTH]; char * testFN = new char[MAX_LINE_LENGTH];
char * outputFN = new char[MAX_LINE_LENGTH]; char * outputFN = new char[MAX_LINE_LENGTH];
LoadParamString(argc, argv, "train", trainFN, ""); LoadParamString(argc, args, "train", trainFN, "");
LoadParamString(argc, argv, "model", modelFN, ""); LoadParamString(argc, args, "model", modelFN, "");
LoadParamString(argc, argv, "test", testFN, ""); LoadParamString(argc, args, "test", testFN, "");
LoadParamString(argc, argv, "output", outputFN, ""); LoadParamString(argc, args, "output", outputFN, "");
T2TTrainer trainer; T2TTrainer trainer;
trainer.Init(argc, argv); trainer.Init(argc, args);
T2TModel model; T2TModel model;
model.InitModel(argc, argv); model.InitModel(argc, args);
/* learn model parameters */ /* learn model parameters */
if(strcmp(trainFN, "")) if(strcmp(trainFN, ""))
trainer.Train(trainFN, &model); trainer.Train(trainFN, testFN, strcmp(modelFN, "") ? modelFN : "checkpoint.model", &model);
/* save the final model */ /* save the final model */
if(strcmp(modelFN, "") && strcmp(trainFN, "")) if(strcmp(modelFN, "") && strcmp(trainFN, ""))
...@@ -66,15 +72,22 @@ int TransformerMain(int argc, const char ** argv) ...@@ -66,15 +72,22 @@ int TransformerMain(int argc, const char ** argv)
if(strcmp(modelFN, "")) if(strcmp(modelFN, ""))
model.Read(modelFN); model.Read(modelFN);
T2TTrainer tester;
tester.Init(argc, args);
/* test the model on the new data */ /* test the model on the new data */
if(strcmp(testFN, "") && strcmp(outputFN, "")) if(strcmp(testFN, "") && strcmp(outputFN, ""))
trainer.Test(testFN, outputFN, &model); tester.Test(testFN, outputFN, &model);
delete[] trainFN; delete[] trainFN;
delete[] modelFN; delete[] modelFN;
delete[] testFN; delete[] testFN;
delete[] outputFN; delete[] outputFN;
for(int i = 0; i < argc; i++)
delete[] args[i];
delete[] args;
fclose(tmpFILE); fclose(tmpFILE);
return 0; return 0;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论