Commit b1a9adde by xiaotong

improve dropout implementation by broadcasting from different dimensions

parent 896e5231
......@@ -71,6 +71,8 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient)
GradMultiply(node, isEfficient);
else if(operID == MATH_MULTIPLYDIM)
GradMultiplyDim(node, isEfficient);
else if (operID == MATH_MULTIPLYBROADCAST)
GradMultiplyBroadcast(node, isEfficient);
else if(operID == MATH_NEGATE)
GradNegate(node, isEfficient);
else if(operID == MATH_NORMALIZE)
......
......@@ -82,7 +82,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */
if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP, 2);
x = Dropout(x, dropoutP, 0, 2);
for(int i = 0; i < nlayer; i++){
XTensor att;
......@@ -97,7 +97,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP, 2);
att = Dropout(att, dropoutP, 0, 2);
/* residual connection */
res = Sum(att, x);
......@@ -111,7 +111,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */
if(isTraining && dropoutP > 0)
ende = Dropout(ende, dropoutP, 2);
ende = Dropout(ende, dropoutP, 0, 2);
/* residual connection */
res = Sum(ende, x);
......@@ -125,7 +125,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */
if(isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP, 2);
fnn = Dropout(fnn, dropoutP, 0, 2);
/* residual connection */
res = Sum(fnn, x);
......
......@@ -103,11 +103,9 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
x = embedder.Make(input);
//x.Dump(tmpFILE, "embedding: ");
/* dropout */
if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP, 2);
x = Dropout(x, dropoutP, 0, 2);
for(int i = 0; i < nlayer; i++){
XTensor att;
......@@ -120,7 +118,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP, 2);
att = Dropout(att, dropoutP, 0, 2);
/* residual connection */
res = Sum(att, x);
......@@ -133,7 +131,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
/* dropout */
if(isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP, 2);
fnn = Dropout(fnn, dropoutP, 0, 2);
/* residual connection */
res = Sum(fnn, x);
......
......@@ -36,8 +36,6 @@ int TransformerMain(int argc, const char ** argv)
{
if(argc == 0)
return 1;
fprintf(stderr, "%e\n", log(1e-8F));
char ** args = new char*[argc];
for(int i = 0; i < argc; i++){
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论