/* NiuTrans.NMT - an open-source neural machine translation system.
 * Copyright (C) 2020 NiuTrans Research. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-10-09
 * $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
 */

#include "Decoder.h"
#include "Utility.h"
#include "submodel/LayerNorm.h"
#include "submodel/CommonModules.h"
#include "../../tensor/core/CHeader.h"

namespace nmt
{

/* constructor */
AttDecoder::AttDecoder()
{
    selfAtt = NULL;
    fnns = NULL;
    selfAttLayerNorms = NULL;
    fnnLayerNorms = NULL;
    enDeAtt = NULL;
    enDeAttLayerNorms = NULL;
    decoderLayerNorm = NULL;
    selfAttCache = NULL;
    enDeAttCache = NULL;
}

/* de-constructor */
AttDecoder::~AttDecoder()
{
    delete[] selfAttCache;
    delete[] enDeAttCache;
    delete[] selfAtt;
    delete[] fnns;
    delete[] selfAttLayerNorms;
    delete[] fnnLayerNorms;
    delete[] enDeAtt;
    delete[] enDeAttLayerNorms;
    if (preNorm)
        delete decoderLayerNorm;
}

/*
initialize the model
>> config - configurations of the model
*/
void AttDecoder::InitModel(Config& config)
{
    devID = config.devID;
    nlayer = config.nDecLayer;
    hSize = config.modelSize;
    eSize = config.embSize;
    vSize = config.tgtVocabSize;
    dropoutP = config.dropout;
    preNorm = config.preNorm;

    CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!");
    CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsizetgt\"");

    /* embedding model */
    embedder.InitModel(config, false);

    selfAtt = new Attention[nlayer];
    fnns = new FNN[nlayer];
    selfAttLayerNorms = new LN[nlayer];
    enDeAtt = new Attention[nlayer];
    enDeAttLayerNorms = new LN[nlayer];
    fnnLayerNorms = new LN[nlayer];

    selfAttCache = new Cache[nlayer];
    enDeAttCache = new Cache[nlayer];
    if (preNorm)
        decoderLayerNorm = new LN;

    /* initialize the stacked layers */
    for (int i = 0; i < nlayer; i++) {
        selfAtt[i].InitModel(config);
        fnns[i].InitModel(config);
        selfAttLayerNorms[i].InitModel(config);
        fnnLayerNorms[i].InitModel(config);
        enDeAtt[i].InitModel(config);
        enDeAttLayerNorms[i].InitModel(config);
        selfAttCache[i].enable = true;
        enDeAttCache[i].enable = true;
    }
    if (preNorm)
        decoderLayerNorm->InitModel(config);
}

