Commit dd0ad421 by xiaotong

bug fixes of attention and layer normalization

parent fe37006f
...@@ -74,14 +74,17 @@ void T2TAttention::InitModel(int argc, char ** argv, ...@@ -74,14 +74,17 @@ void T2TAttention::InitModel(int argc, char ** argv,
InitTensor2D(&wk, d, dk, X_FLOAT, devID, mem); InitTensor2D(&wk, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wq, d, dk, X_FLOAT, devID, mem); InitTensor2D(&wq, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wv, d, dv, X_FLOAT, devID, mem); InitTensor2D(&wv, d, dv, X_FLOAT, devID, mem);
InitTensor2D(&wa, d, d, X_FLOAT, devID, mem);
float scale = 1.0F; float scale = 1.0F;
float finfoutk = (float)sqrt(6.0F * scale/(d + dk)); float finfoutk = (float)sqrt(6.0F * scale/(d + dk));
float finfoutv = (float)sqrt(6.0F * scale/(d + dv)); float finfoutv = (float)sqrt(6.0F * scale/(d + dv));
float finfouta = (float)sqrt(6.0F * scale / (d + d));
wk.SetDataRand(-finfoutk, finfoutk); wk.SetDataRand(-finfoutk, finfoutk);
wq.SetDataRand(-finfoutk, finfoutk); wq.SetDataRand(-finfoutk, finfoutk);
wv.SetDataRand(-finfoutv, finfoutv); wv.SetDataRand(-finfoutv, finfoutv);
wa.SetDataRand(-finfouta, finfouta);
} }
/* /*
...@@ -135,7 +138,7 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bo ...@@ -135,7 +138,7 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bo
att = BMMul(scalar, vheads); att = BMMul(scalar, vheads);
/* concatenate the heads */ /* concatenate the heads */
return Merge(att, att.order - 1); return MMul(Merge(att, att.order - 1), wa);
} }
} }
...@@ -57,6 +57,9 @@ public: ...@@ -57,6 +57,9 @@ public:
/* transformation matrix for V */ /* transformation matrix for V */
XTensor wv; XTensor wv;
/* transformation after dot-product attention */
XTensor wa;
/* size of transformed Q and K */ /* size of transformed Q and K */
int dk; int dk;
......
...@@ -32,7 +32,8 @@ namespace transformer ...@@ -32,7 +32,8 @@ namespace transformer
T2TLN::T2TLN() T2TLN::T2TLN()
{ {
devID = -1; devID = -1;
mem = NULL; mem = NULL;
d = 0;
} }
/* de-constructor */ /* de-constructor */
...@@ -52,23 +53,23 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem) ...@@ -52,23 +53,23 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
devID = myDevID; devID = myDevID;
mem = myMem; mem = myMem;
int d = 0; d = 0;
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
InitTensor2D(&w, d, d, X_FLOAT, devID, mem); InitTensor1D(&w, d, X_FLOAT, devID, mem);
InitTensor1D(&b, d, X_FLOAT, devID, mem); InitTensor1D(&b, d, X_FLOAT, devID, mem);
float scale = 1.0F; float scale = 1.0F;
float finfout = (float)sqrt(6.0F * scale / (d + d)); float finfout = (float)sqrt(6.0F * scale / d);
w.SetDataRand(-finfout, finfout); w.SetDataRand(-finfout, finfout);
b.SetZeroAll(); b.SetZeroAll();
} }
/* /*
make the network make the network
for each layer representation x, we have for each layer representation x, we have
y = y =
>> input - the input tensor >> input - the input tensor
>> return - layer normalization output >> return - layer normalization output
*/ */
...@@ -90,17 +91,17 @@ XTensor T2TLN::Make(XTensor &input) ...@@ -90,17 +91,17 @@ XTensor T2TLN::Make(XTensor &input)
/* standard = sqrt(variance) */ /* standard = sqrt(variance) */
standard = Power(variance, 0.5F); standard = Power(variance, 0.5F);
/* unsqueeze mean and standard deviation to fit them into /* unsqueeze mean and standard deviation to fit them into
the same shape of x */ the same shape of x */
meanFilled = Unsqueeze(mean, x.order - 1, x.GetDim(-1)); meanFilled = Unsqueeze(mean, x.order - 1, x.GetDim(-1));
standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1)); standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1));
/* x' = (x - \mu)/standard */ /* x' = (x - \mu)/standard */
xn = (x - meanFilled)/standardFilled; xn = (x - meanFilled) / standardFilled;
/* result = x' * w + b */ /* result = x' * w + b */
return MMul(xn, w) + b; return xn * w + b;
} }
} }
...@@ -45,6 +45,9 @@ public: ...@@ -45,6 +45,9 @@ public:
/* the bias term b */ /* the bias term b */
XTensor b; XTensor b;
/* dimension size of the model */
int d;
public: public:
/* constructor */ /* constructor */
......
...@@ -174,6 +174,7 @@ void T2TModel::GetParams(XList &list) ...@@ -174,6 +174,7 @@ void T2TModel::GetParams(XList &list)
list.Add(&encoder.attentions[i].wk); list.Add(&encoder.attentions[i].wk);
list.Add(&encoder.attentions[i].wq); list.Add(&encoder.attentions[i].wq);
list.Add(&encoder.attentions[i].wv); list.Add(&encoder.attentions[i].wv);
list.Add(&encoder.attentions[i].wa);
list.Add(&encoder.fnnLayerNorms[i].w); list.Add(&encoder.fnnLayerNorms[i].w);
list.Add(&encoder.fnnLayerNorms[i].b); list.Add(&encoder.fnnLayerNorms[i].b);
list.Add(&encoder.attLayerNorms[i].w); list.Add(&encoder.attLayerNorms[i].w);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论