/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02
 */

#include <math.h>
#include "T2TTrainer.h"
#include "T2TUtility.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h"
#include "../../network/XNoder.h"

namespace transformer
{

/* constructor */
T2TTrainer::T2TTrainer()
{
    seqLen = NULL;
    nseqBuf = 0;
    nextSeq = -1;
}

/* de-constructor */
T2TTrainer::~T2TTrainer()
{
    delete[] buf;
    delete[] seqLen;
    delete[] seqOffset;

    for(int i = 0; i < moments.count; i++){
        XTensor * m = (XTensor*)moments.Get(i);
        delete m;
    }

    for(int i = 0; i < moments2nd.count; i++){
        XTensor * m = (XTensor*)moments2nd.Get(i);
        delete m;
    }
}

/* 
initialization 
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TTrainer::Init(int argc, const char ** argv)
{
    bool useMem = false;

    LoadParamBool(argc, argv, "mem", &useMem, useMem);
    LoadParamFloat(argc, argv, "lrate", &lrate, 1.0F);
    LoadParamFloat(argc, argv, "lrbias", &lrbias, 0);
    LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1);
    LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1);
    LoadParamInt(argc, argv, "nepoch", &nepoch, 1);
    LoadParamInt(argc, argv, "nstep", &nstep, 1);
    LoadParamInt(argc, argv, "d", &d, 512);
    LoadParamInt(argc, argv, "nwarmup", &nwarmup, 4000);
    LoadParamInt(argc, argv, "vsize", &vSize, 1);
    LoadParamBool(argc, argv, "sorted", &isLenSorted, false);
    LoadParamInt(argc, argv, "bufsize", &bufSize, 50000);
    LoadParamBool(argc, argv, "adam", &useAdam, false);
    LoadParamFloat(argc, argv, "adambeta1", &adamBeta1, 0.9F);
    LoadParamFloat(argc, argv, "adambeta2", &adamBeta2, 0.999F);
    LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-8F);

    buf = new int[bufSize];
    seqLen = new int[bufSize];
    seqOffset = new int[bufSize];

    adamBeta1T = 1.0F;
    adamBeta2T = 1.0F;

}

int tc = 0;

/* 
train the model
>> fn - training data file
>> model - model to train
*/
void T2TTrainer::Train(const char * fn, T2TModel * model)
{
    int epoch = 0;
    int step = 0;
    int wc = 0;
    int wordCount = 0;
    int wordCountTotal = 0;
    bool isEnd = false;
    float loss = 0;
    float lr = 0;

    PrepareModel(model);

    int devID = model->devID;
    XMem * mem = model->mem;

    if(mem != NULL && mem->mode == UNI_FREE)
        mem->SetPin();

    XNet net;

    tf = fopen("tmp.xx.txt", "wb");
    tc = 0;
    
    double startT = GetClockSec();
    
    for(epoch = 1; epoch <= nepoch; epoch++){
        
        FILE * file = fopen(fn, "rb");
        CheckNTErrors(file, "cannot open training file!");
        
        wordCount = 0;
        loss = 0;

        if(mem != NULL)
            mem->BackToPin();
        
        /* batch of input sequences */
        XTensor batch;

        /* padding */
        XTensor padding;

        /* gold standard */
        XTensor gold;
        
        while(LoadBatch(file, true, &batch, &padding, &gold, NULL, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)){
            
            /* output probabilities */
            XTensor output;
            
            /* make the network */
            model->Make(batch, output, padding);

            /* make paddings for the output */
            if(output.GetDim(0) > 1)
                PadOutput(&output, &padding);

            /* back-propagation for obtaining gradients */
            net.Backward(output, gold, CROSSENTROPY);
            
            /* learning rate */
            lr = lrate * (1.0F / (float)sqrt((float)d)) * (float)MIN(pow((float)step + 1, -0.5F - lrbias), ((float)step + 1) * pow((float)nwarmup, -1.5F - lrbias));
            
            /* update the parameters */
            Update(model, lr);
            
            /* get probabilities */
            float prob = GetProb(&output, &gold, NULL);

            MTYPE totalUsed = 0;
            MTYPE totalSize = 0;
            
            for (int i = 0; i <= mem->curBlockID; i++) {
                totalSize += mem->blocks[i].size;
                totalUsed += mem->blocks[i].used;
            }

            //fprintf(stderr, "%d(%ld,%ld,%f)\n", mem->curBlockID, totalUsed, totalSize, (float)totalUsed/totalSize);
            
            loss += -prob;
            wordCount += wc;
            wordCountTotal += wc;
            
            if(++step >= nstep){
                isEnd = true;
                break;
            }
            
            if (step % 1 == 0) {
                double elapsed = GetClockSec() - startT;
                XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f, sppl=%.3f\n",
                        lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount), exp(-prob/wc));
            }

            if(mem != NULL && mem->mode == UNI_FREE)
                mem->BackToPin();
        }
        