/*
make the decoding network
>> inputDec - the input tensor of the decoder
>> outputEnc - the output tensor of the encoder
>> mask - mask that indicates which position is valid
>> maskEncDec - mask for the encoder-decoder attention
>> nstep - the current length of the decoder input
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the decoder
*/
XTensor AttDecoder::Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask,
                         XTensor* maskEncDec, int nstep, bool isTraining)
{
    XTensor x;

    x = embedder.Make(inputDec, true, isTraining, nstep);

    /* dropout */
    if (isTraining && dropoutP > 0)
        x = Dropout(x, dropoutP);

    for (int i = 0; i < nlayer; i++) {
        XTensor att;
        XTensor ende;
        XTensor fnn;
        XTensor res;
        XTensor selfAttnBefore;
        XTensor selfAttnAfter;
        XTensor endeAttnBefore;
        XTensor endeAttnAfter;
        XTensor fnnBefore;

        /* layer normalization with pre-norm for self-attn */
        selfAttnBefore = LayerNorm(x, selfAttLayerNorms[i], preNorm, true, false);

        /******************/
        /* self attention */
        att = selfAtt[i].Make(selfAttnBefore, selfAttnBefore, selfAttnBefore, 
                              mask, isTraining, &selfAttCache[i], SELF_ATT);

        /* dropout */
        if (isTraining && dropoutP > 0)
            att = Dropout(att, dropoutP);

        /* residual connection */
        res = Sum(att, x);

        /* layer normalization with post-norm for self-attention */
        selfAttnAfter = LayerNorm(res, selfAttLayerNorms[i], preNorm, false, true);

        /* layer normalization with pre-norm for encoder-decoder attention */
        endeAttnBefore = LayerNorm(selfAttnAfter, enDeAttLayerNorms[i], preNorm, true, false);

        /* encoder-decoder attention */
        ende = enDeAtt[i].Make(outputEnc, endeAttnBefore, outputEnc, maskEncDec, 
                               isTraining, &enDeAttCache[i], EN_DE_ATT);

        /* dropout */
        if (isTraining && dropoutP > 0)
            ende = Dropout(ende, dropoutP);

        /* residual connection */
        res = Sum(ende, selfAttnAfter);

        /* layer normalization with post-norm for encoder-decoder attention */
        endeAttnAfter = LayerNorm(res, enDeAttLayerNorms[i], preNorm, false, true);

        /* layer normalization with pre-norm for fnn */
        fnnBefore = LayerNorm(endeAttnAfter, fnnLayerNorms[i], preNorm, true, false);

        /* fnn */
        fnn = fnns[i].Make(fnnBefore, isTraining);

        /* dropout */
        if (isTraining && dropoutP > 0)
            fnn = Dropout(fnn, dropoutP);

        /* residual connection */
        res = Sum(fnn, endeAttnAfter);

        /* layer normalization with post-norm for fnn */
        x = LayerNorm(res, fnnLayerNorms[i], preNorm, false, true);
    }

    if (preNorm)
        return decoderLayerNorm->Make(x);

    return x;
}

/*
make the decoding network
>> inputDec - the input tensor of the decoder
>> outputEnc - the output tensor of the encoder
>> mask - mask that indicates which position is valid
>> maskEncDec - mask for the encoder-decoder attention
>> nstep - the current length of the decoder input
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the decoder
*/
XTensor AttDecoder::MakeFast(XTensor& inputDec, XTensor& outputEnc, XTensor* mask,
                             XTensor* maskEncDec, int nstep, bool isTraining)
{
    XTensor x;

    x = embedder.Make(inputDec, true, isTraining, nstep);

    /* dropout */
    if (isTraining && dropoutP > 0)
        x = Dropout(x, dropoutP);

    for (int i = 0; i < nlayer; i++) {
        XTensor res;

        res = x;

        /* layer normalization with pre-norm for self-attn */
        x = selfAttLayerNorms[i].Make(x);

        /******************/
        /* self attention */
        x = selfAtt[i].Make(x, x, x, mask, isTraining, &selfAttCache[i], SELF_ATT);

        /* dropout */
        if (isTraining && dropoutP > 0)
            x = Dropout(x, dropoutP);

        /* residual connection */
        x = Sum(res, x);

        res = x;

        /* layer normalization with pre-norm for encoder-decoder attention */
        x = enDeAttLayerNorms[i].Make(x);

        /* encoder-decoder attention */
        x = enDeAtt[i].Make(outputEnc, x, outputEnc, maskEncDec,
                            isTraining, &enDeAttCache[i], EN_DE_ATT);

        /* dropout */
        if (isTraining && dropoutP > 0)
            x = Dropout(x, dropoutP);

        /* residual connection */
        x = Sum(res, x);

        res = x;

        /* layer normalization with pre-norm for fnn */
        x = fnnLayerNorms[i].Make(x);

        /* fnn */
        x = fnns[i].Make(x, isTraining);

        /* dropout */
        if (isTraining && dropoutP > 0)
            x = Dropout(x, dropoutP);

        /* residual connection */
        x = Sum(res, x);
    }

    x = decoderLayerNorm->Make(x);

    return x;
}

}