/* NiuTrans.Tensor - an open-source tensor library * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. * All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31 */ #include <math.h> #include "T2TAttention.h" #include "T2TUtility.h" #include "T2TEmbedding.h" #include "../../tensor/core/CHeader.h" namespace transformer { /* constructor */ T2TAttention::T2TAttention() { nhead = -1; dk = -1; dv = -1; d = -1; isMasked = false; ignored = 0; } /* deconstructor */ T2TAttention::~T2TAttention() { } /* initialize the model >> argc - number of arguments >> argv - list of pointers to the arguments >> myIgnored - number of position ignored in attention (from the begining) >> myIsMasked - indicates whether the attention is with a mask >> myDevID - device id >> myMem - the memory pool */ void T2TAttention::InitModel(int argc, char ** argv, bool myIsMasked, int myIgnored, int myDevID, XMem * myMem) { devID = myDevID; mem = myMem; isMasked = myIsMasked; ignored = myIgnored; float minmax = 0; LoadParamInt(argc, argv, "nhead", &nhead, 8); LoadParamInt(argc, argv, "d", &dk, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &dv, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE); LoadParamFloat(argc, argv, "attminmax", &minmax, 0.1F); LoadParamFloat(argc, argv, "dropoutatt", &dropoutP, 0); InitTensor2D(&wk, d, dk, X_FLOAT, devID, mem); InitTensor2D(&wq, d, dk, X_FLOAT, devID, mem); InitTensor2D(&wv, d, dv, X_FLOAT, devID, mem); InitTensor2D(&wa, d, d, X_FLOAT, devID, mem); float scale = 1.0F; float finfoutk = (float)sqrt(6.0F * scale/(d + dk)); float finfoutv = (float)sqrt(6.0F * scale/(d + dv)); float finfouta = (float)sqrt(6.0F * scale / (d + d)); wk.SetDataRand(-finfoutk, finfoutk); wq.SetDataRand(-finfoutk, finfoutk); wv.SetDataRand(-finfoutv, finfoutv); wa.SetDataRand(-finfouta, finfouta); } /* make the network >> k - keys. It might be of size B * L * H where B = batch size, L = sequence length, and H = vector size of each position >> q - queries >> v - values >> mask - as it is >> isTraining - indicates whether the model is used for training << return - multi-attention result */ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining) { XTensor k2; XTensor q2; XTensor v2; /* linear transofmration before self-attention */ k2 = MMul(k, wk); q2 = MMul(q, wq); v2 = MMul(v, wv); XTensor kheads; XTensor qheads; XTensor vheads; /* multi head */ kheads = Split(k2, k2.order - 1, nhead); qheads = Split(q2, q2.order - 1, nhead); vheads = Split(v2, v2.order - 1, nhead); XTensor att; XTensor dot; XTensor scalar; /* scalar = softmax(Q * K^T / sqrt(dk)) * V */ dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS); if(isMasked) dot = dot + mask; dot = Linear(dot, 1.0F/(float)sqrt((float)dk/nhead)); scalar = Softmax(dot, -1); if(isTraining && dropoutP > 0) scalar = Dropout(scalar, dropoutP); att = BMMul(scalar, vheads); /* concatenate the heads */ return MMul(Merge(att, att.order - 1), wa); } }