Commit b1a9adde by xiaotong

improve dropout implementation by broadcasting from different dimensions

parent 896e5231
...@@ -71,6 +71,8 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient) ...@@ -71,6 +71,8 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient)
GradMultiply(node, isEfficient); GradMultiply(node, isEfficient);
else if(operID == MATH_MULTIPLYDIM) else if(operID == MATH_MULTIPLYDIM)
GradMultiplyDim(node, isEfficient); GradMultiplyDim(node, isEfficient);
else if (operID == MATH_MULTIPLYBROADCAST)
GradMultiplyBroadcast(node, isEfficient);
else if(operID == MATH_NEGATE) else if(operID == MATH_NEGATE)
GradNegate(node, isEfficient); GradNegate(node, isEfficient);
else if(operID == MATH_NORMALIZE) else if(operID == MATH_NORMALIZE)
......
...@@ -82,7 +82,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X ...@@ -82,7 +82,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP, 2); x = Dropout(x, dropoutP, 0, 2);
for(int i = 0; i < nlayer; i++){ for(int i = 0; i < nlayer; i++){
XTensor att; XTensor att;
...@@ -97,7 +97,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X ...@@ -97,7 +97,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP, 2); att = Dropout(att, dropoutP, 0, 2);
/* residual connection */ /* residual connection */
res = Sum(att, x); res = Sum(att, x);
...@@ -111,7 +111,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X ...@@ -111,7 +111,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
ende = Dropout(ende, dropoutP, 2); ende = Dropout(ende, dropoutP, 0, 2);
/* residual connection */ /* residual connection */
res = Sum(ende, x); res = Sum(ende, x);
...@@ -125,7 +125,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X ...@@ -125,7 +125,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP, 2); fnn = Dropout(fnn, dropoutP, 0, 2);
/* residual connection */ /* residual connection */
res = Sum(fnn, x); res = Sum(fnn, x);
......
...@@ -103,11 +103,9 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo ...@@ -103,11 +103,9 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
x = embedder.Make(input); x = embedder.Make(input);
//x.Dump(tmpFILE, "embedding: ");
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP, 2); x = Dropout(x, dropoutP, 0, 2);
for(int i = 0; i < nlayer; i++){ for(int i = 0; i < nlayer; i++){
XTensor att; XTensor att;
...@@ -120,7 +118,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo ...@@ -120,7 +118,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP, 2); att = Dropout(att, dropoutP, 0, 2);
/* residual connection */ /* residual connection */
res = Sum(att, x); res = Sum(att, x);
...@@ -133,7 +131,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo ...@@ -133,7 +131,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP, 2); fnn = Dropout(fnn, dropoutP, 0, 2);
/* residual connection */ /* residual connection */
res = Sum(fnn, x); res = Sum(fnn, x);
......
...@@ -37,8 +37,6 @@ int TransformerMain(int argc, const char ** argv) ...@@ -37,8 +37,6 @@ int TransformerMain(int argc, const char ** argv)
if(argc == 0) if(argc == 0)
return 1; return 1;
fprintf(stderr, "%e\n", log(1e-8F));
char ** args = new char*[argc]; char ** args = new char*[argc];
for(int i = 0; i < argc; i++){ for(int i = 0; i < argc; i++){
args[i] = new char[strlen(argv[i]) + 1]; args[i] = new char[strlen(argv[i]) + 1];
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论