Commit d061d183 by xiaotong

improve the code of the attention model

parent a7223650
...@@ -101,69 +101,93 @@ make the network ...@@ -101,69 +101,93 @@ make the network
>> isTraining - indicates whether the model is used for training >> isTraining - indicates whether the model is used for training
<< return - multi-attention result << return - multi-attention result
*/ */
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining, bool selfatt) XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining)
{ {
XTensor k2; XTensor k2;
XTensor q2; XTensor q2;
XTensor v2; XTensor v2;
if (selfatt){ /* linear transformation before self-attention */
k2 = MMul(k, wk);
XTensor con; q2 = MMul(q, wq);
XList split; v2 = MMul(v, wv);
con = MMul(k, wbig); return MakeAttention(k2, q2, v2, mask, isTraining);
}
int d1 = con.GetDim(0);
int d2 = con.GetDim(1); /*
int d3 = con.GetDim(2) / 3; make the network given a big tensor that keeps keys, queries and values
>> kqv - the big tensor
InitTensor3D(&k2, d1, d2, d3, X_FLOAT, devID, mem); >> mask - as it is
InitTensor3D(&q2, d1, d2, d3, X_FLOAT, devID, mem); >> isTraining - indicates whether the model is used for training
InitTensor3D(&v2, d1, d2, d3, X_FLOAT, devID, mem); */
XTensor T2TAttention::MakeBig(XTensor &kqv, XTensor &mask, bool isTraining)
split.Add(&q2); {
split.Add(&k2); XTensor k2;
split.Add(&v2); XTensor q2;
XTensor v2;
Split(con, split, 2, 3); XTensor kqv2;
} XList split;
else{ kqv2 = MMul(kqv, wbig);
/* linear transofmration before self-attention */
k2 = MMul(k, wk); int d1 = kqv2.GetDim(0);
q2 = MMul(q, wq); int d2 = kqv2.GetDim(1);
v2 = MMul(v, wv); int d3 = kqv2.GetDim(2) / 3;
}
InitTensor3D(&k2, d1, d2, d3, X_FLOAT, devID, mem);
InitTensor3D(&q2, d1, d2, d3, X_FLOAT, devID, mem);
InitTensor3D(&v2, d1, d2, d3, X_FLOAT, devID, mem);
split.Add(&q2);
split.Add(&k2);
split.Add(&v2);
Split(kqv2, split, 2, 3);
return MakeAttention(k2, q2, v2, mask, isTraining);
}
/*
make the attention network given keys, queries and values (after linear transformation)
>> k - keys. It might be of size B * L * H
where B = batch size, L = sequence length,
and H = vector size of each position
>> q - queries
>> v - values
>> mask - as it is
>> isTraining - indicates whether the model is used for training
*/
XTensor T2TAttention::MakeAttention(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining)
{
XTensor kheads; XTensor kheads;
XTensor qheads; XTensor qheads;
XTensor vheads; XTensor vheads;
/* multi head */ /* multi head */
kheads = Split(k2, k2.order - 1, nhead); kheads = Split(k, k.order - 1, nhead);
qheads = Split(q2, q2.order - 1, nhead); qheads = Split(q, q.order - 1, nhead);
vheads = Split(v2, v2.order - 1, nhead); vheads = Split(v, v.order - 1, nhead);
XTensor att; XTensor att;
XTensor dot; XTensor dot;
XTensor scalar; XTensor scalar;
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */ /* scalar = softmax(Q * K^T / sqrt(dk)) * V */
dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS); dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
if(isMasked) if(isMasked)
dot = dot + mask; dot = dot + mask;
dot = Linear(dot, 1.0F/(float)sqrt((float)dk/nhead)); dot = Linear(dot, 1.0F/(float)sqrt((float)dk/nhead));
scalar = Softmax(dot, -1); scalar = Softmax(dot, -1);
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
scalar = Dropout(scalar, dropoutP); scalar = Dropout(scalar, dropoutP);
att = BMMul(scalar, vheads); att = BMMul(scalar, vheads);
/* concatenate the heads */ /* concatenate the heads */
return MMul(Merge(att, att.order - 1), wa); return MMul(Merge(att, att.order - 1), wa);
} }
......
...@@ -97,7 +97,13 @@ public: ...@@ -97,7 +97,13 @@ public:
int myDevID = -1, XMem * myMem = NULL); int myDevID = -1, XMem * myMem = NULL);
/* make the network */ /* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining, bool selfatt); XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining);
/* make the network given a big tensor that keeps keys, queries and values */
XTensor MakeBig(XTensor &kqv, XTensor &mask, bool isTraining);
/* make the attention network given keys, queries and values (after linear transformation) */
XTensor MakeAttention(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining);
}; };
} }
......
...@@ -119,7 +119,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X ...@@ -119,7 +119,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/******************/ /******************/
/* self attention */ /* self attention */
att = attentions[i].Make(x, x, x, mask, isTraining, true); att = attentions[i].MakeBig(x, mask, isTraining);
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
...@@ -133,7 +133,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X ...@@ -133,7 +133,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/*****************************/ /*****************************/
/* encoder-decoder attention */ /* encoder-decoder attention */
ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, maskEncDec, isTraining, false); ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, maskEncDec, isTraining);
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
......
...@@ -114,7 +114,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo ...@@ -114,7 +114,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
XTensor res; XTensor res;
/* self attention */ /* self attention */
att = attentions[i].Make(x, x, x, mask, isTraining, true); att = attentions[i].MakeBig(x, mask, isTraining);
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论