/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02
 */

#include <math.h>
#include "T2TTrainer.h"
#include "T2TUtility.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h"

namespace transformer
{

/* constructor */
T2TTrainer::T2TTrainer()
{
    devID = -1;
    mem = NULL;
    seqLen = NULL;
    nseqBuf = 0;
    nextSeq = -1;
}

/* de-constructor */
T2TTrainer::~T2TTrainer()
{
    delete[] buf;
    delete[] seqLen;
    delete[] seqOffset;
}

/* 
initialization 
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TTrainer::Init(int argc, const char ** argv)
{
    LoadParamInt(argc, argv, "dev", &devID, -1);
    LoadParamFloat(argc, argv, "lrate", &lrate, 0.001F);
    LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1);
    LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1);
    LoadParamInt(argc, argv, "nepoch", &nepoch, 1);
    LoadParamInt(argc, argv, "nstep", &nstep, 1);
    LoadParamInt(argc, argv, "d", &d, 512);
    LoadParamInt(argc, argv, "nwarmup", &nwarmup, 4000);
    LoadParamInt(argc, argv, "vsize", &vSize, 1);
    LoadParamBool(argc, argv, "sorted", &isLenSorted, false);
    LoadParamInt(argc, argv, "bufsize", &bufSize, 50000);

    buf = new int[bufSize];
    seqLen = new int[bufSize];
    seqOffset = new int[bufSize];
}

/* 
train the model
>> fn - training data file
>> model - model to train
*/
void T2TTrainer::Train(const char * fn, T2TModel * model)
{
    int epoch = 0;
    int step = 0;
    int wc = 0;
    int wordCount = 0;
    int wordCountTotal = 0;
    bool isEnd = false;
    float loss = 0;
    float lr = 0;

    XNet net;
    
    double startT = GetClockSec();
    
    for(epoch = 0; epoch < nepoch; epoch++){
        
        FILE * file = fopen(fn, "rb");
        CheckNTErrors(file, "cannot open training file!");
        
        wordCount = 0;
        
        /* batch of input sequences */
        XTensor batch;
        
        while(LoadBatch(file, &batch, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc)){
            
            /* output probabilities */
            XTensor output;
            
            /* make the network */
            model->Make(batch, output);

            /* back-propagation for obtaining gradients */
            net.Backward(output, batch, CROSSENTROPY);
            
            /* learning rate */
            lr = (1 / (float)sqrt((float)d)) * (float)MIN(pow(step + 1, -0.5), (step + 1) * pow(nwarmup, -1.5));
            //lr = 0.00005F;
            
            /* update the parameters */
            Update(model, lr);
            
            /* get probabilities */
            float prob = GetProb(&output, &batch, NULL);
            
            loss += -prob;
            wordCount += wc;
            wordCountTotal += wc;
            
            if(++step >= nstep){
                isEnd = true;
                break;
            }
            
            if (step % 1 == 0) {
                double elapsed = GetClockSec() - startT;
                XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n",
                        lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
            }
        }
        
        fclose(file);

        if (isEnd)
            break;
    }
    
    double elapsed = GetClockSec() - startT;
    
    XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n",
            lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
    XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
            elapsed, step, epoch);
}

char line[MAX_SEQUENCE_LENGTH];

/* 
load data to buffer 
>> file - where to load data
*/
int T2TTrainer::LoadBuf(FILE * file)
{
    int lineCount = 0;
    int seqCount = 0;
    int wordCount = 0;
    while(fgets(line, MAX_SEQUENCE_LENGTH - 1, file)){
        int len = (int)strlen(line);

        while(line[len - 1] == '\r' || line[len - 1] == '\n'){
            line[len - 1] = 0;
            len--;
        }

        len = (int)strlen(line);
        if(len == 0)
            continue;
        
        /* how many characters are in a word */
        int wSize = 0;
        
        /* how many words are in the sentence */
        int wNum = 0;
        int wNumLocal = 0;
        int i = 0;

        for(i = 0; i < len; i++){
            /* load word (id) seperated by space or tab */
            if((line[i] == ' ' || line[i] == '\t') && wSize > 0){
                line[i] = 0;

                if(wSize == 3 && line[i - 1] == '|' && line[i - 2] == '|' && line[i - 3] == '|'){
                    seqLen[seqCount] = wNumLocal;
                    seqOffset[seqCount] = wordCount + wNum - wNumLocal;
                    seqCount++;
                    wNumLocal = 0;
                }
                else{
                    buf[wordCount + wNum++] = atoi(line + i - wSize);
                    wNumLocal++;
                }

                wSize = 0;
            }
            else
                wSize++;
        }

        if(wSize > 0){
            buf[wordCount + wNum++] = atoi(line + i - wSize);
            wNumLocal++;
        }

        seqLen[seqCount] = wNumLocal;
        seqOffset[seqCount] = wordCount + wNum - wNumLocal;
        seqCount++;

        wordCount += wNum;
        lineCount++;

        if(wordCount >= bufSize - MAX_SEQUENCE_LENGTH)
            break;
    }

    nseqBuf = seqCount;
    nextSeq = 0;

    return lineCount;
}

