/* NiuTrans.NMT - an open-source neural machine translation system. * Copyright (C) 2020 NiuTrans Research. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31 * $Modified by: HU Chi (huchinlp@gmail.com) 2020-04, 2020-06 */ #include "Attention.h" #include "Embedding.h" #include "../Utility.h" #include "../../../tensor/core/CHeader.h" namespace nmt { /* constructor */ Attention::Attention() { nhead = -1; dk = -1; dv = -1; d = -1; } /* de-constructor */ Attention::~Attention() { } /* initialize the model >> config - the configurations of the network */ void Attention::InitModel(Config& config) { devID = config.devID; useRPR = config.useRPR; nhead = config.nhead; d = config.modelSize; dk = config.modelSize; dv = config.modelSize; maxRP = config.maxRP; dropoutP = config.attDropout; /* initialize the parameters */ InitTensor2D(&weightQ, d, d, X_FLOAT, devID); InitTensor1D(&biasQ, d, X_FLOAT, devID); InitTensor2D(&weightK, d, d, X_FLOAT, devID); InitTensor1D(&biasK, d, X_FLOAT, devID); InitTensor2D(&weightV, d, d, X_FLOAT, devID); InitTensor1D(&biasV, d, X_FLOAT, devID); if (useRPR) InitTensor2D(&RPEmbK, maxRP * 2 + 1, d / nhead, X_FLOAT, devID); InitTensor2D(&weightO, d, d, X_FLOAT, devID); InitTensor1D(&biasO, d, X_FLOAT, devID); float scale = 1.0F; _SetDataFanInOut(&weightK, scale); _SetDataFanInOut(&weightQ, scale); _SetDataFanInOut(&weightV, scale); _SetDataFanInOut(&weightO, scale); if (useRPR) _SetDataFanInOut(&RPEmbK, scale); biasQ.SetZeroAll(); biasO.SetZeroAll(); biasK.SetDataRand(-(DTYPE)sqrt(6.0F / d), (DTYPE)sqrt(6.0F / d)); biasV.SetDataRand(-(DTYPE)sqrt(6.0F / d), (DTYPE)sqrt(6.0F / d)); } /* make the network >> k - keys, B * L * H for encoders, B * 1 * H for decoders where B = batch size, L = sequence length, and H = vector size of each position >> q - queries, B * L * H >> v - values, B * L * H for encoders, B * 1 * H for decoders >> mask - as it is >> isTraining - indicates whether the model is used for training >> cache - decoder cache >> cacheType - type of cache, e.g., self-attention << return - multi-attention result */ XTensor Attention::Make(XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, Cache* cache, int attType) { const bool isEnc = (!cache) ? true : false; /* linear transformation before self-attention */ XTensor q2, k2, v2; q2 = MulAndShift(q, weightQ, biasQ); if (!cache || isTraining || !(cache->enable)) { /* self attention for encoder layers */ k2 = MulAndShift(k, weightK, biasK); v2 = MulAndShift(v, weightV, biasV); if (useRPR && attType == SELF_ATT) return MakeRPRAttention(k2, q2, v2, mask, isTraining, isEnc); return MakeAttention(k2, q2, v2, mask, isTraining); } else { if (attType == SELF_ATT) { k2 = MulAndShift(k, weightK, biasK); v2 = MulAndShift(v, weightV, biasV); /* if hit, we only concat the cache with the new token */ if (!cache->miss) { k2 = Concatenate(cache->key, k2, 1); v2 = Concatenate(cache->value, v2, 1); } cache->key = k2; cache->value = v2; cache->miss = false; if (useRPR) return MakeRPRAttention(cache->key, q2, cache->value, mask, isTraining, isEnc); return MakeAttention(cache->key, q2, cache->value, mask, isTraining); } else if (attType == EN_DE_ATT) { if (cache->miss) { cache->key = MulAndShift(k, weightK, biasK); cache->value = MulAndShift(v, weightV, biasV); cache->miss = false; } return MakeAttention(cache->key, q2, cache->value, mask, isTraining); } CheckNTErrors(0, "invalid cache type"); } } /* make the attention network given keys, queries and values (after linear transformation) >> k - keys, B * L * H >> q - queries, B * L * H >> v - values, B * L * H >> mask - as it is >> isTraining - indicates whether the model is used for training */ XTensor Attention::MakeAttention(XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining) { XTensor kheads; XTensor qheads; XTensor vheads; const auto dataType = k.dataType; /* multi head */ kheads = Split(k, k.order - 1, nhead); qheads = Split(q, q.order - 1, nhead); vheads = Split(v, v.order - 1, nhead); XTensor att; XTensor dot; XTensor scalar; /* Some operations may cause numerical overflow under FP16 including BMMul, Mask, Div and Softmax. So we need to cast the input to FP32 */ if (qheads.dataType == X_FLOAT16) { qheads = ConvertDataType(qheads, X_FLOAT); kheads = ConvertDataType(kheads, X_FLOAT); } /* scalar = softmax(Q * K^T / sqrt(dk)) * V */ dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS); if (mask) dot = dot + *mask; dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead)); scalar = Softmax(dot, -1); if (isTraining && dropoutP > 0) scalar = Dropout(scalar, dropoutP); if (vheads.dataType != scalar.dataType) vheads = ConvertDataType(vheads, scalar.dataType); att = BMMul(scalar, vheads); if (dataType != att.dataType) att = ConvertDataType(att, dataType); /* concatenate the heads */ return MulAndShift(Merge(att, att.order - 1), weightO, biasO); } /* make the attention network by incorporating the relative position representation with the given keys, queries and values (after linear transformation) >> k - keys, B * L * H >> q - queries, B * L * H >> v - values, B * L * H >> mask - as it is >> isTraining - indicates whether the model is used for training >> isEnc - indicates whether it is encoder */ XTensor Attention::MakeRPRAttention(XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, bool isEnc) { XTensor kheads; XTensor qheads; XTensor vheads; const int lenQ = q.GetDim(1); const int lenKV = k.GetDim(1); const auto dataType = k.dataType; /* multi head */ kheads = Split(k, k.order - 1, nhead); qheads = Split(q, q.order - 1, nhead); vheads = Split(v, v.order - 1, nhead); XTensor att; XTensor dot; XTensor scalar; XTensor embMatrix, relativeKey; /* generate the relative emb index (L_q, L_kv) */ embMatrix = GetRPEmbedding(lenQ, lenKV, maxRP, isEnc || isTraining); /* generate the relative key from the RPEmbK (L_q, L_kv, H/K) */ relativeKey = Gather(RPEmbK, embMatrix); if (qheads.dataType == X_FLOAT16) { qheads = ConvertDataType(qheads, X_FLOAT); kheads = ConvertDataType(kheads, X_FLOAT); relativeKey = ConvertDataType(relativeKey, X_FLOAT); } float scaling = (float)sqrt(d / nhead); qheads = ScaleAndShift(qheads, 1.0F / scaling); dot = RPDotProduct(qheads, kheads, relativeKey, true); if (mask) dot = dot + *mask; /* softmax */ scalar = Softmax(dot, -1); if (isTraining && dropoutP > 0) scalar = Dropout(scalar, dropoutP); if (vheads.dataType != scalar.dataType) vheads = ConvertDataType(vheads, scalar.dataType); /* generate the relative attention output (K, B, L_q, H/K) */ att = BMMul(scalar, vheads); if (dataType != att.dataType) att = ConvertDataType(att, dataType); /* concatenate the heads */ return MulAndShift(Merge(att, att.order - 1), weightO, biasO); } /* generate relative position embeddings >> lenQ - the length of query >> lenKV - the length of key and value >> maxRelativeLen - the maximum length of relative position */ XTensor Attention::GetRPEmbedding(const int lenQ, const int lenKV, const int maxRelativeLen, const bool isEnc) { XTensor range; XTensor embMatrix; InitTensor1D(&range, lenKV, X_INT, devID); int* index = new int[lenKV]; if (isEnc) { for (int i = 0; i < lenKV; i++) index[i] = i; range.SetData(index, lenKV); XTensor range2D; XTensor range2DTrans; range2D = Unsqueeze(range, 0, lenQ); range2DTrans = Transpose(range2D, 0, 1); embMatrix = Sum(range2D, range2DTrans, false, -1); } else { for (int i = 0; i < lenKV; i++) index[i] = -lenKV + i + 1; range.SetData(index, lenKV); embMatrix = Unsqueeze(range, 0, lenQ); } //ClipMe(embMatrix, -float(maxRelativeLen), float(maxRelativeLen)); embMatrix = Clip(embMatrix, -float(maxRelativeLen), float(maxRelativeLen)); embMatrix = ScaleAndShift(embMatrix, 1.0F, float(maxRelativeLen)); delete[] index; return embMatrix; } /* relative position-aware dot-product attention inner calculation. >> x - Tensor with shape [batch_size*heads, length, length or depth]. >> y - Tensor with shape [batch_size*heads, length, depth]. >> z - Tensor with shape [length, length, depth]. >> isKey - Whether y is key. << return - A Tensor with shape [batch_size*heads, length, length or depth]. */ XTensor Attention::RPDotProduct(XTensor& x, XTensor& y, XTensor& z, const bool isKey) { const int headNum = nhead; const int batchSize = x.GetDim(1); const int lenQ = x.GetDim(2); const int lenKV = y.GetDim(2); const int depth = y.GetDim(3); const int lastDim = isKey ? lenKV : depth; auto transposeFlag = isKey ? X_TRANS : X_NOTRANS; int mergeDimsX[] = { headNum * batchSize, lenQ, x.GetDim(3) }; int mergeDimsY[] = { headNum * batchSize, lenKV, y.GetDim(3) }; x = Reshape(x, 3, mergeDimsX); y = Reshape(y, 3, mergeDimsY); if (isKey) { y = Transpose(y, 1, 2); } XTensor context; context = BMMul(x, y); int newDims[]{ headNum, batchSize, context.GetDim(1), context.GetDim(2) }; context = Reshape(context, 4, newDims); XTensor xTrans; xTrans = Transpose(x, 0, 1); XTensor relative; relative = MatrixMulBatched(xTrans, X_NOTRANS, z, transposeFlag); XTensor relativeTrans; relativeTrans = Transpose(relative, 0, 1); int splitDims[] = { headNum, batchSize, lenQ, lastDim }; relativeTrans = Reshape(relativeTrans, 4, splitDims); return context + relativeTrans; } /* constructor */ Cache::Cache() { miss = true; enable = true; } /* update the states cache */ void Cache::Update(XTensor&& k, XTensor&& v) { key = k; value = v; miss = false; } /* keep alive states */ void Cache::KeepAlive(XTensor& aliveIdx) { if (!miss) { key = AutoGather(key, aliveIdx); value = AutoGather(value, aliveIdx); } } /* reorder alive states */ void Cache::Reorder(XTensor& reorder) { if (!miss) { key = AutoGather(key, reorder); value = AutoGather(value, reorder); } } }