/* NiuTrans.Tensor - an open-source tensor library * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. * All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02 */ #include <math.h> #include "T2TTrainer.h" #include "T2TUtility.h" #include "../../tensor/XUtility.h" #include "../../tensor/core/CHeader.h" #include "../../tensor/loss/LHeader.h" #include "../../network/XNoder.h" #ifndef WIN32 #include <sys/time.h> #include <unistd.h> #endif namespace transformer { /* constructor */ T2TTrainer::T2TTrainer() { argNum = 0; argArray = NULL; } /* de-constructor */ T2TTrainer::~T2TTrainer() { for(int i = 0; i < moments.count; i++){ XTensor * m = (XTensor*)moments.Get(i); delete m; } for(int i = 0; i < moments2nd.count; i++){ XTensor * m = (XTensor*)moments2nd.Get(i); delete m; } for(int i = 0; i < argNum; i++) delete[] argArray[i]; delete[] argArray; } /* initialization >> argc - number of arguments >> argv - list of pointers to the arguments */ void T2TTrainer::Init(int argc, char ** argv) { argNum = argc; argArray = new char*[argc]; for(int i = 0; i < argNum; i++){ argArray[i] = new char[strlen(argv[i]) + 1]; strcpy(argArray[i], argv[i]); } LoadParamFloat(argc, argv, "lrate", &lrate, 1.0F); LoadParamFloat(argc, argv, "lrbias", &lrbias, 0); LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1); LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1); LoadParamInt(argc, argv, "nepoch", &nepoch, 1); LoadParamInt(argc, argv, "nstep", &nstep, 1); LoadParamInt(argc, argv, "d", &d, 512); LoadParamInt(argc, argv, "nwarmup", &nwarmup, 4000); LoadParamInt(argc, argv, "vsize", &vSize, 1); LoadParamInt(argc, argv, "vsizetgt", &vSizeTgt, vSize); LoadParamBool(argc, argv, "adam", &useAdam, false); LoadParamFloat(argc, argv, "adambeta1", &adamBeta1, 0.9F); LoadParamFloat(argc, argv, "adambeta2", &adamBeta2, 0.98F); LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-9F); LoadParamBool(argc, argv, "shuffled", &isShuffled, false); LoadParamFloat(argc, argv, "labelsmoothing", &labelSmoothingP, 0); LoadParamInt(argc, argv, "nstepcheckpoint", &nStepCheckpoint, -1); LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false); LoadParamInt(argc, argv, "updatestep", &updateStep, 1); LoadParamBool(argc, argv, "debug", &isDebugged, false); LoadParamBool(argc, argv, "sorted", &isLenSorted, false); adamBeta1T = 1.0F; adamBeta2T = 1.0F; batchLoader.Init(argc, argv); } int tc = 0; /* train the model >> fn - training data file >> validFN - validation data file >> modelFN - where we keep the model >> model - model to train */ void T2TTrainer::Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model) { int step = 0; int wc = 0; int ws =0; int wordCount = 0; int wordCountTotal = 0; int wordCountBatch = 0; bool isEnd = false; float loss = 0; float lr = 0; int nStepCheck = 0; int nCheckpoint = 0; int nSkipped = 0; int gradStep = 0; int validStep = 0; int epoch = 0; char * trainFN = new char[(int)strlen(fn) + 10]; strcpy(trainFN, fn); #ifndef WIN32 if(isShuffled) sprintf(trainFN, "%s.random", fn); #endif int devID = model->devID; XNet net; if(isDebugged) net.SetGradEfficientFlag(false); PrepareModel(model); double startT = GetClockSec(); double mkinput = 0.0; double train_time = 0.0; double forward = 0.0; double backward = 0.0; double update = 0.0; double start = 0.0; double time = 0.0; for(epoch = 1; epoch <= nepoch; epoch++){ #ifndef WIN32 if(isShuffled) batchLoader.Shuffle(fn, trainFN); #endif FILE * file = fopen(trainFN, "rb"); CheckNTErrors(file, "cannot open training file!"); wordCount = 0; loss = 0; /* batch of sequences (on the encoder and decoder sides) */ XTensor batchEnc; XTensor batchDec; /* labels */ XTensor label; /* padding */ XTensor paddingEnc; XTensor paddingDec; /* gold standard */ XTensor gold; /* label smoothed gold standard (if needed) */ XTensor goldSmoothed; //while (batchLoader.LoadBatch(file, model->isLM, // &batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, &label, // NULL, vSize, vSizeTgt, // sBatchSize, wBatchSize, isLenSorted, ws, wc, devID, true)) while (true) { start = GetClockSec(); int batch = batchLoader.LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, &label, NULL, vSize, vSizeTgt, sBatchSize, wBatchSize, isLenSorted, ws, wc, devID, true); mkinput += GetClockSec() - start; if (!batch) { break; } time = GetClockSec(); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); /* output probabilities */ XTensor output; start = GetClockSec(); /* make the network */ if(model->isLM) model->MakeLM(batchEnc, output, paddingEnc, true); else if(model->isMT) model->MakeMT(batchEnc, batchDec, output, paddingEnc, paddingDec, true); else{ ShowNTErrors("Illegal model type!"); } forward += GetClockSec() - start; /* back-propagation for obtaining gradients */ //if (labelSmoothingP > 0) // LabelSmooth(&gold, &goldSmoothed, labelSmoothingP); start = GetClockSec(); XTensor labelOnehot; labelOnehot = IndexToOnehot(label, vSizeTgt, labelSmoothingP); /* make paddings for the output */ //if (output.GetDim(0) > 0) //PadOutput(&output, &labelOnehot, &paddingDec); /* get probabilities */ //float prob = GetProb(&output, &labelOnehot, NULL); XTensor lossTensor; lossTensor = CrossEntropy(output, labelOnehot, paddingDec); float prob = ReduceSumAll(lossTensor); DTYPE lossLocal = prob / wc; bool doUpdate = (!IsNAN(lossLocal) && !IsINF(lossLocal) && lossLocal < 1e3F); //XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold; if (doUpdate) { /* recale the output for normalized loss */ //RescaleOutput(&output, &labelOnehot, &paddingDec); /* back-propagation */ net.Backward(lossTensor); //net.Backward(output, labelOnehot, paddingDec, CROSSENTROPY); //net.Backward(output, label, labelSmoothingP, CROSSENTROPY); backward += GetClockSec() - start; start = GetClockSec(); gradStep += 1; loss += prob; wordCount += wc; wordCountTotal += wc; //totalW = wc + ws; wordCountBatch += ws; /* update the parameters */ if(gradStep == updateStep){ /* learning rate */ lr = lrate * (1.0F / (float)sqrt((float)d)) * (float)MIN(pow((float)validStep + 1, -0.5F - lrbias), ((float)validStep + 1) * pow((float)nwarmup, -1.5F - lrbias)); /* model update */ Update(model, lr); gradStep = 0; validStep++; update += GetClockSec() - start; } } else nSkipped++; train_time += GetClockSec() - time; if(++step >= nstep){ isEnd = true; break; } if (step % 100 == 0) { double elapsed = GetClockSec() - startT; startT = GetClockSec(); XPRINT6(0, stderr, "[Time] elapsed=%.5lfs,mkinput=%.5lfs,train_time=%.5lfs,forward=%.5lfs, backward=%.5lf, update=%.5lf\n", elapsed, mkinput,train_time, forward, backward, update); XPRINT8(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, tword=%d, sword=%d, loss=%.3f, ppl=%.3f, sppl=%.3f", elapsed, step, epoch, wordCountTotal, wordCountBatch, loss/wordCount, exp(loss/wordCount), exp(prob/wc)); if (!doUpdate) XPRINT(0, stderr, " (no update)"); XPRINT(0, stderr, "\n"); mkinput = 0.0; train_time = 0.0; forward = 0.0; backward = 0.0; update = 0.0; } if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){ MakeCheckpoint(model, validFN, modelFN, "step", step); nStepCheck = 0; nCheckpoint++; } } fclose(file); if (isEnd) break; if(useEpochCheckpoint) MakeCheckpoint(model, validFN, modelFN, "epoch", epoch); } double elapsed = GetClockSec() - startT; epoch = MIN(epoch, nepoch); XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f\n", lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount)); XPRINT4(0, stderr, "[INFO] training finished (took %.1fs, step=%d, skipped=%d and epoch=%d)\n", elapsed, step, nSkipped, epoch); delete[] trainFN; } /* test the model >> fn - test data file >> ofn - output data file >> model - model that is trained */ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) { int wc = 0; int ws = 0; int wordCount = 0; int wordCountTotal = 0; int sentCount = 0; float loss = 0; /* data files */ FILE * file = fopen(fn, "rb"); CheckNTErrors(file, "Cannot read the test file"); FILE * ofile = fopen(ofn, "wb"); CheckNTErrors(ofile, "Cannot open the output file"); int devID = model->devID; XNet net; double startT = GetClockSec(); wordCount = 0; /* batch of input sequences */ XTensor batchEnc; XTensor batchDec; /* label */ XTensor label; /* padding */ XTensor paddingEnc; XTensor paddingDec; /* gold standard */ XTensor gold; /* an array that keeps the sequences */ int * seqs = new int[MILLION]; batchLoader.ClearBuf(); while(batchLoader.LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, &label, seqs, vSize, vSizeTgt, 1, 1, false, ws, wc, devID, false)) { CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); /* output probabilities */ XTensor output; /* make the network */ if(model->isLM) model->MakeLM(batchEnc, output, paddingEnc, false); else if(model->isMT) model->MakeMT(batchEnc, batchDec, output, paddingEnc, paddingDec, false); else{ ShowNTErrors("Illegal model type!"); } int bSize = output.GetDim(0); int length = output.GetDim(1); /* prediction probabilities */ XTensor probs; InitTensor1DV2(&probs, bSize * length); XTensor labelOnehot; labelOnehot = IndexToOnehot(label, vSizeTgt, 0); /* get probabilities */ float prob = GetProb(&output, &labelOnehot, &probs); /* dump the test result */ for(int s = 0; s < bSize; s++){ DTYPE sum = 0; int * seq = seqs + s * length; for(int i = 0; i < length; i++){ if(seq[i] >= 0){ fprintf(ofile, "%d ", seq[i]); } else break; } fprintf(ofile, "||| "); for(int i = 0; i < length; i++){ if(seq[i] >= 0){ DTYPE p = probs.Get1D(s * length + i); fprintf(ofile, "%.3e ", p); sum += p; } else break; } fprintf(ofile, "||| %e\n", sum); } loss += -prob; wordCount += wc; wordCountTotal += wc; sentCount += 1; } fclose(file); fclose(ofile); delete[] seqs; double elapsed = GetClockSec() - startT; XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n", elapsed,wordCountTotal, exp(loss / wordCount)); } /* make a checkpoint >> model - the model >> validFN - validation data file >> modelFN - model data file >> label - label of the model >> id - id of the checkpoint */ void T2TTrainer::MakeCheckpoint(T2TModel * model, const char * validFN, const char * modelFN, const char * label, int id) { char * fn = new char[MAX_LINE_LENGTH]; char * fn2 = new char[MAX_LINE_LENGTH]; sprintf(fn, "%s.%s.%03d", modelFN, label, id); sprintf(fn2, "%s.%s.%03d.output", modelFN, label, id); model->Dump(fn); //if(validFN != NULL){ //T2TTrainer trainer; //trainer.Init(argNum, argArray); //trainer.Test(validFN, fn2, model); //} delete[] fn; delete[] fn2; } /* get word probabilities for a batch of sequences >> output - word distribution for each position >> gold - gold standard >> wordProbs - word probability for gold prediction */ float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs) { XTensor probs; InitTensor(&probs, output); _Multiply(output, gold, &probs); /* probability of each word */ XTensor wprobs; InitTensor1DV2(&wprobs, output->unitNum/output->GetDim(-1), X_FLOAT, output->devID); int dims[2] = {output->unitNum/output->GetDim(-1), output->GetDim(-1)}; probs.Reshape(2, dims); _ReduceSum(&probs, &wprobs, 1); if(wordProbs != NULL) _CopyValues(&wprobs, wordProbs); /* reshape the tensor to fit it into the reduce procedure TODO: XTensor supports scalars */ dims[0] = 1; dims[1] = probs.unitNum; probs.Reshape(2, dims); /* probability for the batch */ XTensor result; InitTensor1DV2(&result, 1, X_FLOAT, output->devID); _ReduceSum(&probs, &result, 1); return result.Get1D(0); } /* update the model by delta rule \theta_new = \theta - \lrate * grad where \lrate = d^-0.5 * min(stepNum^-0.5, stepNum * warmupStepNum^-1.5) >> model - the t2t model >> lr - learning rate */ void T2TTrainer::Update(T2TModel * model, const float lr) { TensorList ws(100); model->GetParams(ws); for(int i = 0; i < ws.count; i++){ XTensor * para = (XTensor*)ws.Get(i); XTensor * paraGrad = para->grad; if (paraGrad == NULL) continue; CheckNTErrors(para != NULL, "NULL parameter tensor!"); CheckNTErrors(paraGrad != NULL, "NULL gradient tensor!"); if(useAdam){ adamBeta1T *= adamBeta1; adamBeta2T *= adamBeta2; DTYPE e = lr * (DTYPE)sqrt(1 - adamBeta2T) / (1 - adamBeta1T); DTYPE d = adamDelta * (DTYPE)sqrt(1 - adamBeta2T); /* m = beta_1 * m + (1-beta_1) * grad */ XTensor * m = (XTensor*)moments.Get(i); _ScaleAndShiftMe(m, adamBeta1, 0); _Sum(m, paraGrad, m, (1.0F - adamBeta1)); /* v = beta_2 * v + (1-beta_2) * grad * grad*/ XTensor * v = (XTensor*)moments2nd.Get(i); _Multiply(paraGrad, paraGrad, v, adamBeta2/(1.0F - adamBeta2)); _ScaleAndShiftMe(v, (1.0F - adamBeta2), 0); /* v2 = m / (sqrt(v) + delta) */ XTensor * v2 = NewTensorBufV2(v, v->devID); _Power(v, v2, 0.5F); _ScaleAndShiftMe(v2, 1.0F, d); _Div(m, v2, v2); /* the delta rule */ _Sum(para, v2, para, -e); DelTensorBuf(v2); } else{ /* the delta rule */ _Sum(para, paraGrad, para, -lr); } /* clear gradient */ paraGrad->SetZeroAll(); } } /* prepare model for training >> model - the model for training */ void T2TTrainer::PrepareModel(T2TModel * model) { moments.Clear(); moments2nd.Clear(); TensorList ws(100); model->GetParams(ws); for(int i = 0; i < ws.count; i++){ XTensor * para = (XTensor*)ws.Get(i); XNoder::MakeGrad(para); if(useAdam){ XTensor * m = new XTensor(para); XTensor * m2 = new XTensor(para); m->SetZeroAll(); m2->SetZeroAll(); moments.Add(m); moments2nd.Add(m2); } } adamBeta1T = 1.0F; adamBeta2T = 1.0F; } /* do padding on the output >> output - output tensor of the network >> gold - gold standard >> padding - padding of a batch of sentences >> lsP - smoothing factor */ void T2TTrainer::PadOutput(XTensor * output, XTensor * gold, XTensor * padding) { if(output == NULL || padding == NULL) return; int on = output->order; int * dimso = new int[on]; memcpy(dimso, output->dimSize, sizeof(int) * on); output->Reshape(output->unitNum/dimso[output->order - 1], dimso[output->order - 1]); XTensor * padding2 = NewTensorBufV2(1, &padding->unitNum, X_FLOAT, padding->devID); _CopyValues(padding, padding2); _MultiplyDim(output, padding2, output, 0); _ScaleAndShiftMe(padding2, 1e9F, -1e9F); _SumDim(output, padding2, output, 0); output->Reshape(on, dimso); if(gold != NULL){ gold->Reshape(gold->unitNum/dimso[gold->order - 1], dimso[gold->order - 1]); _CopyValues(padding, padding2); _MultiplyDim(gold, padding2, gold, 0); gold->Reshape(on, dimso); } delete[] dimso; DelTensorBuf(padding2); } /* recale the output and gold tensors for normalized loss >> output - output tensor of the network >> gold - gold standard >> padding - padding of a batch of sentences */ void T2TTrainer::RescaleOutput(XTensor * output, XTensor * gold, XTensor * padding) { CheckNTErrors(output->order == 3, "Wrong dimension number!"); CheckNTErrors(gold->order == 3, "Wrong dimension number!"); DTYPE count = _ReduceSumAll(padding); _ExpMe(output); _ScaleAndShiftMe(output, 1/count); _LogMe(output); _ScaleAndShiftMe(gold, 1/count); } /* perform label smoothing >> gold - gold standard >> smoothed - result of label smoothing >> p - smoothing factor */ void T2TTrainer::LabelSmooth(XTensor * gold, XTensor * smoothed, DTYPE p) { CheckNTErrors(p >= 0 && p <= 1.0F, "Smoothing factor must be in range [0,1]"); int n = gold->GetDim(-1); DTYPE q = 1.0F - p; DTYPE gift = p / n; InitTensorV2(smoothed, gold); _CopyValues(gold, smoothed); if(p == 0) return; _ScaleAndShiftMe(smoothed, q, gift); } }