Commit 8b0e06ab by xiaotong

improve the t2t implementation

parent 6fb9ad1b
......@@ -53,7 +53,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TAttention::InitModel(int argc, const char ** argv,
void T2TAttention::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
{
......
......@@ -84,7 +84,7 @@ public:
~T2TAttention();
/* initialize the model */
void InitModel(int argc, const char ** argv,
void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
......
......@@ -34,7 +34,7 @@ class AttDecoder : T2TDecoder
{
public:
/* initialize the model */
void InitModel(int argc, const char ** argv);
void InitModel(int argc, char ** argv);
};
}
......
......@@ -48,7 +48,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TEmbedder::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
......
......@@ -71,7 +71,7 @@ public:
~T2TEmbedder();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make positional embeddings */
void MakePosEmbedding(int eSize, int d, int length);
......
......@@ -51,7 +51,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void AttEncoder::InitModel(int argc, const char ** argv,
void AttEncoder::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
{
......
......@@ -113,7 +113,7 @@ public:
~AttEncoder();
/* initialize the model */
void InitModel(int argc, const char ** argv,
void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
......
......@@ -49,7 +49,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
void T2TFNN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
......
......@@ -69,7 +69,7 @@ public:
~T2TFNN();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
......
......@@ -47,7 +47,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TLN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
void T2TLN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
......
......@@ -54,7 +54,7 @@ public:
~T2TLN();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
......
......@@ -48,7 +48,7 @@ initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TModel::InitModel(int argc, const char ** argv)
void T2TModel::InitModel(int argc, char ** argv)
{
bool useMem = false;
int memSize = 0;
......@@ -64,7 +64,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
if(useMem){
delete mem;
mem = new XMem(devID, isMemFreeOTF ? FREE_ON_THE_FLY : UNI_FREE, (MTYPE)MILLION * 256, 1024, MILLION * 128);
mem = new XMem(devID, FREE_ON_THE_FLY, (MTYPE)MILLION * 256, 1024, MILLION * 128);
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
}
......@@ -144,7 +144,7 @@ void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding, bool isTr
//_Sum(&mask, padding3, &mask);
encoding = MakeEncoding(input, mask, true, isTraining);
encoding = MakeEncoding(input, mask, false, isTraining);
outputLayer.Make(encoding, output);
delete[] dims;
......
......@@ -66,7 +66,7 @@ public:
~T2TModel();
/* initialize the model */
void InitModel(int argc, const char ** argv);
void InitModel(int argc, char ** argv);
/* make the encoding network */
XTensor MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining);
......
......@@ -49,7 +49,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TOutput::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
void T2TOutput::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
......
......@@ -59,7 +59,7 @@ public:
~T2TOutput();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
......
......@@ -26,6 +26,11 @@
#include "../../tensor/core/CHeader.h"
#include "../../network/XNoder.h"
#ifndef WIN32
#include <sys/time.h>
#include <unistd.h>
#endif
namespace transformer
{
......@@ -33,8 +38,16 @@ namespace transformer
T2TTrainer::T2TTrainer()
{
seqLen = NULL;
seqLen2 = NULL;
nseqBuf = 0;
nextSeq = -1;
argNum = 0;
argArray = NULL;
buf = NULL;
buf2 = NULL;
bufSize = 0;
seqOffset = NULL;
}
/* de-constructor */
......@@ -55,6 +68,11 @@ T2TTrainer::~T2TTrainer()
XTensor * m = (XTensor*)moments2nd.Get(i);
delete m;
}
for(int i = 0; i < argNum; i++)
delete[] argArray[i];
delete[] argArray;
}
/*
......@@ -62,8 +80,15 @@ initialization
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TTrainer::Init(int argc, const char ** argv)
void T2TTrainer::Init(int argc, char ** argv)
{
argNum = argc;
argArray = new char*[argc];
for(int i = 0; i < argNum; i++){
argArray[i] = new char[strlen(argv[i]) + 1];
strcpy(argArray[i], argv[i]);
}
bool useMem = false;
LoadParamBool(argc, argv, "mem", &useMem, useMem);
......@@ -82,6 +107,9 @@ void T2TTrainer::Init(int argc, const char ** argv)
LoadParamFloat(argc, argv, "adambeta1", &adamBeta1, 0.9F);
LoadParamFloat(argc, argv, "adambeta2", &adamBeta2, 0.999F);
LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-8F);
LoadParamBool(argc, argv, "shuffled", &isShuffled, false);
LoadParamInt(argc, argv, "nstepcheckpoint", &nStepCheckpoint, -1);
LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false);
buf = new int[bufSize];
buf2 = new int[bufSize];
......@@ -91,7 +119,6 @@ void T2TTrainer::Init(int argc, const char ** argv)
adamBeta1T = 1.0F;
adamBeta2T = 1.0F;
}
int tc = 0;
......@@ -99,9 +126,11 @@ int tc = 0;
/*
train the model
>> fn - training data file
>> validFN - validation data file
>> modelFN - where we keep the model
>> model - model to train
*/
void T2TTrainer::Train(const char * fn, T2TModel * model)
void T2TTrainer::Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model)
{
int epoch = 0;
int step = 0;
......@@ -111,32 +140,36 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
bool isEnd = false;
float loss = 0;
float lr = 0;
int nStepCheck = 0;
int nCheckpoint = 0;
char * trainFN = new char[(int)strlen(fn) + 10];
strcpy(trainFN, fn);
#ifndef WIN32
if(isShuffled)
sprintf(trainFN, "%s.random", fn);
#endif
PrepareModel(model);
int devID = model->devID;
XMem * mem = model->mem;
if(mem != NULL && mem->mode == UNI_FREE)
mem->SetPin();
XNet net;
tf = fopen("tmp.xx.txt", "wb");
tc = 0;
double startT = GetClockSec();
for(epoch = 1; epoch <= nepoch; epoch++){
#ifndef WIN32
if(isShuffled)
Shuffle(fn, trainFN);
#endif
FILE * file = fopen(fn, "rb");
FILE * file = fopen(trainFN, "rb");
CheckNTErrors(file, "cannot open training file!");
wordCount = 0;
loss = 0;
if(mem != NULL)
mem->BackToPin();
/* batch of input sequences */
XTensor batch;
......@@ -186,22 +219,23 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount), exp(-prob/wc));
}
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){
MakeCheckpoint(model, validFN, modelFN, "step", step);
nStepCheck = 0;
nCheckpoint++;
}
}
fclose(file);
if (isEnd)
break;
if(useEpochCheckpoint)
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
}
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
double elapsed = GetClockSec() - startT;
fclose(tf);
epoch = MIN(epoch, nepoch);
......@@ -209,6 +243,8 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
elapsed, step, epoch);
delete[] trainFN;
}
/*
......@@ -234,16 +270,10 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
XMem * mem = model->mem;
XNet net;
tf = fopen("tmp.xx.txt", "wb");
tc = 0;
double startT = GetClockSec();
wordCount = 0;
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
/* batch of input sequences */
XTensor batch;
......@@ -306,13 +336,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
loss += -prob;
wordCount += wc;
wordCountTotal += wc;
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
}
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
fclose(file);
fclose(ofile);
......@@ -320,13 +344,37 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
delete[] seqs;
double elapsed = GetClockSec() - startT;
fclose(tf);
XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n",
elapsed,wordCountTotal, exp(loss / wordCount));
}
/*
make a checkpoint
>> model - the model
>> validFN - validation data file
>> modelFN - model data file
>> label - label of the model
>> id - id of the checkpoint
*/
void T2TTrainer::MakeCheckpoint(T2TModel * model, const char * validFN, const char * modelFN, const char * label, int id)
{
char * fn = new char[MAX_LINE_LENGTH];
char * fn2 = new char[MAX_LINE_LENGTH];
sprintf(fn, "%s.%s.%3d", modelFN, label, id);
sprintf(fn2, "%s.%s.%3d.output", modelFN, label, id);
//model->Dump(fn);
if(validFN != NULL){
T2TTrainer trainer;
trainer.Init(argNum, argArray);
trainer.Test(validFN, fn2, model);
}
delete[] fn;
delete[] fn2;
}
char line[MAX_SEQUENCE_LENGTH];
struct SampleNode
......@@ -583,6 +631,24 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM,
return sc;
}
/*
shuffle lines of the file
>> srcFile - the source file to shuffle
>> tgtFile - the resulting file
*/
void T2TTrainer::Shuffle(const char * srcFile, const char * tgtFile)
{
char * line = new char[MAX_LINE_LENGTH];
#ifndef WIN32
sprintf(line, "shuf %s > %s", srcFile, tgtFile);
system(line);
#else
ShowNTErrors("Cannot shuffle the file on WINDOWS systems!");
#endif
delete[] line;
}
/*
get word probabilities for a batch of sequences
......
......@@ -37,6 +37,12 @@ namespace transformer
class T2TTrainer
{
public:
/* paramter number */
int argNum;
/* parameter array */
char ** argArray;
/* buffer for loading words */
int * buf;
......@@ -107,6 +113,15 @@ public:
/* list of the 2nd order moment of the parameter matrics */
XList moments2nd;
/* indicates whether the data file is shuffled for training */
bool isShuffled;
/* number of steps after which we make a checkpoint */
int nStepCheckpoint;
/* indicates whether we make a checkpoint after each traing epoch */
bool useEpochCheckpoint;
public:
/* constructor */
T2TTrainer();
......@@ -115,14 +130,17 @@ public:
~T2TTrainer();
/* initialize the trainer */
void Init(int argc, const char ** argv);
void Init(int argc, char ** argv);
/* train the model */
void Train(const char * fn, T2TModel * model);
void Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model);
/* test the model */
void Test(const char * fn, const char * ofn, T2TModel * model);
/* make a checkpoint */
void MakeCheckpoint(T2TModel * model, const char * validFN, const char * modelFN, const char * label, int id);
/* load data to buffer */
int LoadBuf(FILE * file, bool isSorted, int step);
......@@ -136,6 +154,9 @@ public:
int step, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
/* shuffle the data file */
void Shuffle(const char * srcFile, const char * tgtFile);
/* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
......
......@@ -30,7 +30,7 @@ FILE * tmpFILE;
int llnum = 0;
FILE * tf = NULL;
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP)
void LoadParamString(int argc, char ** argv, const char * name, char * p, const char * defaultP)
{
char vname[128];
vname[0] = '-';
......@@ -47,7 +47,7 @@ void LoadParamString(int argc, const char ** argv, const char * name, char * p,
strcpy(p, defaultP);
}
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP)
void LoadParamInt(int argc, char ** argv, const char * name, int * p, int defaultP)
{
char vname[128];
vname[0] = '-';
......@@ -64,7 +64,7 @@ void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int
*p = defaultP;
}
void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP)
void LoadParamBool(int argc, char ** argv, const char * name, bool * p, bool defaultP)
{
char vname[128];
vname[0] = '-';
......@@ -81,7 +81,7 @@ void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bo
*p = defaultP;
}
void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP)
void LoadParamFloat(int argc, char ** argv, const char * name, float * p, float defaultP)
{
char vname[128];
vname[0] = '-';
......@@ -98,7 +98,7 @@ void LoadParamFloat(int argc, const char ** argv, const char * name, float * p,
*p = defaultP;
}
void ShowParams(int argc, const char ** argv)
void ShowParams(int argc, char ** argv)
{
fprintf(stderr, "args:\n");
for(int i = 0; i < argc; i++){
......
......@@ -30,13 +30,13 @@ namespace transformer
extern FILE * tmpFILE;
/* load arguments */
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP);
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP);
void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP);
void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP);
void LoadParamString(int argc, char ** argv, const char * name, char * p, const char * defaultP);
void LoadParamInt(int argc, char ** argv, const char * name, int * p, int defaultP);
void LoadParamBool(int argc, char ** argv, const char * name, bool * p, bool defaultP);
void LoadParamFloat(int argc, char ** argv, const char * name, float * p, float defaultP);
/* show arguments */
void ShowParams(int argc, const char ** argv);
void ShowParams(int argc, char ** argv);
extern int llnum;
extern FILE * tf;
......
......@@ -33,30 +33,36 @@ int TransformerMain(int argc, const char ** argv)
if(argc == 0)
return 1;
char ** args = new char*[argc];
for(int i = 0; i < argc; i++){
args[i] = new char[strlen(argv[i]) + 1];
strcpy(args[i], argv[i]);
}
tmpFILE = fopen("tmp.txt", "wb");
ShowParams(argc, argv);
ShowParams(argc, args);
char * trainFN = new char[MAX_LINE_LENGTH];
char * modelFN = new char[MAX_LINE_LENGTH];
char * testFN = new char[MAX_LINE_LENGTH];
char * outputFN = new char[MAX_LINE_LENGTH];
LoadParamString(argc, argv, "train", trainFN, "");
LoadParamString(argc, argv, "model", modelFN, "");
LoadParamString(argc, argv, "test", testFN, "");
LoadParamString(argc, argv, "output", outputFN, "");
LoadParamString(argc, args, "train", trainFN, "");
LoadParamString(argc, args, "model", modelFN, "");
LoadParamString(argc, args, "test", testFN, "");
LoadParamString(argc, args, "output", outputFN, "");
T2TTrainer trainer;
trainer.Init(argc, argv);
trainer.Init(argc, args);
T2TModel model;
model.InitModel(argc, argv);
model.InitModel(argc, args);
/* learn model parameters */
if(strcmp(trainFN, ""))
trainer.Train(trainFN, &model);
trainer.Train(trainFN, testFN, strcmp(modelFN, "") ? modelFN : "checkpoint.model", &model);
/* save the final model */
if(strcmp(modelFN, "") && strcmp(trainFN, ""))
......@@ -66,15 +72,22 @@ int TransformerMain(int argc, const char ** argv)
if(strcmp(modelFN, ""))
model.Read(modelFN);
T2TTrainer tester;
tester.Init(argc, args);
/* test the model on the new data */
if(strcmp(testFN, "") && strcmp(outputFN, ""))
trainer.Test(testFN, outputFN, &model);
tester.Test(testFN, outputFN, &model);
delete[] trainFN;
delete[] modelFN;
delete[] testFN;
delete[] outputFN;
for(int i = 0; i < argc; i++)
delete[] args[i];
delete[] args;
fclose(tmpFILE);
return 0;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论