/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2019, Natural Language Processing Lab, Northestern University.
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

 /*
  * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13
  * This is the first source file I create in 2019 - new start!
  */

#ifndef __T2TPREDICTOR_H__
#define __T2TPREDICTOR_H__

#include "T2TModel.h"
#include "T2TLengthPenalty.h"

namespace transformer
{

#define T2T_PID_EMPTY -1

/* state for search. It keeps the path (back-pointer), prediction distribution,
   and etc. It can be regarded as a hypothsis in translation. */
class T2TState
{
public:
    /* we assume that the prediction is an integer */
    int prediction;

    /* id of the problem. One can regard it as the sentence id when we
       translate a number of sentences in the batched manner. The hypothesis
       is empty if id = -1 */
    int pid;

    /* indicates whether the state is an end */
    bool isEnd;

    /* indicates whether the state is the start */
    bool isStart;

    /* indicates whether the state is completed */
    bool isCompleted;

    /* probability of every prediction (last state of the path) */
    float prob;

    /* probability of every path */
    float probPath;

    /* model score of every path. A model score = path probability + some other stuff */
    float modelScore;

    /* nubmer of steps we go over so far */
    int nstep;

    /* pointer to the previous state */
    T2TState* last;
};

/* a bundle of states */
class T2TStateBundle
{
public:
    /* predictions */
    XTensor prediction;

    /* id of the previous state that generates the current one  */
    XTensor preID;

    /* mark that indicates whether each hypothesis is completed */
    XTensor endMark;

    /* probability of every prediction (last state of the path) */
    XTensor prob;

    /* probability of every path */
    XTensor probPath;

    /* model score of every path */
    XTensor modelScore;

    /* step number of each hypothesis */
    XTensor nstep;

    /* list of states */
    T2TState* states;

    /* number of states */
    int stateNum;

    /* indicates whether it is the first state */
    bool isStart;

public:
    /* constructor */
    T2TStateBundle();

    /* de-constructor */
    ~T2TStateBundle();

    /* create states */
    void MakeStates(int num);
};

/* The predictor reads the current state and then predicts the next.
   It is exactly the same procedure of MT inference -
   we get the state of previous words and then generate the next word.
   Here, a state can be regared as the representation of words (word
   indices, hidden states, embeddings and etc.).  */
class T2TPredictor
{
private:
    /* pointer to the transformer model */
    T2TModel* m;

    /* current state */
    T2TStateBundle* s;

    /* start symbol */
    int startSymbol;

public:
    /* constructor */
    T2TPredictor();

    /* de-constructor */
    ~T2TPredictor();

    /* create an initial state */
    void Create(T2TModel* model, XTensor* top, const XTensor* input, int beamSize, T2TStateBundle* state);

    /* set the start symbol */
    void SetStartSymbol(int symbol);

    /* read a state */
    void Read(T2TModel* model, T2TStateBundle* state);

    /* predict the next state */
    void Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inputEnc, XTensor* paddingEnc, bool isStart);

    /* generate paths up to the states of the current step */
    XTensor GeneratePaths(T2TStateBundle* state);

    /* get the predictions of the previous step */
    XTensor GetLastPrediction(T2TStateBundle* state);
};

}

#endif