Commit 275a812a by liyinqiao

Bug fixed.

Fix backward bugs in MulAndShift function.
parent 257585e3
......@@ -1553,6 +1553,7 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
int n = income.GetParamInt(0);
MATRIX_TRANS_TYPE transW = income.GetParamTrans(1);
MATRIX_TRANS_TYPE transX = income.GetParamTrans(2);
DTYPE alpha = income.GetParam(3);
if (!isEfficient || w->isGrad)
XNoder::MakeGrad(w);
......@@ -1614,7 +1615,6 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
}
/* compute dE/dx, dE/dw */
XTensor * c = node;
XTensor * dedc = node->grad;
......@@ -1622,7 +1622,7 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
XTensor * dedx = x->grad;
if (x->order == 2 && w->order == 2)
GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, 1.0F, isEfficient);
GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, alpha, isEfficient);
else if (transX == X_NOTRANS && x->order > 2 && w->order == 2){
int orderBackupX = x->order;
int orderBackupC = c->order;
......@@ -1637,14 +1637,13 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
dedx->Reshape(dedx->unitNum / dedx->GetDim(-1), dedx->GetDim(-1));
dedc->Reshape(dedc->unitNum / dedc->GetDim(-1), dedc->GetDim(-1));
GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, 1.0F, isEfficient);
GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, alpha, isEfficient);
x->Reshape(orderBackupX, dimsBackupX);
c->Reshape(orderBackupC, dimsBackupC);
if (!isEfficient || x->isGrad)
dedx->Reshape(orderBackupX, dimsBackupX);
dedc->Reshape(orderBackupC, dimsBackupC);
}
node->visitMark = NODE_FINISHED;
......
......@@ -80,7 +80,6 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
// TODO!!
ShowNTErrors("TODO!");
}
else if (n >= 0 && n < tmp->order) {
/* call _SumDim function */
......@@ -95,6 +94,7 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha);
}
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论