T2TPredictor.h 3.79 KB
Newer Older
xiaotong committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2019, Natural Language Processing Lab, Northestern University.
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13
 * This is the first source file I create in 2019 - new start!
 */

#ifndef __T2TPREDICTOR_H__
#define __T2TPREDICTOR_H__

#include "T2TModel.h"
27
#include "T2TLengthPenalty.h"
xiaotong committed
28 29 30 31

namespace transformer
{

32 33
#define T2T_PID_EMPTY -1

34
/* state for search. It keeps the path (back-pointer), prediction distribution,
xiaotong committed
35
   and etc. It can be regarded as a hypothsis in translation. */
xiaotong committed
36 37
class T2TState
{
xiaotong committed
38 39
public:
    /* we assume that the prediction is an integer */
40 41
    int prediction;

42 43 44 45 46
    /* id of the problem. One can regard as the sentence id when we 
       translated a number of sentences in the batched manner. It is 
       an empty hypothesis if id = -1 */
    int pid;

xiaotong committed
47 48 49
    /* indicates whether the state is an end */
    int isEnd;

50 51
    /* probability of every prediction (last state of the path) */
    float prob;
52

53 54 55 56 57 58 59 60
    /* probability of every path */
    float probPath;

    /* model score of every path */
    float modelScore;

    /* nubmer of steps we go over so far */
    int nstep;
61

xiaotong committed
62
    /* pointer to the previous state */
63 64 65 66 67 68
    T2TState * last;
};

/* a bundle of states */
class T2TStateBundle
{
xiaotong committed
69 70 71
public:
    /* predictions */
    XTensor prediction;
72 73 74
    
    /* id of the previous state that generates the current one  */
    XTensor preID;
xiaotong committed
75

xiaotong committed
76 77 78
    /* mark that indicates whether each hypothesis is completed */
    XTensor endMark;

79 80 81 82 83 84 85 86
    /* probability of every prediction (last state of the path) */
    XTensor prob;

    /* probability of every path */
    XTensor probPath;

    /* model score of every path */
    XTensor modelScore;
87

88 89
    /* step number of each hypothesis */
    XTensor nstep;
90

xiaotong committed
91 92
    /* layers on the encoder side. We actually use the encoder output instead
       of all hidden layers. */
93
    XList layersEnc;
xiaotong committed
94 95

    /* layers on the decoder side */
96
    XList layersDec;
xiaotong committed
97 98 99 100

    /* list of states */
    T2TState * states;

101 102 103
    /* number of states */
    int stateNum;

xiaotong committed
104 105 106 107 108 109 110 111 112
public:
    /* constructor */
    T2TStateBundle();

    /* de-constructor */
    ~T2TStateBundle();

    /* create states */
    void MakeStates(int num);
xiaotong committed
113 114 115 116 117 118
};

/* The predictor reads the current state and then predicts the next. 
   It is exactly the same procedure of MT inference -
   we get the state of previous words and then generate the next word.
   Here, a state can be regared as the representation of words (word 
119
   indices, hidden states, embeddings and etc.).  */
xiaotong committed
120 121
class T2TPredictor
{
122
private:
xiaotong committed
123
    /* pointer to the transformer model */
xiaotong committed
124 125 126
    T2TModel * m;

    /* current state */
xiaotong committed
127
    T2TStateBundle * s;
xiaotong committed
128 129 130 131 132 133 134 135

public:
    /* constructor */
    T2TPredictor();

    /* de-constructor */
    ~T2TPredictor();

xiaotong committed
136
    /* create an initial state */
137
    void Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state);
xiaotong committed
138

xiaotong committed
139
    /* read a state */
xiaotong committed
140
    void Read(T2TModel * model, T2TStateBundle * state);
xiaotong committed
141 142

    /* predict the next state */
143
    void Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc);
144 145 146

    /* generate paths up to the states of the current step */
    XTensor GeneratePaths(T2TStateBundle * state);
xiaotong committed
147 148 149 150 151
};

}

#endif