Commit 275a812a by liyinqiao

Bug fixed.

Fix backward bugs in MulAndShift function.
parent 257585e3
...@@ -1553,6 +1553,7 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient) ...@@ -1553,6 +1553,7 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
int n = income.GetParamInt(0); int n = income.GetParamInt(0);
MATRIX_TRANS_TYPE transW = income.GetParamTrans(1); MATRIX_TRANS_TYPE transW = income.GetParamTrans(1);
MATRIX_TRANS_TYPE transX = income.GetParamTrans(2); MATRIX_TRANS_TYPE transX = income.GetParamTrans(2);
DTYPE alpha = income.GetParam(3);
if (!isEfficient || w->isGrad) if (!isEfficient || w->isGrad)
XNoder::MakeGrad(w); XNoder::MakeGrad(w);
...@@ -1614,7 +1615,6 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient) ...@@ -1614,7 +1615,6 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
} }
/* compute dE/dx, dE/dw */ /* compute dE/dx, dE/dw */
XTensor * c = node; XTensor * c = node;
XTensor * dedc = node->grad; XTensor * dedc = node->grad;
...@@ -1622,7 +1622,7 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient) ...@@ -1622,7 +1622,7 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
XTensor * dedx = x->grad; XTensor * dedx = x->grad;
if (x->order == 2 && w->order == 2) if (x->order == 2 && w->order == 2)
GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, 1.0F, isEfficient); GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, alpha, isEfficient);
else if (transX == X_NOTRANS && x->order > 2 && w->order == 2){ else if (transX == X_NOTRANS && x->order > 2 && w->order == 2){
int orderBackupX = x->order; int orderBackupX = x->order;
int orderBackupC = c->order; int orderBackupC = c->order;
...@@ -1637,14 +1637,13 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient) ...@@ -1637,14 +1637,13 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
dedx->Reshape(dedx->unitNum / dedx->GetDim(-1), dedx->GetDim(-1)); dedx->Reshape(dedx->unitNum / dedx->GetDim(-1), dedx->GetDim(-1));
dedc->Reshape(dedc->unitNum / dedc->GetDim(-1), dedc->GetDim(-1)); dedc->Reshape(dedc->unitNum / dedc->GetDim(-1), dedc->GetDim(-1));
GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, 1.0F, isEfficient); GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, alpha, isEfficient);
x->Reshape(orderBackupX, dimsBackupX); x->Reshape(orderBackupX, dimsBackupX);
c->Reshape(orderBackupC, dimsBackupC); c->Reshape(orderBackupC, dimsBackupC);
if (!isEfficient || x->isGrad) if (!isEfficient || x->isGrad)
dedx->Reshape(orderBackupX, dimsBackupX); dedx->Reshape(orderBackupX, dimsBackupX);
dedc->Reshape(orderBackupC, dimsBackupC); dedc->Reshape(orderBackupC, dimsBackupC);
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
......
...@@ -80,7 +80,6 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, ...@@ -80,7 +80,6 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
// TODO!! // TODO!!
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
} }
else if (n >= 0 && n < tmp->order) { else if (n >= 0 && n < tmp->order) {
/* call _SumDim function */ /* call _SumDim function */
...@@ -95,6 +94,7 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, ...@@ -95,6 +94,7 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha);
} }
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论