Commit dd0ad421 by xiaotong

bug fixes of attention and layer normalization

parent fe37006f
...@@ -74,14 +74,17 @@ void T2TAttention::InitModel(int argc, char ** argv, ...@@ -74,14 +74,17 @@ void T2TAttention::InitModel(int argc, char ** argv,
InitTensor2D(&wk, d, dk, X_FLOAT, devID, mem); InitTensor2D(&wk, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wq, d, dk, X_FLOAT, devID, mem); InitTensor2D(&wq, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wv, d, dv, X_FLOAT, devID, mem); InitTensor2D(&wv, d, dv, X_FLOAT, devID, mem);
InitTensor2D(&wa, d, d, X_FLOAT, devID, mem);
float scale = 1.0F; float scale = 1.0F;
float finfoutk = (float)sqrt(6.0F * scale/(d + dk)); float finfoutk = (float)sqrt(6.0F * scale/(d + dk));
float finfoutv = (float)sqrt(6.0F * scale/(d + dv)); float finfoutv = (float)sqrt(6.0F * scale/(d + dv));
float finfouta = (float)sqrt(6.0F * scale / (d + d));
wk.SetDataRand(-finfoutk, finfoutk); wk.SetDataRand(-finfoutk, finfoutk);
wq.SetDataRand(-finfoutk, finfoutk); wq.SetDataRand(-finfoutk, finfoutk);
wv.SetDataRand(-finfoutv, finfoutv); wv.SetDataRand(-finfoutv, finfoutv);
wa.SetDataRand(-finfouta, finfouta);
} }
/* /*
...@@ -135,7 +138,7 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bo ...@@ -135,7 +138,7 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bo
att = BMMul(scalar, vheads); att = BMMul(scalar, vheads);
/* concatenate the heads */ /* concatenate the heads */
return Merge(att, att.order - 1); return MMul(Merge(att, att.order - 1), wa);
} }
} }
...@@ -57,6 +57,9 @@ public: ...@@ -57,6 +57,9 @@ public:
/* transformation matrix for V */ /* transformation matrix for V */
XTensor wv; XTensor wv;
/* transformation after dot-product attention */
XTensor wa;
/* size of transformed Q and K */ /* size of transformed Q and K */
int dk; int dk;
......
...@@ -33,6 +33,7 @@ T2TLN::T2TLN() ...@@ -33,6 +33,7 @@ T2TLN::T2TLN()
{ {
devID = -1; devID = -1;
mem = NULL; mem = NULL;
d = 0;
} }
/* de-constructor */ /* de-constructor */
...@@ -52,14 +53,14 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem) ...@@ -52,14 +53,14 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
devID = myDevID; devID = myDevID;
mem = myMem; mem = myMem;
int d = 0; d = 0;
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
InitTensor2D(&w, d, d, X_FLOAT, devID, mem); InitTensor1D(&w, d, X_FLOAT, devID, mem);
InitTensor1D(&b, d, X_FLOAT, devID, mem); InitTensor1D(&b, d, X_FLOAT, devID, mem);
float scale = 1.0F; float scale = 1.0F;
float finfout = (float)sqrt(6.0F * scale / (d + d)); float finfout = (float)sqrt(6.0F * scale / d);
w.SetDataRand(-finfout, finfout); w.SetDataRand(-finfout, finfout);
b.SetZeroAll(); b.SetZeroAll();
...@@ -97,10 +98,10 @@ XTensor T2TLN::Make(XTensor &input) ...@@ -97,10 +98,10 @@ XTensor T2TLN::Make(XTensor &input)
standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1)); standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1));
/* x' = (x - \mu)/standard */ /* x' = (x - \mu)/standard */
xn = (x - meanFilled)/standardFilled; xn = (x - meanFilled) / standardFilled;
/* result = x' * w + b */ /* result = x' * w + b */
return MMul(xn, w) + b; return xn * w + b;
} }
} }
...@@ -46,6 +46,9 @@ public: ...@@ -46,6 +46,9 @@ public:
/* the bias term b */ /* the bias term b */
XTensor b; XTensor b;
/* dimension size of the model */
int d;
public: public:
/* constructor */ /* constructor */
T2TLN(); T2TLN();
......
...@@ -174,6 +174,7 @@ void T2TModel::GetParams(XList &list) ...@@ -174,6 +174,7 @@ void T2TModel::GetParams(XList &list)
list.Add(&encoder.attentions[i].wk); list.Add(&encoder.attentions[i].wk);
list.Add(&encoder.attentions[i].wq); list.Add(&encoder.attentions[i].wq);
list.Add(&encoder.attentions[i].wv); list.Add(&encoder.attentions[i].wv);
list.Add(&encoder.attentions[i].wa);
list.Add(&encoder.fnnLayerNorms[i].w); list.Add(&encoder.fnnLayerNorms[i].w);
list.Add(&encoder.fnnLayerNorms[i].b); list.Add(&encoder.fnnLayerNorms[i].b);
list.Add(&encoder.attLayerNorms[i].w); list.Add(&encoder.attLayerNorms[i].w);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论