Commit 1226160d by xiaotong

bug fix in XShapeGrad

parent ad3025f0
...@@ -48,8 +48,8 @@ int main( int argc, const char ** argv ) ...@@ -48,8 +48,8 @@ int main( int argc, const char ** argv )
XTensor c; XTensor c;
InitTensor2D(&a, 2, 2); InitTensor2D(&a, 2, 2);
InitTensor2D(&b, 2, 2); InitTensor2D(&b, 2, 4);
InitTensor2D(&c, 2, 2); InitTensor2D(&c, 2, 4);
a.SetZeroAll(); a.SetZeroAll();
b.SetZeroAll(); b.SetZeroAll();
...@@ -59,7 +59,7 @@ int main( int argc, const char ** argv ) ...@@ -59,7 +59,7 @@ int main( int argc, const char ** argv )
a.Set2D(0.3F, 1, 0); a.Set2D(0.3F, 1, 0);
a.Set2D(0.4F, 1, 1); a.Set2D(0.4F, 1, 1);
b = a + a; b = Merge(a, a, 1);
c = HTanH(MMul(a, b)); c = HTanH(MMul(a, b));
a.Dump(stderr, "a:"); a.Dump(stderr, "a:");
......
...@@ -205,6 +205,7 @@ void XShapeGrad::GradMergeList(XTensor * node) ...@@ -205,6 +205,7 @@ void XShapeGrad::GradMergeList(XTensor * node)
_Sum(smallGrad, &gradSmall, smallGrad); _Sum(smallGrad, &gradSmall, smallGrad);
} }
gradSmall.data = NULL;
delete[] dims; delete[] dims;
} }
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "XBackwardLoss.h" #include "XBackwardLoss.h"
#include "XBackwardMath.h" #include "XBackwardMath.h"
#include "XBackwardFunc.h" #include "XBackwardFunc.h"
#include "XBackwardShape.h"
#include "../tensor/XName.h" #include "../tensor/XName.h"
namespace nts{ namespace nts{
...@@ -179,6 +180,8 @@ void XNet::BackwardNode(XTensor * node) ...@@ -179,6 +180,8 @@ void XNet::BackwardNode(XTensor * node)
XMathGrad::MakeGrad(node); XMathGrad::MakeGrad(node);
else if(XFuncGrad::IsFunc(node)) else if(XFuncGrad::IsFunc(node))
XFuncGrad::MakeGrad(node); XFuncGrad::MakeGrad(node);
else if(XShapeGrad::IsShapeOP(node))
XShapeGrad::MakeGrad(node);
else{ else{
ShowNTErrors("Wrong node type!"); ShowNTErrors("Wrong node type!");
} }
......
...@@ -59,7 +59,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -59,7 +59,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MOVEMENT_COPYINDEXED MOVEMENT + 1 #define MOVEMENT_COPYINDEXED MOVEMENT + 1
#define MOVEMENT_COPYVALUES MOVEMENT_COPYINDEXED + 1 #define MOVEMENT_COPYVALUES MOVEMENT_COPYINDEXED + 1
#define SHAPE REDUCE_REDUCEVARIANCE + 1 #define SHAPE MOVEMENT_COPYVALUES + 1
#define SHAPE_CONCATENATE SHAPE + 1 #define SHAPE_CONCATENATE SHAPE + 1
#define SHAPE_MERGE SHAPE_CONCATENATE + 1 #define SHAPE_MERGE SHAPE_CONCATENATE + 1
#define SHAPE_MERGE_LIST SHAPE_MERGE + 1 #define SHAPE_MERGE_LIST SHAPE_MERGE + 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论