T2TAttention.cpp 5.36 KB
Newer Older
xiaotong committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
 */

22
#include <math.h>
xiaotong committed
23 24
#include "T2TAttention.h"
#include "T2TUtility.h"
25
#include "T2TEmbedding.h"
26
#include "../../tensor/core/CHeader.h"
xiaotong committed
27 28 29 30 31 32 33 34 35 36 37

namespace transformer
{

/* constructor */
T2TAttention::T2TAttention()
{
    nhead = -1;
    dk = -1;
    dv = -1;
    d  = -1;
xiaotong committed
38
    isMasked = false;
39
    ignored = 0;
xiaotong committed
40 41 42 43 44 45 46 47 48 49
}

/* deconstructor */
T2TAttention::~T2TAttention()
{
}

/* 
initialize the model 
>> argc - number of arguments
xiaotong committed
50
>> argv - list of pointers to the arguments
51
>> myIgnored - number of position ignored in attention (from the begining)
xiaotong committed
52
>> myIsMasked - indicates whether the attention is with a mask
53 54
>> myDevID - device id
>> myMem - the memory pool
xiaotong committed
55
*/
56
void T2TAttention::InitModel(int argc, char ** argv, 
57 58
                             bool myIsMasked, int myIgnored, 
                             int myDevID, XMem * myMem)
xiaotong committed
59
{
60 61
    devID = myDevID;
    mem = myMem;
62 63
    isMasked = myIsMasked;
    ignored = myIgnored;
64
    
xiaotong committed
65 66 67
    float minmax = 0;

    LoadParamInt(argc, argv, "nhead", &nhead, 8);
xiaotong committed
68 69 70
    LoadParamInt(argc, argv, "d", &dk, DEFAULT_EMBEDDING_SIZE);
    LoadParamInt(argc, argv, "d", &dv, DEFAULT_EMBEDDING_SIZE);
    LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
xiaotong committed
71
    LoadParamFloat(argc, argv, "attminmax", &minmax, 0.1F);
72
    LoadParamFloat(argc, argv, "dropoutatt", &dropoutP, 0);
xiaotong committed
73

74 75 76
    InitTensor2D(&wk, d, dk, X_FLOAT, devID, mem);
    InitTensor2D(&wq, d, dk, X_FLOAT, devID, mem);
    InitTensor2D(&wv, d, dv, X_FLOAT, devID, mem);
77
    InitTensor2D(&wa, d, d, X_FLOAT, devID, mem);
78 79
    InitTensor2D(&wbig, d, 3 * d, X_FLOAT, devID, mem);

80 81 82
    float scale = 1.0F;
    float finfoutk = (float)sqrt(6.0F * scale/(d + dk));
    float finfoutv = (float)sqrt(6.0F * scale/(d + dv));
83
    float finfouta = (float)sqrt(6.0F * scale / (d + d));
84
    float finfoutbig = (float)sqrt(6.0F * scale / (d + 3*d));
xiaotong committed
85

86 87 88
    wk.SetDataRand(-finfoutk, finfoutk);
    wq.SetDataRand(-finfoutk, finfoutk);
    wv.SetDataRand(-finfoutv, finfoutv);
89
    wa.SetDataRand(-finfouta, finfouta);
90
    wbig.SetDataRand(-finfoutbig, finfoutbig);
91 92 93 94 95 96 97 98 99
}

/* 
make the network 
>> k - keys. It might be of size B * L * H
       where B = batch size, L = sequence length, 
       and H = vector size of each position
>> q - queries
>> v - values
100 101
>> mask - as it is
>> isTraining - indicates whether the model is used for training
102 103
<< return - multi-attention result
*/
104
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining)
105 106 107 108
{
    XTensor k2;
    XTensor q2;
    XTensor v2;
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
    
    /* linear transformation before self-attention */
    k2 = MMul(k, wk);
    q2 = MMul(q, wq);
    v2 = MMul(v, wv);
    
    return MakeAttention(k2, q2, v2, mask, isTraining);
}
    
/*
make the network given a big tensor that keeps keys, queries and values
>> kqv - the big tensor
>> mask - as it is
>> isTraining - indicates whether the model is used for training
*/
XTensor T2TAttention::MakeBig(XTensor &kqv, XTensor &mask, bool isTraining)
{
    XTensor k2;
    XTensor q2;
    XTensor v2;
    XTensor kqv2;
130
    TensorList split;
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
    
    kqv2 = MMul(kqv, wbig);
    
    int d1 = kqv2.GetDim(0);
    int d2 = kqv2.GetDim(1);
    int d3 = kqv2.GetDim(2) / 3;
    
    InitTensor3D(&k2, d1, d2, d3, X_FLOAT, devID, mem);
    InitTensor3D(&q2, d1, d2, d3, X_FLOAT, devID, mem);
    InitTensor3D(&v2, d1, d2, d3, X_FLOAT, devID, mem);
    
    split.Add(&q2);
    split.Add(&k2);
    split.Add(&v2);
    
    Split(kqv2, split, 2, 3);
    
    return MakeAttention(k2, q2, v2, mask, isTraining);
}
    
/*
make the attention network given keys, queries and values (after linear transformation)
>> k - keys. It might be of size B * L * H
       where B = batch size, L = sequence length,
       and H = vector size of each position
>> q - queries
>> v - values
>> mask - as it is
>> isTraining - indicates whether the model is used for training
*/
XTensor T2TAttention::MakeAttention(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining)
{
163 164 165
    XTensor kheads;
    XTensor qheads;
    XTensor vheads;
166
    
167
    /* multi head */
168 169 170 171
    kheads = Split(k, k.order - 1, nhead);
    qheads = Split(q, q.order - 1, nhead);
    vheads = Split(v, v.order - 1, nhead);
    
172
    XTensor att;
xiaotong committed
173
    XTensor dot;
174
    XTensor scalar;
175
    
176
    /* scalar = softmax(Q * K^T / sqrt(dk)) * V */
xiaotong committed
177
    dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
178
    
xiaotong committed
179 180
    if(isMasked)
        dot = dot + mask;
181
    
xiaotong committed
182
    dot = Linear(dot, 1.0F/(float)sqrt((float)dk/nhead));
183
    
xiaotong committed
184
    scalar = Softmax(dot, -1);
xiaotong committed
185

186 187
    if(isTraining && dropoutP > 0)
        scalar = Dropout(scalar, dropoutP);
188
    
189
    att = BMMul(scalar, vheads);
190
    
191
    /* concatenate the heads */
192
    return MMul(Merge(att, att.order - 1), wa);
xiaotong committed
193 194
}

195
}