Commit 59d8b4bf by xiaotong

bug fix

parent 78bdfb45
......@@ -121,10 +121,10 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
if(isMasked)
dot = dot + mask;
scalar = Softmax(Linear(dot, 1/(float)sqrt((float)dk)), -1);
scalar = Softmax(Linear(dot, 1.0F/(float)sqrt((float)dk)), -1);
if(ignored > 0)
_SetDataDim(&scalar, 0, ignored, scalar.order - 2, 1e-9F);
//if(ignored > 0)
// _SetDataDim(&scalar, 0, ignored, scalar.order - 2, 1e-9F);
att = BMMul(scalar, vheads);
......
......@@ -123,7 +123,8 @@ XTensor T2TEmbedder::Make(XTensor &input)
}
/* we make positional embeddings first */
if(!match){
//if(!match){
if(true){
InitTensor(&posEmbedding, input.order, dims, X_FLOAT, 1.0F, devID, mem);
XTensor * posTMP = NewTensorBuf(2, dims + 1, X_FLOAT, 1.0F, devID, mem);
......
......@@ -55,7 +55,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamInt(argc, argv, "memsize", &memSize, 256);
LoadParamInt(argc, argv, "memsize", &memSize, 1024);
LoadParamBool(argc, argv, "lm", &isLM, true);
LoadParamBool(argc, argv, "mt", &isMT, false);
LoadParamInt(argc, argv, "nhead", &nhead, 8);
......@@ -66,7 +66,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
}
encoder.InitModel(argc, argv, isLM, isLM ? 1 : 0, devID, mem);
encoder.InitModel(argc, argv, isLM, 0, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem);
}
......@@ -104,7 +104,7 @@ void T2TModel::Make(XTensor &input, XTensor &output)
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
a given sequence. */
_SetDataLowTri(&mask, 1e9F, -1);
_SetDataLowTri(&mask, 1e9F, 0);
_ScaleAndShiftMe(&mask, 1.0F, -1e9F);
encoding = MakeEncoding(input, mask, true);
......
......@@ -43,6 +43,16 @@ T2TTrainer::~T2TTrainer()
delete[] buf;
delete[] seqLen;
delete[] seqOffset;
for(int i = 0; i < moments.count; i++){
XTensor * m = (XTensor*)moments.Get(i);
delete m;
}
for(int i = 0; i < moments2nd.count; i++){
XTensor * m = (XTensor*)moments2nd.Get(i);
delete m;
}
}
/*
......@@ -66,11 +76,18 @@ void T2TTrainer::Init(int argc, const char ** argv)
LoadParamInt(argc, argv, "vsize", &vSize, 1);
LoadParamBool(argc, argv, "sorted", &isLenSorted, false);
LoadParamInt(argc, argv, "bufsize", &bufSize, 50000);
LoadParamBool(argc, argv, "adam", &useAdam, false);
LoadParamFloat(argc, argv, "adambeta1", &adamBeta1, 0.9F);
LoadParamFloat(argc, argv, "adambeta2", &adamBeta2, 0.999F);
LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-8F);
buf = new int[bufSize];
seqLen = new int[bufSize];
seqOffset = new int[bufSize];
adamBeta1T = 1.0F;
adamBeta2T = 1.0F;
}
FILE * tf = NULL;
......@@ -113,6 +130,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
CheckNTErrors(file, "cannot open training file!");
wordCount = 0;
loss = 0;
if(mem != NULL)
mem->BackToPin();
......@@ -122,8 +140,11 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
/* padding */
XTensor padding;
/* gold standard */
XTensor gold;
while(LoadBatch(file, &batch, &padding, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)){
while(LoadBatch(file, true, &batch, &padding, &gold, NULL, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)){
/* output probabilities */
XTensor output;
......@@ -136,16 +157,16 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
PadOutput(&output, &padding);
/* back-propagation for obtaining gradients */
net.Backward(output, batch, CROSSENTROPY);
net.Backward(output, gold, CROSSENTROPY);
/* learning rate */
lr = lrate * (1 / (float)sqrt((float)d)) * (float)MIN(pow(step + 1, -0.5F - lrbias), (step + 1) * pow(nwarmup, -1.5F - lrbias));
lr = lrate * (1.0F / (float)sqrt((float)d)) * (float)MIN(pow((float)step + 1, -0.5F - lrbias), ((float)step + 1) * pow((float)nwarmup, -1.5F - lrbias));
/* update the parameters */
Update(model, lr);
/* get probabilities */
float prob = GetProb(&output, &batch, NULL);
float prob = GetProb(&output, &gold, NULL);
loss += -prob;
wordCount += wc;
......@@ -158,8 +179,8 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
if (step % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n",
lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f, sppl=%.3f\n",
lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount), exp(-prob/wc));
}
if(mem != NULL)
......@@ -171,15 +192,139 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
if (isEnd)
break;
}
if(mem != NULL)
mem->BackToPin();
double elapsed = GetClockSec() - startT;
fclose(tf);
XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n",
lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
elapsed, step, epoch);
elapsed, step, epoch + 1);
}
/*
test the model
>> fn - test data file
>> ofn - output data file
>> model - model that is trained
*/
void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
{
int step = 0;
int wc = 0;
int wordCount = 0;
int wordCountTotal = 0;
bool isEnd = false;
float loss = 0;
/* data files */
FILE * file = fopen(fn, "rb");
CheckNTErrors(file, "Cannot read the test file");
FILE * ofile = fopen(ofn, "wb");
CheckNTErrors(ofile, "Cannot open the output file");
int devID = model->devID;
XMem * mem = model->mem;
if(mem != NULL)
mem->SetPin();
XNet net;
tf = fopen("tmp.xx.txt", "wb");
tc = 0;
double startT = GetClockSec();
wordCount = 0;
if(mem != NULL)
mem->BackToPin();
/* batch of input sequences */
XTensor batch;
/* padding */
XTensor padding;
/* gold standard */
XTensor gold;
/* an array that keeps the sequences */
int * seqs = new int[MILLION];
ClearBuf();
while(LoadBatch(file, true, &batch, &padding, &gold, seqs, 1, vSize, 1, 512, isLenSorted, wc, devID, mem)){
CheckNTErrors(batch.order == 3, "wrong tensor order of the sequence batch");
/* output probabilities */
XTensor output;
/* make the network */
model->Make(batch, output);
int bSize = batch.GetDim(0);
int length = batch.GetDim(1);
/* prediction probabilities */
XTensor probs;
InitTensor1D(&probs, bSize * length);
/* get probabilities */
float prob = GetProb(&output, &gold, &probs);
/* dump the test result */
for(int s = 0; s < bSize; s++){
DTYPE sum = 0;
int * seq = seqs + s * length;
for(int i = 0; i < length; i++){
if(seq[i] >= 0){
fprintf(ofile, "%d ", seq[i]);
}
else
break;
}
fprintf(ofile, "||| ");
for(int i = 0; i < length; i++){
if(seq[i] >= 0){
DTYPE p = probs.Get1D(s * length + i);
fprintf(ofile, "%.3e ", p);
sum += p;
}
else
break;
}
fprintf(ofile, "||| %e\n", sum);
}
loss += -prob;
wordCount += wc;
wordCountTotal += wc;
if(mem != NULL)
mem->BackToPin();
}
if(mem != NULL)
mem->BackToPin();
fclose(file);
fclose(ofile);
delete[] seqs;
double elapsed = GetClockSec() - startT;
fclose(tf);
XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n",
elapsed,wordCountTotal, exp(loss / wordCount));
}
char line[MAX_SEQUENCE_LENGTH];
......@@ -257,11 +402,21 @@ int T2TTrainer::LoadBuf(FILE * file)
return lineCount;
}
/* clear the data buffer */
void T2TTrainer::ClearBuf()
{
nseqBuf = 0;
nextSeq = -1;
}
/*
load a batch of sequences
>> file - the handle to the data file
>> batch - the batch
>> isLM - indicates whether the data is used for training lms
>> batch - the batch of the input sequences
>> padding - padding of the input sequences
>> output - the batch of the output sequences
>> seqs - keep the sequences in an array
>> step - the step we go over when move to the next sequence
>> vs - vocabulary size
>> sBatch - batch size of sequences
......@@ -271,7 +426,9 @@ load a batch of sequences
>> devID - device id
>> mem - memory pool
*/
int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
int T2TTrainer::LoadBatch(FILE * file, bool isLM,
XTensor * batch, XTensor * padding, XTensor * output,
int * seqs,
int step, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem)
......@@ -299,7 +456,10 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
wCount = 0;
nextSeq = seq + sc;
if(sc > 0){
if(sc <= 0)
return 0;
if(isLM){
int dims[MAX_TENSOR_DIM_NUM];
dims[0] = sc;
dims[1] = max;
......@@ -307,6 +467,7 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(padding, sc, max, X_FLOAT, devID, mem);
InitTensor(output, 3, dims, X_FLOAT, 1.0F, devID, mem);
if(batch->grad == NULL)
XNoder::MakeGrad(batch);
......@@ -318,10 +479,19 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
else
InitTensor2D(padding->grad, sc, max, X_FLOAT, devID, mem);
if(output->grad == NULL)
XNoder::MakeGrad(output);
else
InitTensor(output->grad, 3, dims, X_FLOAT, 1.0F, devID, mem);
batch->SetZeroAll();
padding->SetZeroAll();
output->SetZeroAll();
batch->grad->SetZeroAll();
padding->grad->SetZeroAll();
output->grad->SetZeroAll();
int seqSize = 0;
//fprintf(tf, "batch %d(%d)\n", tc++, sc);
......@@ -330,12 +500,23 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
for(int w = 0; w < seqLen[s]; w++){
batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
padding->Set2D(1.0F, s - seq, w);
if(w > 0)
output->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]);
if(w == seqLen[s] - 1)
output->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
wCount++;
//fprintf(tf, "%d", buf[seqOffset[s] + w]);
//if(w < seqLen[s] - 1)
// fprintf(tf, " ");
//else
// fprintf(tf, "\n");
if(seqs != NULL)
seqs[seqSize++] = buf[seqOffset[s] + w];
}
if(seqs != NULL){
for(int w = seqLen[s]; w < max; w++)
seqs[seqSize++] = -1;
}
}
......@@ -408,8 +589,38 @@ void T2TTrainer::Update(T2TModel * model, const float lr)
CheckNTErrors(para != NULL, "NULL parameter tensor!");
CheckNTErrors(paraGrad != NULL, "NULL gradient tensor!");
/* the delta rule */
_Sum(para, paraGrad, para, -lr);
if(useAdam){
adamBeta1T *= adamBeta1;
adamBeta2T *= adamBeta2;
DTYPE e = lr * (DTYPE)sqrt(1 - adamBeta2T) / (1 - adamBeta1T);
DTYPE d = adamDelta * (DTYPE)sqrt(1 - adamBeta2T);
/* m = beat_1 * m + (1-beta_1) * grad */
XTensor * m = (XTensor*)moments.Get(i);
_ScaleAndShiftMe(m, adamBeta1, 0);
_Sum(m, paraGrad, m, (1.0F - adamBeta1));
/* v = beat_2 * v + (1-beta_2) * grad * grad*/
XTensor * v = (XTensor*)moments2nd.Get(i);
_Multiply(paraGrad, paraGrad, v, adamBeta2/(1.0F - adamBeta2));
_ScaleAndShiftMe(v, (1.0F - adamBeta2), 0);
/* v2 = m / (sqrt(v) + delta) */
XTensor * v2 = NewTensorBuf(v, v->devID, v->mem);
_Power(v, v2, 0.5F);
_ScaleAndShiftMe(v2, 1.0F, d);
_Div(m, v2, v2);
/* the delta rule */
_Sum(para, v2, para, -e);
DelTensorBuf(v2);
}
else{
/* the delta rule */
_Sum(para, paraGrad, para, -lr);
}
/* clear gradient */
paraGrad->SetZeroAll();
......@@ -422,6 +633,9 @@ prepare model for training
*/
void T2TTrainer::PrepareModel(T2TModel * model)
{
moments.Clear();
moments2nd.Clear();
XList ws(100);
model->GetParams(ws);
......@@ -429,7 +643,19 @@ void T2TTrainer::PrepareModel(T2TModel * model)
for(int i = 0; i < ws.count; i++){
XTensor * para = (XTensor*)ws.Get(i);
XNoder::MakeGrad(para);
if(useAdam){
XTensor * m = new XTensor(para);
XTensor * m2 = new XTensor(para);
m->SetZeroAll();
m2->SetZeroAll();
moments.Add(m);
moments2nd.Add(m2);
}
}
adamBeta1T = 1.0F;
adamBeta2T = 1.0F;
}
/*
......
......@@ -85,6 +85,22 @@ public:
/* traing step number */
int nstep;
/* indicates whether we use adam */
bool useAdam;
/* hyper parameters of adam*/
float adamBeta1;
float adamBeta2;
float adamDelta;
float adamBeta1T;
float adamBeta2T;
/* list of the moment of the parameter matrics */
XList moments;
/* list of the 2nd order moment of the parameter matrics */
XList moments2nd;
public:
/* constructor */
T2TTrainer();
......@@ -98,11 +114,19 @@ public:
/* train the model */
void Train(const char * fn, T2TModel * model);
/* test the model */
void Test(const char * fn, const char * ofn, T2TModel * model);
/* load data to buffer */
int LoadBuf(FILE * file);
/* clear data buffer */
void ClearBuf();
/* load a batch of sequences */
int LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
int LoadBatch(FILE * file, bool isLM,
XTensor * batch, XTensor * padding, XTensor * output,
int * seqs,
int step, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
......
......@@ -40,20 +40,23 @@ int TransformerMain(int argc, const char ** argv)
char * trainFN = new char[MAX_LINE_LENGTH];
char * modelFN = new char[MAX_LINE_LENGTH];
char * testFN = new char[MAX_LINE_LENGTH];
char * outputFN = new char[MAX_LINE_LENGTH];
LoadParamString(argc, argv, "train", trainFN, "");
LoadParamString(argc, argv, "model", modelFN, "");
LoadParamString(argc, argv, "test", testFN, "");
LoadParamString(argc, argv, "output", outputFN, "");
T2TTrainer trainer;
trainer.Init(argc, argv);
T2TModel model;
model.InitModel(argc, argv);
/* learn model parameters */
if(strcmp(trainFN, "")){
T2TTrainer trainer;
trainer.Init(argc, argv);
if(strcmp(trainFN, ""))
trainer.Train(trainFN, &model);
}
/* save the final model */
if(strcmp(modelFN, "") && strcmp(trainFN, ""))
......@@ -63,9 +66,14 @@ int TransformerMain(int argc, const char ** argv)
if(strcmp(modelFN, ""))
model.Read(modelFN);
/* test the model on the new data */
if(strcmp(testFN, "") && strcmp(outputFN, ""))
trainer.Test(testFN, outputFN, &model);
delete[] trainFN;
delete[] modelFN;
delete[] testFN;
delete[] outputFN;
fclose(tmpFILE);
......
......@@ -147,6 +147,7 @@ extern bool useCUDA;
#define XPRINT4(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4);FFLUSH(FILEH);}}
#define XPRINT5(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5);FFLUSH(FILEH);}}
#define XPRINT6(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6);FFLUSH(FILEH);}}
#define XPRINT7(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7);FFLUSH(FILEH);}}
#define B2I(V) V==0?false:true
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论