/* 
load a batch of sequences 
>> file - the handle to the data file
>> batch - the batch
>> step - the step we go over when move to the next sequence
>> vs - vocabulary size
>> sBatch - batch size of sequences
>> wBatch - batch size of words
>> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count
*/
int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount)
{
    if(nextSeq < 0 || nextSeq >= nseqBuf)
        LoadBuf(file);

    int seq = MAX(nextSeq, 0);
    int wc = 0;
    int wn = 0;
    int sc = 0;
    int max = 0;
    while(seq + sc < nseqBuf){
        wn = seqLen[seq + sc];
        wc += wn;
        sc += 1;

        if(max < wn)
            max = wn;

        if(sc >= sBatch && wc >= wBatch)
            break;
    }

    wCount = 0;
    nextSeq = seq + sc;

    if(sc > 0){
        int dims[MAX_TENSOR_DIM_NUM];
        dims[0] = sc;
        dims[1] = max;
        dims[2] = vs;

        if(batch->order != 3 || batch->GetDim(0) != dims[0] || 
           batch->GetDim(1) != dims[1] || batch->GetDim(2) != dims[2]){
               InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem);
        }

        batch->SetZeroAll();

        /* this might be slow on GPUs :( */
        for(int s = seq; s < seq + sc; s++){
            for(int w = 0; w < seqLen[s]; w++){
                batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
                wCount++;
            }
        }
    }

    return sc;
}
    
/*
get word probabilities for a batch of sequences
>> output - word distribution for each position
>> gold - gold standard
>> wordProbs - word probability for gold prediction
*/
float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs)
{
    XTensor probs;
    InitTensor(&probs, output);
    
    /* probs[i,j] = output[i,j] * gold[i,j] */
    _Multiply(output, gold, &probs);
    
    /* probability of each word */
    XTensor wprobs;
    InitTensor1D(&wprobs, output->unitNum/output->GetDim(-1), X_FLOAT, output->devID, output->mem);
    
    int dims[2] = {output->unitNum/output->GetDim(-1), output->GetDim(-1)};
    probs.Reshape(2, dims);
    _ReduceSum(&probs, &wprobs, 1);
    
    if(wordProbs != NULL)
        _CopyValues(&wprobs, wordProbs);
    
    /* reshape the tensor to fit it into the reduce procedure
     TODO: XTensor supports scalars */
    dims[0] = 1;
    dims[1] = probs.unitNum;
    probs.Reshape(2, dims);
    
    /* probability for the batch */
    XTensor result;
    InitTensor1D(&result, 1, X_FLOAT, output->devID, output->mem);
    _ReduceSum(&probs, &result, 1);
    
    return result.Get1D(0);
}

/* 
update the model by delta rule
\theta_new = \theta - \lrate * grad
where
\lrate = d^-0.5 * min(stepNum^-0.5, stepNum * warmupStepNum^-1.5)
>> model - the t2t model
>> lr - learning rate
*/
void T2TTrainer::Update(T2TModel * model, const float lr)
{
    XList ws(100);

    ws.Add(&model->outputLayer.w);
    
    for(int i = 0; i < model->encoder.nlayer; i++){
        ws.Add(&model->encoder.fnns[i].w1);
        ws.Add(&model->encoder.fnns[i].b1);
        ws.Add(&model->encoder.fnns[i].w2);
        ws.Add(&model->encoder.fnns[i].b2);
        ws.Add(&model->encoder.attentions[i].wk);
        ws.Add(&model->encoder.attentions[i].wq);
        ws.Add(&model->encoder.attentions[i].wv);
        ws.Add(&model->encoder.fnnLayerNorms[i].w);
        ws.Add(&model->encoder.fnnLayerNorms[i].b);
        ws.Add(&model->encoder.attLayerNorms[i].w);
        ws.Add(&model->encoder.attLayerNorms[i].b);
    }

    ws.Add(&model->encoder.embedder.w);

    for(int i = 0; i < ws.count; i++){
        XTensor * para = (XTensor*)ws.Get(i);
        XTensor * paraGrad = para->grad;

        if (para == NULL || paraGrad == NULL)
            continue;

        CheckNTErrors(para != NULL, "NULL parameter tensor!");
        CheckNTErrors(paraGrad != NULL, "NULL gradient tensor!");

        /*
        DTYPE * d = new DTYPE[para->unitNum * para->unitSize];
        DTYPE * g = new DTYPE[para->unitNum * para->unitSize];

        XMemCopy(d, -1, para->data, para->devID, para->unitNum * para->unitSize);
        XMemCopy(g, -1, paraGrad->data, paraGrad->devID, para->unitNum * para->unitSize);

        for (int i = 0; i < para->unitNum; i++) {
            if (IsNAN(d[i]) || IsINF(d[i])) {
                int nnn = 0;
            }
            if (IsNAN(g[i]) || IsINF(g[i])) {
                int nnn = 0;
            }
        }

        delete[] d;
        delete[] g;
        */

        /* the delta rule */
        _Sum(para, paraGrad, para, -lr);

        /* clear gradient */
        paraGrad->SetZeroAll();
    }
}

}
