T2TSearch.h 2.73 KB
Newer Older
xiaotong committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2019, Natural Language Processing Lab, Northestern University.
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27
 */

#ifndef __T2TSEARCH_H__
#define __T2TSEARCH_H__

#include "T2TModel.h"
xiaotong committed
26
#include "T2TPredictor.h"
xiaotong committed
27 28 29

namespace transformer
{
xiaotong committed
30

xiaotong committed
31
/* The class orgnizes the search process. It calls "predictors" to generate
xiaotong committed
32
   distributions of the predictions and prunes the search space by beam pruning.
xiaotong committed
33
   This makes a graph where each path respresents a translation hypothsis.
xiaotong committed
34 35 36
   The output can be the path with the highest model score. */
class T2TSearch
{
37
private:
38 39 40
    /* the alpha parameter controls the length preference */
    float alpha;

41 42 43
    /* predictor */
    T2TPredictor predictor;
    
xiaotong committed
44 45
    /* max length of the generated sequence */
    int maxLength;
46 47 48
    
    /* beam size */
    int beamSize;
xiaotong committed
49

xiaotong committed
50 51 52
    /* batch size */
    int batchSize;

53
    /* we keep the final hypotheses in a heap for each sentence in the batch. */
xiaotong committed
54
    XHeap<MIN_HEAP, float> * fullHypos;
55

xiaotong committed
56 57 58 59 60 61
    /* array of the end symbols */
    int * endSymbols;

    /* number of the end symbols */
    int endSymbolNum;

xiaotong committed
62
public:
xiaotong committed
63
    /* constructor */
xiaotong committed
64
    T2TSearch();
xiaotong committed
65 66

    /* de-constructor */
xiaotong committed
67
    ~T2TSearch();
68 69
    
    /* initialize the model */
xiaotong committed
70
    void Init(int argc, char ** argv);
xiaotong committed
71

xiaotong committed
72
    /* search for the most promising states */
73
    void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
xiaotong committed
74

75
    /* preparation */
xiaotong committed
76
    void Prepare(int myBatchSize,int myBeamSize);
77

78 79 80
    /* compute the model score for each hypothesis */
    void Score(T2TStateBundle * prev, T2TStateBundle * beam);

xiaotong committed
81 82
    /* generate token indices via beam pruning */
    void Generate(T2TStateBundle * beam);
xiaotong committed
83

xiaotong committed
84 85 86
    /* expand the search graph */
    void Expand(T2TStateBundle * prev, T2TStateBundle * beam);

87 88 89
    /* collect hypotheses with ending symbol */
    void Collect(T2TStateBundle * beam);

xiaotong committed
90
    /* save the output sequences in a tensor */
xiaotong committed
91
    void Dump(XTensor * output);
xiaotong committed
92 93 94 95 96 97

    /* check if the token is an end symbol */
    bool IsEnd(int token);

    /* set end symbols for search */
    void SetEnd(const int * tokens, const int tokenNum);
xiaotong committed
98 99
};

xiaotong committed
100 101
}

102
#endif