Commit cae88113 by xiaotong

fix the bug of missing leaddim parameter in XLink of softmax

parent 484f3694
...@@ -50,6 +50,7 @@ void XFuncGrad::MakeGrad(XTensor * node) ...@@ -50,6 +50,7 @@ void XFuncGrad::MakeGrad(XTensor * node)
_IdentityBackward(NULL, output, input, output->grad, input->grad, NOLOSS); _IdentityBackward(NULL, output, input, output->grad, input->grad, NOLOSS);
else if(operID == FUNC_LOGSOFTMAX){ else if(operID == FUNC_LOGSOFTMAX){
int leadDim = income.GetParamInt(0); int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in logsoftmax!");
_LogSoftmaxBackward(NULL, output, input, output->grad, input->grad, leadDim, NOLOSS); _LogSoftmaxBackward(NULL, output, input, output->grad, input->grad, leadDim, NOLOSS);
} }
else if(operID == FUNC_RECTIFY) else if(operID == FUNC_RECTIFY)
...@@ -58,6 +59,7 @@ void XFuncGrad::MakeGrad(XTensor * node) ...@@ -58,6 +59,7 @@ void XFuncGrad::MakeGrad(XTensor * node)
_SigmoidBackward(NULL, output, input, output->grad, input->grad, NOLOSS); _SigmoidBackward(NULL, output, input, output->grad, input->grad, NOLOSS);
else if(operID == FUNC_SOFTMAX){ else if(operID == FUNC_SOFTMAX){
int leadDim = income.GetParamInt(0); int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in softmax!");
_SoftmaxBackward(NULL, output, input, output->grad, input->grad, leadDim, NOLOSS); _SoftmaxBackward(NULL, output, input, output->grad, input->grad, leadDim, NOLOSS);
} }
else{ else{
......
...@@ -176,15 +176,19 @@ make a new tensor to keep the result and return it ...@@ -176,15 +176,19 @@ make a new tensor to keep the result and return it
*/ */
XTensor LogSoftmax(const XTensor &x, int leadDim) XTensor LogSoftmax(const XTensor &x, int leadDim)
{ {
int ld = leadDim;
if (ld < 0)
ld = x.order - 1;
XTensor y(&x); XTensor y(&x);
y.SetTMP(); y.SetTMP();
/* call _LogSoftmax function */ /* call _LogSoftmax function */
_LogSoftmax(&x, &y, leadDim); _LogSoftmax(&x, &y, ld);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX); XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX);
XLink::AddParamToHeadInt(&y, leadDim); XLink::AddParamToHeadInt(&y, ld);
return y; return y;
} }
......
...@@ -143,14 +143,19 @@ make a new tensor to keep the result and return it ...@@ -143,14 +143,19 @@ make a new tensor to keep the result and return it
*/ */
XTensor Softmax(const XTensor &x, int leadDim) XTensor Softmax(const XTensor &x, int leadDim)
{ {
int ld = leadDim;
if (ld < 0)
ld = x.order - 1;
XTensor y(&x); XTensor y(&x);
y.SetTMP(); y.SetTMP();
/* call _Softmax function */ /* call _Softmax function */
_Softmax(&x, &y, leadDim); _Softmax(&x, &y, ld);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX); XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
return y; return y;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论