        fclose(file);

        if (isEnd)
            break;
    }

    if(mem != NULL && mem->mode == UNI_FREE)
        mem->BackToPin();
    
    double elapsed = GetClockSec() - startT;

    fclose(tf);
    
    XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n",
            lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
    XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
            elapsed, step, epoch);
}

/* 
test the model
>> fn - test data file
>> ofn - output data file
>> model - model that is trained
*/
void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
{
    int wc = 0;
    int wordCount = 0;
    int wordCountTotal = 0;
    float loss = 0;

    /* data files */
    FILE * file = fopen(fn, "rb");
    CheckNTErrors(file, "Cannot read the test file");
    FILE * ofile = fopen(ofn, "wb");
    CheckNTErrors(ofile, "Cannot open the output file");

    int devID = model->devID;
    XMem * mem = model->mem;

    XNet net;

    tf = fopen("tmp.xx.txt", "wb");
    tc = 0;
    
    double startT = GetClockSec();
        
    wordCount = 0;

    if(mem != NULL && mem->mode == UNI_FREE)
        mem->BackToPin();
        
    /* batch of input sequences */
    XTensor batch;

    /* padding */
    XTensor padding;

    /* gold standard */
    XTensor gold;

    /* an array that keeps the sequences */
    int * seqs = new int[MILLION];
    
    ClearBuf();

    while(LoadBatch(file, true, &batch, &padding, &gold, seqs, 1, vSize, 1, 1, isLenSorted, wc, devID, mem)){

        CheckNTErrors(batch.order == 3, "wrong tensor order of the sequence batch");
            
        /* output probabilities */
        XTensor output;
            
        /* make the network */
        model->Make(batch, output, padding);

        int bSize = batch.GetDim(0);
        int length = batch.GetDim(1);

        /* prediction probabilities */
        XTensor probs;
        InitTensor1D(&probs, bSize * length);

        /* get probabilities */
        float prob = GetProb(&output, &gold, &probs);

        /* dump the test result */
        for(int s = 0; s < bSize; s++){
            DTYPE sum = 0;
            int * seq = seqs + s * length;
            for(int i = 0; i < length; i++){
                if(seq[i] >= 0){
                    fprintf(ofile, "%d ", seq[i]);
                }
                else
                    break;
            }
            fprintf(ofile, "||| ");
            for(int i = 0; i < length; i++){
                if(seq[i] >= 0){
                    DTYPE p = probs.Get1D(s * length + i);
                    fprintf(ofile, "%.3e ", p);
                    sum += p;
                }
                else
                    break;
            }
            fprintf(ofile, "||| %e\n", sum);
        }
            
        loss += -prob;
        wordCount += wc;
        wordCountTotal += wc;
            
        if(mem != NULL && mem->mode == UNI_FREE)
            mem->BackToPin();
    }

    if(mem != NULL && mem->mode == UNI_FREE)
        mem->BackToPin();
        
    fclose(file);
    fclose(ofile);

    delete[] seqs;
    
    double elapsed = GetClockSec() - startT;

    fclose(tf);
    
    XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n",
            elapsed,wordCountTotal, exp(loss / wordCount));
}

char line[MAX_SEQUENCE_LENGTH];

/* 
load data to buffer 
>> file - where to load data
*/
int T2TTrainer::LoadBuf(FILE * file)
{
    int lineCount = 0;
    int seqCount = 0;
    int wordCount = 0;
    while(fgets(line, MAX_SEQUENCE_LENGTH - 1, file)){
        int len = (int)strlen(line);

        while(line[len - 1] == '\r' || line[len - 1] == '\n'){
            line[len - 1] = 0;
            len--;
        }

        len = (int)strlen(line);
        if(len == 0)
            continue;
        
        /* how many characters are in a word */
        int wSize = 0;
        
        /* how many words are in the sentence */
        int wNum = 0;
        int wNumLocal = 0;
        int i = 0;

        for(i = 0; i < len; i++){
            /* load word (id) seperated by space or tab */
            if((line[i] == ' ' || line[i] == '\t') && wSize > 0){
                line[i] = 0;

                if(wSize == 3 && line[i - 1] == '|' && line[i - 2] == '|' && line[i - 3] == '|'){
                    seqLen[seqCount] = wNumLocal;
                    seqOffset[seqCount] = wordCount + wNum - wNumLocal;
                    seqCount++;
                    wNumLocal = 0;
                }
                else{
                    buf[wordCount + wNum++] = atoi(line + i - wSize);
                    wNumLocal++;
                }

                wSize = 0;
            }
            else
                wSize++;
        }

        if(wSize > 0){
            buf[wordCount + wNum++] = atoi(line + i - wSize);
            wNumLocal++;
        }

        seqLen[seqCount] = wNumLocal;
        seqOffset[seqCount] = wordCount + wNum - wNumLocal;
        seqCount++;

        wordCount += wNum;
        lineCount++;

        if(wordCount >= bufSize - MAX_SEQUENCE_LENGTH)
            break;
    }

    nseqBuf = seqCount;
    nextSeq = 0;

    return lineCount;
}

