Commit d061d183 by xiaotong

improve the code of the attention model

parent a7223650
......@@ -101,22 +101,39 @@ make the network
>> isTraining - indicates whether the model is used for training
<< return - multi-attention result
*/
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining, bool selfatt)
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining)
{
XTensor k2;
XTensor q2;
XTensor v2;
if (selfatt){
/* linear transformation before self-attention */
k2 = MMul(k, wk);
q2 = MMul(q, wq);
v2 = MMul(v, wv);
XTensor con;
return MakeAttention(k2, q2, v2, mask, isTraining);
}
/*
make the network given a big tensor that keeps keys, queries and values
>> kqv - the big tensor
>> mask - as it is
>> isTraining - indicates whether the model is used for training
*/
XTensor T2TAttention::MakeBig(XTensor &kqv, XTensor &mask, bool isTraining)
{
XTensor k2;
XTensor q2;
XTensor v2;
XTensor kqv2;
XList split;
con = MMul(k, wbig);
kqv2 = MMul(kqv, wbig);
int d1 = con.GetDim(0);
int d2 = con.GetDim(1);
int d3 = con.GetDim(2) / 3;
int d1 = kqv2.GetDim(0);
int d2 = kqv2.GetDim(1);
int d3 = kqv2.GetDim(2) / 3;
InitTensor3D(&k2, d1, d2, d3, X_FLOAT, devID, mem);
InitTensor3D(&q2, d1, d2, d3, X_FLOAT, devID, mem);
......@@ -126,24 +143,31 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bo
split.Add(&k2);
split.Add(&v2);
Split(con, split, 2, 3);
}
Split(kqv2, split, 2, 3);
else{
/* linear transofmration before self-attention */
k2 = MMul(k, wk);
q2 = MMul(q, wq);
v2 = MMul(v, wv);
}
return MakeAttention(k2, q2, v2, mask, isTraining);
}
/*
make the attention network given keys, queries and values (after linear transformation)
>> k - keys. It might be of size B * L * H
where B = batch size, L = sequence length,
and H = vector size of each position
>> q - queries
>> v - values
>> mask - as it is
>> isTraining - indicates whether the model is used for training
*/
XTensor T2TAttention::MakeAttention(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining)
{
XTensor kheads;
XTensor qheads;
XTensor vheads;
/* multi head */
kheads = Split(k2, k2.order - 1, nhead);
qheads = Split(q2, q2.order - 1, nhead);
vheads = Split(v2, v2.order - 1, nhead);
kheads = Split(k, k.order - 1, nhead);
qheads = Split(q, q.order - 1, nhead);
vheads = Split(v, v.order - 1, nhead);
XTensor att;
XTensor dot;
......
......@@ -97,7 +97,13 @@ public:
int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining, bool selfatt);
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining);
/* make the network given a big tensor that keeps keys, queries and values */
XTensor MakeBig(XTensor &kqv, XTensor &mask, bool isTraining);
/* make the attention network given keys, queries and values (after linear transformation) */
XTensor MakeAttention(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining);
};
}
......
......@@ -119,7 +119,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/******************/
/* self attention */
att = attentions[i].Make(x, x, x, mask, isTraining, true);
att = attentions[i].MakeBig(x, mask, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
......@@ -133,7 +133,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/*****************************/
/* encoder-decoder attention */
ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, maskEncDec, isTraining, false);
ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, maskEncDec, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
......
......@@ -114,7 +114,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
XTensor res;
/* self attention */
att = attentions[i].Make(x, x, x, mask, isTraining, true);
att = attentions[i].MakeBig(x, mask, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论