T2TOutput.h 1.63 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
 */

#ifndef __T2TOUTPUT_H__
#define __T2TOUTPUT_H__

#include "../../tensor/function/FHeader.h"

using namespace nts;

namespace transformer
{
xiaotong committed
31 32
    
#define OUTPUT_NAME "output"
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

/* output layer */
class T2TOutput
{
public:
    /* device id */
    int devID;

    /* memory pool */
    XMem * mem;

    /* vocabulary size */
    int vSize;

    /* input vector size */
    int inSize;

    /* vector size of the linear transformation */
    int hSize;

    /* transformation matrix */
    XTensor w;

public:
    /* constructor */
    T2TOutput();

    /* de-constructor */
    ~T2TOutput();

    /* initialize the model */
64
    void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
65 66

    /* make the network */
xiaotong committed
67
    XTensor Make(XTensor &input);
68 69

    /* make the network (redefined output tensor) */
xiaotong committed
70
    void Make(XTensor &input, XTensor &output);
71 72 73 74 75
};


}

76
#endif