/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2019, Natural Language Processing Lab, Northestern University.
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13
 */

#include "T2TPredictor.h"
#include "../../tensor/core/CHeader.h"

using namespace nts;

namespace transformer
{

/* constructor */
T2TPredictor::T2TPredictor()
{
}

/* de-constructor */
T2TPredictor::~T2TPredictor()
{
}

/* 
read a state 
>> model - the t2t model that keeps the network created so far
>> current - a set of states. It keeps
             1) hypotheses (states)
             2) probablities of hypotheses
             3) parts of the network for expanding to the next state
*/
void T2TPredictor::Read(T2TModel * model, T2TStateBundle * current)
{
    m = model;
    cur = current;
}

/*
predict the next state
>> next - next states (assuming that the current state has been read)
*/
void T2TPredictor::Predict(T2TStateBundle * next)
{
    next->decoderLayers.Clear();
    next->encoderLayers.Clear();
    
    AttDecoder &decoder = *m->decoder;
    
    /* word indices of previous positions */
    XTensor &inputLast = *(XTensor*)cur->decoderLayers.GetItem(0);

    /* word indices of positions up to next state */
    XTensor &input = *NewTensor();
    input = Concatenate(inputLast, cur->prediction, inputLast.GetDim(-1));

    /* prediction probabilities */
    XTensor &output = next->prediction;

    /* encoder output */
    XTensor &outputEnc = *(XTensor*)cur->encoderLayers.GetItem(-1);

    /* empty tensors (for masking?) */
    XTensor nullMask;

    /* make the decoding network and generate the output probabilities */
    output = decoder.Make(cur->prediction, outputEnc, nullMask, nullMask, false);
    
    next->encoderLayers.AddList(&cur->encoderLayers);
    next->decoderLayers.Add(&input);
    next->decoderLayers.Add(&output);
}

}

