Commit 1226160d by xiaotong

bug fix in XShapeGrad

parent ad3025f0
......@@ -48,8 +48,8 @@ int main( int argc, const char ** argv )
XTensor c;
InitTensor2D(&a, 2, 2);
InitTensor2D(&b, 2, 2);
InitTensor2D(&c, 2, 2);
InitTensor2D(&b, 2, 4);
InitTensor2D(&c, 2, 4);
a.SetZeroAll();
b.SetZeroAll();
......@@ -59,7 +59,7 @@ int main( int argc, const char ** argv )
a.Set2D(0.3F, 1, 0);
a.Set2D(0.4F, 1, 1);
b = a + a;
b = Merge(a, a, 1);
c = HTanH(MMul(a, b));
a.Dump(stderr, "a:");
......
......@@ -205,6 +205,7 @@ void XShapeGrad::GradMergeList(XTensor * node)
_Sum(smallGrad, &gradSmall, smallGrad);
}
gradSmall.data = NULL;
delete[] dims;
}
}
......
......@@ -24,6 +24,7 @@
#include "XBackwardLoss.h"
#include "XBackwardMath.h"
#include "XBackwardFunc.h"
#include "XBackwardShape.h"
#include "../tensor/XName.h"
namespace nts{
......@@ -179,6 +180,8 @@ void XNet::BackwardNode(XTensor * node)
XMathGrad::MakeGrad(node);
else if(XFuncGrad::IsFunc(node))
XFuncGrad::MakeGrad(node);
else if(XShapeGrad::IsShapeOP(node))
XShapeGrad::MakeGrad(node);
else{
ShowNTErrors("Wrong node type!");
}
......
......@@ -59,7 +59,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MOVEMENT_COPYINDEXED MOVEMENT + 1
#define MOVEMENT_COPYVALUES MOVEMENT_COPYINDEXED + 1
#define SHAPE REDUCE_REDUCEVARIANCE + 1
#define SHAPE MOVEMENT_COPYVALUES + 1
#define SHAPE_CONCATENATE SHAPE + 1
#define SHAPE_MERGE SHAPE_CONCATENATE + 1
#define SHAPE_MERGE_LIST SHAPE_MERGE + 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论