T2TPredictor.h 4.08 KB
Newer Older
xiaotong committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2019, Natural Language Processing Lab, Northestern University.
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

huchi committed
18 19 20 21
 /*
  * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13
  * This is the first source file I create in 2019 - new start!
  */
xiaotong committed
22 23 24 25 26

#ifndef __T2TPREDICTOR_H__
#define __T2TPREDICTOR_H__

#include "T2TModel.h"
27
#include "T2TLengthPenalty.h"
xiaotong committed
28 29 30 31

namespace transformer
{

32 33
#define T2T_PID_EMPTY -1

34
/* state for search. It keeps the path (back-pointer), prediction distribution,
xiaotong committed
35
   and etc. It can be regarded as a hypothsis in translation. */
xiaotong committed
36 37
class T2TState
{
xiaotong committed
38 39
public:
    /* we assume that the prediction is an integer */
40 41
    int prediction;

huchi committed
42 43
    /* id of the problem. One can regard it as the sentence id when we
       translate a number of sentences in the batched manner. The hypothesis
xiaotong committed
44
       is empty if id = -1 */
45 46
    int pid;

xiaotong committed
47
    /* indicates whether the state is an end */
xiaotong committed
48 49 50 51
    bool isEnd;

    /* indicates whether the state is the start */
    bool isStart;
xiaotong committed
52

53 54 55
    /* indicates whether the state is completed */
    bool isCompleted;

56 57
    /* probability of every prediction (last state of the path) */
    float prob;
58

59 60 61
    /* probability of every path */
    float probPath;

xiaotong committed
62
    /* model score of every path. A model score = path probability + some other stuff */
63 64 65 66
    float modelScore;

    /* nubmer of steps we go over so far */
    int nstep;
67

xiaotong committed
68
    /* pointer to the previous state */
huchi committed
69
    T2TState* last;
70 71 72 73 74
};

/* a bundle of states */
class T2TStateBundle
{
xiaotong committed
75 76 77
public:
    /* predictions */
    XTensor prediction;
huchi committed
78

79 80
    /* id of the previous state that generates the current one  */
    XTensor preID;
xiaotong committed
81

xiaotong committed
82 83 84
    /* mark that indicates whether each hypothesis is completed */
    XTensor endMark;

85 86 87 88 89 90 91 92
    /* probability of every prediction (last state of the path) */
    XTensor prob;

    /* probability of every path */
    XTensor probPath;

    /* model score of every path */
    XTensor modelScore;
93

94 95
    /* step number of each hypothesis */
    XTensor nstep;
96

xiaotong committed
97
    /* list of states */
huchi committed
98
    T2TState* states;
xiaotong committed
99

100 101 102
    /* number of states */
    int stateNum;

xiaotong committed
103 104 105
    /* indicates whether it is the first state */
    bool isStart;

xiaotong committed
106 107 108 109 110 111 112 113 114
public:
    /* constructor */
    T2TStateBundle();

    /* de-constructor */
    ~T2TStateBundle();

    /* create states */
    void MakeStates(int num);
xiaotong committed
115 116
};

huchi committed
117
/* The predictor reads the current state and then predicts the next.
xiaotong committed
118 119
   It is exactly the same procedure of MT inference -
   we get the state of previous words and then generate the next word.
huchi committed
120
   Here, a state can be regared as the representation of words (word
121
   indices, hidden states, embeddings and etc.).  */
xiaotong committed
122 123
class T2TPredictor
{
124
private:
xiaotong committed
125
    /* pointer to the transformer model */
huchi committed
126
    T2TModel* m;
xiaotong committed
127 128

    /* current state */
huchi committed
129
    T2TStateBundle* s;
xiaotong committed
130

131 132 133
    /* start symbol */
    int startSymbol;

xiaotong committed
134 135 136 137 138 139 140
public:
    /* constructor */
    T2TPredictor();

    /* de-constructor */
    ~T2TPredictor();

xiaotong committed
141
    /* create an initial state */
huchi committed
142
    void Create(T2TModel* model, XTensor* top, const XTensor* input, int beamSize, T2TStateBundle* state);
xiaotong committed
143

144 145 146
    /* set the start symbol */
    void SetStartSymbol(int symbol);

xiaotong committed
147
    /* read a state */
huchi committed
148
    void Read(T2TModel* model, T2TStateBundle* state);
xiaotong committed
149 150

    /* predict the next state */
huchi committed
151
    void Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inputEnc, XTensor* paddingEnc, bool isStart);
152 153

    /* generate paths up to the states of the current step */
huchi committed
154
    XTensor GeneratePaths(T2TStateBundle* state);
155 156 157

    /* get the predictions of the previous step */
    XTensor GetLastPrediction(T2TStateBundle* state);
xiaotong committed
158 159 160 161 162
};

}

#endif