Commit b8415485 by xiaotong

fix the bug of unaccumulated gradients in XMathGrad

parent 11a57e99
...@@ -73,11 +73,8 @@ void XMathGrad::GradSum(XTensor * node) ...@@ -73,11 +73,8 @@ void XMathGrad::GradSum(XTensor * node)
XNoder::MakeGrad(a); XNoder::MakeGrad(a);
XNoder::MakeGrad(b); XNoder::MakeGrad(b);
_CopyValues(node->grad, a->grad); _Sum(a->grad, node->grad, a->grad);
if(beta != 1.0F) _Sum(b->grad, node->grad, b->grad, beta);
_ScaleAndShift(node->grad, a->grad, beta);
else
_CopyValues(node->grad, b->grad);
} }
/* /*
...@@ -100,8 +97,8 @@ void XMathGrad::GradMultiply(XTensor * node) ...@@ -100,8 +97,8 @@ void XMathGrad::GradMultiply(XTensor * node)
XNoder::MakeGrad(b); XNoder::MakeGrad(b);
CheckNTErrors(XTensor::IsIdentical(a, b), "Wrong sized input tensors!"); CheckNTErrors(XTensor::IsIdentical(a, b), "Wrong sized input tensors!");
_Multiply(node->grad, b, a->grad); _Multiply(node->grad, b, a->grad, 1.0F);
_Multiply(node->grad, a, b->grad); _Multiply(node->grad, a, b->grad, 1.0F);
} }
/* /*
...@@ -135,40 +132,40 @@ void XMathGrad::GradMatrixMul(XTensor * node) ...@@ -135,40 +132,40 @@ void XMathGrad::GradMatrixMul(XTensor * node)
if(transA == X_NOTRANS && transB == X_NOTRANS){ if(transA == X_NOTRANS && transB == X_NOTRANS){
/* dE/da = dE/dc * b^T * \alpha */ /* dE/da = dE/dc * b^T * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha); _MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a^T * dE/dc * \alpha */ /* dE/db = a^T * dE/dc * \alpha */
_MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha); _MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
} }
/* c = a^T * b * \alpha */ /* c = a^T * b * \alpha */
else if(transA == X_TRANS && transB == X_NOTRANS){ else if(transA == X_TRANS && transB == X_NOTRANS){
/* dE/da = dE/dc * b^T * \alpha */ /* dE/da = dE/dc * b^T * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha); _MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a * dE/dc * \alpha */ /* dE/db = a * dE/dc * \alpha */
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha); _MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
} }
/* c = a * b^T * \alpha */ /* c = a * b^T * \alpha */
else if(transA == X_NOTRANS && transB == X_TRANS){ else if(transA == X_NOTRANS && transB == X_TRANS){
/* dE/da = dE/dc * b * \alpha */ /* dE/da = dE/dc * b * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha); _MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha, 1.0F);
/* dE/db = a^T * dE/dc * \alpha */ /* dE/db = a^T * dE/dc * \alpha */
_MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha); _MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
} }
/* c = a^T * b^T * \alpha */ /* c = a^T * b^T * \alpha */
else if(transA == X_TRANS && transB == X_TRANS){ else if(transA == X_TRANS && transB == X_TRANS){
/* dE/da = dE/dc * b * \alpha */ /* dE/da = dE/dc * b * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha); _MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha, 1.0F);
/* dE/db = a * dE/dc * \alpha */ /* dE/db = a * dE/dc * \alpha */
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha); _MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
} }
} }
......
...@@ -260,7 +260,7 @@ void XNet::Dump(FILE * file) ...@@ -260,7 +260,7 @@ void XNet::Dump(FILE * file)
{ {
for(int i = 0; i < nodes.count; i++){ for(int i = 0; i < nodes.count; i++){
XTensor * node = (XTensor*)nodes.Get(i); XTensor * node = (XTensor*)nodes.Get(i);
fprintf(file, "node %d\n", i); fprintf(file, "node %d: %d\n", i, node->id);
node->Dump(file, "tensor: "); node->Dump(file, "tensor: ");
if(node->grad != NULL) if(node->grad != NULL)
node->grad->Dump(file, "grad: "); node->grad->Dump(file, "grad: ");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论