/* clear the data buffer */
void T2TTrainer::ClearBuf()
{
    nseqBuf = 0;
    nextSeq = -1;
}

/* 
load a batch of sequences 
>> file - the handle to the data file
>> isLM - indicates whether the data is used for training lms
>> batch - the batch of the input sequences
>> padding - padding of the input sequences
>> output - the batch of the output sequences
>> seqs - keep the sequences in an array
>> step - the step we go over when move to the next sequence
>> vs - vocabulary size
>> sBatch - batch size of sequences
>> wBatch - batch size of words
>> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count
>> devID - device id
>> mem - memory pool
*/
int T2TTrainer::LoadBatch(FILE * file, bool isLM,
                          XTensor * batch, XTensor * padding, XTensor * output, 
                          int * seqs,
                          int step, int vs, int sBatch, int wBatch, 
                          bool isSorted, int &wCount,
                          int devID, XMem * mem)
{
    if(nextSeq < 0 || nextSeq >= nseqBuf)
        LoadBuf(file);

    int seq = MAX(nextSeq, 0);
    int wc = 0;
    int wn = 0;
    int sc = 0;
    int max = 0;
    while(seq + sc < nseqBuf){
        wn = seqLen[seq + sc];
        wc += wn;
        sc += 1;

        if(max < wn)
            max = wn;

        if(sc >= sBatch && wc >= wBatch)
            break;
    }

    wCount = 0;
    nextSeq = seq + sc;

    if(sc <= 0)
        return 0;

    if(isLM){
        int dims[MAX_TENSOR_DIM_NUM];
        dims[0] = sc;
        dims[1] = max;
        dims[2] = vs;

        InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem);
        InitTensor2D(padding, sc, max, X_FLOAT, devID, mem);
        InitTensor(output, 3, dims, X_FLOAT, 1.0F, devID, mem);

        if(batch->grad == NULL)
            XNoder::MakeGrad(batch);
        else
            InitTensor(batch->grad, 3, dims, X_FLOAT, 1.0F, devID, mem);

        if(padding->grad == NULL)
            XNoder::MakeGrad(padding);
        else
            InitTensor2D(padding->grad, sc, max, X_FLOAT, devID, mem);

        if(output->grad == NULL)
            XNoder::MakeGrad(output);
        else
            InitTensor(output->grad, 3, dims, X_FLOAT, 1.0F, devID, mem);

        batch->SetZeroAll();
        padding->SetZeroAll();
        output->SetZeroAll();
        batch->grad->SetZeroAll();
        padding->grad->SetZeroAll();
        output->grad->SetZeroAll();

        int seqSize = 0;

        //fprintf(tf, "batch %d(%d)\n", tc++, sc);

        /* this might be slow on GPUs :( */
        for(int s = seq; s < seq + sc; s++){
            for(int w = 0; w < seqLen[s]; w++){
                batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
                padding->Set2D(1.0F, s - seq, w);
                if(w > 0)
                    output->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]);
                if(w == seqLen[s] - 1)
                    output->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
                wCount++;
                /*fprintf(tf, "%d", buf[seqOffset[s] + w]);
                if(w < seqLen[s] - 1)
                    fprintf(tf, " ");
                else
                    fprintf(tf, "\n");*/
                if(seqs != NULL)
                    seqs[seqSize++] = buf[seqOffset[s] + w];
            }

            if(seqs != NULL){
                for(int w = seqLen[s]; w < max; w++)
                    seqs[seqSize++] = -1;
            }
        }

        fflush(tf);
    }

    return sc;
}
    
