T2TOutput.cpp 2.38 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
 */

22
#include <math.h>
23 24
#include "T2TOutput.h"
#include "T2TUtility.h"
25
#include "T2TEmbedding.h"
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
#include "../../tensor/core/CHeader.h"

namespace transformer
{
/* constructor */
T2TOutput::T2TOutput()
{
    devID = -1;
    mem = NULL;
    vSize = -1;
    inSize = -1;
    hSize = -1;
}

/* de-constructor */
T2TOutput::~T2TOutput()
{
}

/*
initialize the model 
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
52
void T2TOutput::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
53 54 55 56
{
    devID = myDevID;
    mem = myMem;

57 58
    float minmax = 0;

59
    LoadParamInt(argc, argv, "vsizetgt", &vSize, -1);
60 61
    LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
    LoadParamInt(argc, argv, "d", &hSize, DEFAULT_EMBEDDING_SIZE);
62
    LoadParamFloat(argc, argv, "outputminmax", &minmax, 0.08F);
63

64
    InitTensor2D(&w, hSize, vSize, X_FLOAT, devID, mem);
65 66 67 68
    
    float scale = 1.0F;
    float finfout = (float)sqrt(6.0F * scale/(hSize + vSize));
    w.SetDataRand(-finfout, finfout);
69 70 71

    DTYPE v = 1.0F/(float)sqrt((float)hSize);
    w.SetDataRandn(0, v);
72
}
73 74 75 76 77 78 79

/* 
make the network 
y = softmax(x * w)
>> input - input tensor
<< return - output tensor 
*/
80
XTensor T2TOutput::Make(XTensor &input)
81
{
82
    XTensor &x = input;
83

84
    return LogSoftmax(MMul(x, w), -1);
85 86 87 88 89 90 91
}

/* 
make the network (redefined output tensor) 
>> input - input tensor
>> output - output tensor 
*/
92
void T2TOutput::Make(XTensor &input, XTensor &output)
93
{
94
    XTensor &x = input;
95

96 97
    //output = LogSoftmax(MMul(x, w), -1);
    output = Softmax(MMul(x, w), -1);
xiaotong committed
98
    output.SetName(OUTPUT_NAME);
99 100
}

101
}