T2TEmbedding.cpp 4.57 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-01
 */

#include <math.h>
#include "T2TEmbedding.h"
#include "T2TUtility.h"
#include "../../tensor/core/CHeader.h"

namespace transformer
{

/* constructor */
T2TEmbedder::T2TEmbedder()
{
    devID = -1;
    mem = NULL;
    vSize = -1;
    maxLength = -1;
}

/* deconstructor */
T2TEmbedder::~T2TEmbedder()
{
}

/* 
initialize the model 
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
51
void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, XMem * myMem, bool isEnc)
52 53 54 55
{
    devID = myDevID;
    mem = myMem;
    
56 57 58 59 60 61 62
    if(isEnc){
        LoadParamInt(argc, argv, "vsize", &vSize, -1);
    }
    else{
        LoadParamInt(argc, argv, "vsizetgt", &vSize, -1);
    }
    //LoadParamInt(argc, argv, "vsize", &vSize, -1);
63 64 65
    LoadParamInt(argc, argv, "maxlen", &maxLength, 512);
    LoadParamInt(argc, argv, "d", &eSize, DEFAULT_EMBEDDING_SIZE);
    LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
66 67 68

    InitTensor2D(&w, vSize, eSize, X_FLOAT, devID, mem);

69
    DTYPE v = 1.0F/(float)sqrt((float)eSize);
70
    w.SetDataRandn(0, v);
71 72 73 74 75 76

    /* create the positional embedding matrix */
    MakePosEmbedding(eSize, d, maxLength);
}

/* 
77 78 79 80
make positional embeddings (of size eSize * length)
>> eSize - embedding size
>> d - dimension size of the hidden layers
>> length - length of the sequence
81 82 83
*/
void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
{
84
    InitTensor2D(&posEmbeddingBase, length, eSize, X_FLOAT, devID, mem);
85

86
    float * data = new float[posEmbeddingBase.unitNum];
87 88 89

    for(int pos = 0; pos < length; pos++){
        float * dp = data + pos * eSize;
90 91 92 93 94 95 96 97 98 99 100
        
        int channelSize = eSize / 2;
        int offset = 0;
        for(int i = 0; i < channelSize; i++){
            dp[offset++] = (float)sin(pos/pow(10000.0F, 2.0F*i/(d - 2)));
        }
        for(int i = 0; i < channelSize; i++){
            dp[offset++] = (float)cos(pos/pow(10000.0F, 2.0F*i/(d - 2)));
        }

        /*
101 102 103
        for(int k = 0; k < eSize; k++){
            if(k % 2 == 0){
                int i = k/2;
104
                dp[k] = (float)sin(pos/pow(10000.0F, 2.0F*i/d));
105 106 107
            }
            else{
                int i = (k - 1)/2;
108
                dp[k] = (float)cos(pos/pow(10000.0F, 2.0F*i/d));
109 110
            }
        }
111
        */
112 113
    }

114
    posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum);
115 116 117 118 119 120 121

    delete[] data;
}

/* 
make the network 
*/
122
XTensor T2TEmbedder::Make(XTensor &input)
123
{
124
    //CheckNTErrors(input.GetDim(-1) == vSize, "Wrong vocabulary size!");
125
    CheckNTErrors(input.order > 1, "Wrong input tensor size!");
126
    CheckNTErrors(input.dimSize[input.order - 1] < maxLength, "The sequence is too long!");
127 128
    CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\"");
    CheckNTErrors(eSize > 0, "set embedding size by \"-esize\"");
129 130

    int dims[MAX_TENSOR_DIM_NUM];
131
    memcpy(dims, input.dimSize, input.order * sizeof(int));
132
    dims[input.order] = eSize;
133

134 135 136
    XTensor wordEmbedding;
    XTensor posEmbedding;

137
    bool match = (posEmbedding.order == input.order);
138
    if(match){
139
        for(int i = 0; i < input.order; i++){
140 141 142 143 144 145
            if(dims[i] != posEmbedding.GetDim(i))
                match = false;
        }
    }

    /* we make positional embeddings first */
146 147
    //if(!match){
    if(true){
148 149
        InitTensor(&posEmbedding, input.order + 1, dims, X_FLOAT, 1.0F, devID, mem);

150
        XTensor * posTMP = NewTensorBuf(2, dims + 1, X_FLOAT, 1.0F, devID, mem);
151

152 153
        _CopyValues(&posEmbeddingBase, 0, posTMP->unitNum, posTMP, 0);
        _Unsqueeze(posTMP, &posEmbedding, 0, dims[0]);
154 155 156 157 158

        DelTensorBuf(posTMP);
    }

    /* then we make word embeddings */
159 160
    wordEmbedding = Gather(w, input);
    wordEmbedding = Linear(wordEmbedding, (float)sqrt((float)eSize));
161 162

    /* we sum over the two embeddings */
xuchen committed
163
    return wordEmbedding + posEmbedding;
164 165 166
}

}