/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2019, Natural Language Processing Lab, Northestern University.
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13
 * This is the first source file I create in 2019 - new start!
 */

#ifndef __T2TPREDICTOR_H__
#define __T2TPREDICTOR_H__

#include "T2TModel.h"

namespace transformer
{

/* state for search. It keeps the path (back-pointer), prediction distribution,
   and etc. It can be regarded as a hypothsis in translation. */
class T2TState
{
public:
    /* we assume that the prediction is an integer */
    int prediction;

    /* probability of the prediction */
    float prob;

    /* probability of the path */
    float pathProb;

    /* pointer to the previous state */
    T2TState * last;
};

/* a bundle of states */
class T2TStateBundle
{
public:
    /* predictions */
    XTensor prediction;

    /* distribution of every prediction (last state of the path) */
    XTensor probs;

    /* distribution of every path */
    XTensor pathProbs;

    /* layers on the encoder side. We actually use the encoder output instead
       of all hidden layers. */
    XList encoderLayers;

    /* layers on the decoder side */
    XList decoderLayers;
};

/* The predictor reads the current state and then predicts the next. 
   It is exactly the same procedure of MT inference -
   we get the state of previous words and then generate the next word.
   Here, a state can be regared as the representation of words (word 
   indices, hidden states, embeddings and etc.). */
class T2TPredictor
{
    /* pointer to the transformer model */
    T2TModel * m;

    /* current state */
    T2TStateBundle * cur;

public:
    /* constructor */
    T2TPredictor();

    /* de-constructor */
    ~T2TPredictor();

    /* read a state */
    void Read(T2TModel * model, T2TStateBundle * current);

    /* predict the next state */
    void Predict(T2TStateBundle * next);
};

}

#endif