/*
get word probabilities for a batch of sequences
>> output - word distribution for each position
>> gold - gold standard
>> wordProbs - word probability for gold prediction
*/
float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs)
{
    XTensor probs;
    InitTensor(&probs, output);
    
    /* probs[i,j] = output[i,j] * gold[i,j] */
    _Multiply(output, gold, &probs);
    
    /* probability of each word */
    XTensor wprobs;
    InitTensor1D(&wprobs, output->unitNum/output->GetDim(-1), X_FLOAT, output->devID, output->mem);
    
    int dims[2] = {output->unitNum/output->GetDim(-1), output->GetDim(-1)};
    probs.Reshape(2, dims);
    _ReduceSum(&probs, &wprobs, 1);
    
    if(wordProbs != NULL)
        _CopyValues(&wprobs, wordProbs);
    
    /* reshape the tensor to fit it into the reduce procedure
     TODO: XTensor supports scalars */
    dims[0] = 1;
    dims[1] = probs.unitNum;
    probs.Reshape(2, dims);
    
    /* probability for the batch */
    XTensor result;
    InitTensor1D(&result, 1, X_FLOAT, output->devID, output->mem);
    _ReduceSum(&probs, &result, 1);
    
    return result.Get1D(0);
}

/* 
update the model by delta rule
\theta_new = \theta - \lrate * grad
where
\lrate = d^-0.5 * min(stepNum^-0.5, stepNum * warmupStepNum^-1.5)
>> model - the t2t model
>> lr - learning rate
*/
void T2TTrainer::Update(T2TModel * model, const float lr)
{
    XList ws(100);

    model->GetParams(ws);

    for(int i = 0; i < ws.count; i++){
        XTensor * para = (XTensor*)ws.Get(i);
        XTensor * paraGrad = para->grad;

        if (para == NULL || paraGrad == NULL)
            continue;

        CheckNTErrors(para != NULL, "NULL parameter tensor!");
        CheckNTErrors(paraGrad != NULL, "NULL gradient tensor!");

        if(useAdam){
            adamBeta1T *= adamBeta1;
            adamBeta2T *= adamBeta2;
            DTYPE e = lr * (DTYPE)sqrt(1 - adamBeta2T) / (1 - adamBeta1T);
            DTYPE d = adamDelta * (DTYPE)sqrt(1 - adamBeta2T);

            /* m = beat_1 * m + (1-beta_1) * grad */
            XTensor * m = (XTensor*)moments.Get(i);
            _ScaleAndShiftMe(m, adamBeta1, 0);
            _Sum(m, paraGrad, m, (1.0F - adamBeta1));
            
            /* v = beat_2 * v + (1-beta_2) * grad * grad*/
            XTensor * v = (XTensor*)moments2nd.Get(i);
            _Multiply(paraGrad, paraGrad, v, adamBeta2/(1.0F - adamBeta2));
            _ScaleAndShiftMe(v, (1.0F - adamBeta2), 0);

            /* v2 = m / (sqrt(v) + delta) */
            XTensor * v2 = NewTensorBuf(v, v->devID, v->mem);
            _Power(v, v2, 0.5F);
            _ScaleAndShiftMe(v2, 1.0F, d);
            _Div(m, v2, v2);

            /* the delta rule */
            _Sum(para, v2, para, -e);

            DelTensorBuf(v2);

        }
        else{
            /* the delta rule */
            _Sum(para, paraGrad, para, -lr);
        }

        /* clear gradient */
        paraGrad->SetZeroAll();
    }
}

/* 
prepare model for training 
>> model - the model for training
*/
void T2TTrainer::PrepareModel(T2TModel * model)
{
    moments.Clear();
    moments2nd.Clear();

    XList ws(100);

    model->GetParams(ws);

    for(int i = 0; i < ws.count; i++){
        XTensor * para = (XTensor*)ws.Get(i);
        XNoder::MakeGrad(para);

        if(useAdam){
            XTensor * m = new XTensor(para);
            XTensor * m2 = new XTensor(para);
            m->SetZeroAll();
            m2->SetZeroAll();
            moments.Add(m);
            moments2nd.Add(m2);
        }
    }

    adamBeta1T = 1.0F;
    adamBeta2T = 1.0F;
}

/* 
do padding on the output 
>> output - output tensor of the network
>> padding - padding of a batch of sentences
*/
void T2TTrainer::PadOutput(XTensor * output, XTensor * padding)
{
    if(output == NULL || padding == NULL)
        return;
    
    int on = output->order;
    int * dimso = new int[on];

    memcpy(dimso, output->dimSize, sizeof(int) * on);

    output->Reshape(output->unitNum/dimso[output->order - 1], dimso[output->order - 1]);

    XTensor * padding2 = NewTensorBuf(1, &padding->unitNum, X_FLOAT, 1.0F, padding->devID, padding->mem);

    _CopyValues(padding, padding2);
    _ScaleAndShiftMe(padding2, 1e9F, -1e9F);

    _SumDim(output, padding2, output, 0);

    output->Reshape(on, dimso);

    delete[] dimso;
    DelTensorBuf(padding2);
}

}
