Commit 2058ca57 by xiaotong

create the attention model

parent c3eea060
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include "T2TAttention.h"
#include "T2TUtility.h"
namespace transformer
{
/* constructor */
T2TAttention::T2TAttention()
{
nhead = -1;
dk = -1;
dv = -1;
d = -1;
}
/* deconstructor */
T2TAttention::~T2TAttention()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list pf pointers to the arguments
>> devID - device id
>> mem - the memory pool
*/
void T2TAttention::InitModel(int argc, const char ** argv, int devID, XMem * mem)
{
float minmax = 0;
LoadParamInt(argc, argv, "nhead", &nhead, 8);
LoadParamInt(argc, argv, "dk", &dk, 512);
LoadParamInt(argc, argv, "dv", &dv, 512);
LoadParamInt(argc, argv, "d", &d, 512);
LoadParamFloat(argc, argv, "attminmax", &minmax, 0.08F);
InitTensor2D(&w, 3 * d, 2 * dk + dv, X_FLOAT, devID, mem);
w.SetDataRand(-minmax, minmax);
}
}
\ No newline at end of file
......@@ -22,12 +22,47 @@
#ifndef __T2TATTENTION_H__
#define __T2TATTENTION_H__
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
/*
multi-head attention
y(Q, K, V) = cat(head_1, head_2, ..., head_n)
where head_i = Attention(Q * w_i^Q, K * w_i^K, V * w_i^V)
attention(Q, K, V) = softmax(Q * K^T/d_k^0.5) V
d_k = dimension size of K
*/
class T2TAttention
{
public:
/* head number */
int nhead;
/* transformation matrix */
XTensor w;
/* size of transformed Q and K */
int dk;
/* size of transformed V */
int dv;
/* size of input Q, K and V */
int d;
public:
/* constructor */
T2TAttention();
/* de-constructor */
~T2TAttention();
/* initialize the model */
void InitModel(int argc, const char ** argv, int devID = -1, XMem * mem = NULL);
};
}
......
......@@ -49,8 +49,8 @@ void T2TFNN::InitModel(int argc, const char ** argv, int devID, XMem * mem)
{
float minmax = 0;
LoadParamInt(argc, argv, "fnnin", &inSize, 512);
LoadParamInt(argc, argv, "fnnout", &outSize, 512);
LoadParamInt(argc, argv, "d", &inSize, 512);
LoadParamInt(argc, argv, "d", &outSize, 512);
LoadParamInt(argc, argv, "fnnh", &hSize, 512);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.08F);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论