Commit a4b98ac6 by liyinqiao

Bug fixed and clean the codes.

1. Fix the bug in SubMe function which cannot handle the scalar tensor and broadcast case.
2. Clean the codes.
parent 14ec9fad
...@@ -63,9 +63,24 @@ keep the result in the tensor a and return nothing ...@@ -63,9 +63,24 @@ keep the result in the tensor a and return nothing
>> b - another tensor >> b - another tensor
>> beta - the scaling factor >> beta - the scaling factor
*/ */
void SubMe(XTensor& a, const XTensor& b, DTYPE beta) void SubMe(XTensor & a, const XTensor & b, DTYPE beta)
{ {
_Sub(&a, &b, &a, beta); if (b.order == 0){
DTYPE shift = -(b.Get0D() * beta);
_ScaleAndShift(&a, &a, 1.0F, shift);
}
else {
int n = GetBroadcastDimIndex(a, b);
if (n == -1)
/* call _Sub function */
_Sub(&a, &b, &a, beta);
else if (n >= 0 && n < a.order)
/* call _SumDim function to do the SubDim operation */
_SumDim(&a, &b, &a, n, -beta);
else
ShowNTErrors("Something is wrong!");
}
} }
/* /*
...@@ -77,7 +92,7 @@ make a new tensor c to keep the result and return it ...@@ -77,7 +92,7 @@ make a new tensor c to keep the result and return it
>> beta - the scaling factor >> beta - the scaling factor
<< return - the result of tensor subtraction << return - the result of tensor subtraction
*/ */
XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta) XTensor Sub(const XTensor & a, const XTensor & b, DTYPE beta)
{ {
XTensor c(&a); XTensor c(&a);
c.SetTMPFlag(); c.SetTMPFlag();
...@@ -125,7 +140,7 @@ tensor subtraction c = a - b * \beta ...@@ -125,7 +140,7 @@ tensor subtraction c = a - b * \beta
>> c - where we put a-b*\beta. we save it in a if c is NULL >> c - where we put a-b*\beta. we save it in a if c is NULL
>> beta - the scaling factor >> beta - the scaling factor
*/ */
void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) void Sub(const XTensor & a, const XTensor & b, XTensor & c, DTYPE beta)
{ {
if (!c.isInit || !IsSameShaped(a, c)) { if (!c.isInit || !IsSameShaped(a, c)) {
InitTensorV2(&c, &a); InitTensorV2(&c, &a);
......
...@@ -181,7 +181,7 @@ keep the result in the tensor a and return nothing ...@@ -181,7 +181,7 @@ keep the result in the tensor a and return nothing
>> b - another tensor >> b - another tensor
>> beta - the scaling factor >> beta - the scaling factor
*/ */
void SumMe(XTensor& a, const XTensor& b, DTYPE beta) void SumMe(XTensor & a, const XTensor & b, DTYPE beta)
{ {
if (b.order == 0){ if (b.order == 0){
DTYPE shift = b.Get0D() * beta; DTYPE shift = b.Get0D() * beta;
...@@ -190,17 +190,14 @@ void SumMe(XTensor& a, const XTensor& b, DTYPE beta) ...@@ -190,17 +190,14 @@ void SumMe(XTensor& a, const XTensor& b, DTYPE beta)
else { else {
int n = GetBroadcastDimIndex(a, b); int n = GetBroadcastDimIndex(a, b);
if (n == -1) { if (n == -1)
/* call _Sum function */ /* call _Sum function */
_Sum(&a, &b, &a, beta); _Sum(&a, &b, &a, beta);
} else if (n >= 0 && n < a.order)
else if (n >= 0 && n < a.order) {
/* call _SumDim function */ /* call _SumDim function */
_SumDim(&a, &b, &a, n, beta); _SumDim(&a, &b, &a, n, beta);
} else
else {
ShowNTErrors("Something is wrong!"); ShowNTErrors("Something is wrong!");
}
} }
} }
...@@ -247,7 +244,7 @@ make a new tensor c to keep the result and return it ...@@ -247,7 +244,7 @@ make a new tensor c to keep the result and return it
>> beta - the scaling factor >> beta - the scaling factor
<< return - the result of tensor summation << return - the result of tensor summation
*/ */
XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta) XTensor Sum(const XTensor & a, const XTensor & b, DTYPE beta)
{ {
XTensor c(&a); XTensor c(&a);
c.SetTMPFlag(); c.SetTMPFlag();
...@@ -294,7 +291,7 @@ tensor summation c = a + b * \beta ...@@ -294,7 +291,7 @@ tensor summation c = a + b * \beta
>> b - another tensor >> b - another tensor
>> beta - the scaling factor >> beta - the scaling factor
*/ */
void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) void Sum(const XTensor & a, const XTensor & b, XTensor & c, DTYPE beta)
{ {
if (!c.isInit || !IsSameShaped(a, c)) { if (!c.isInit || !IsSameShaped(a, c)) {
InitTensorV2(&c, &a); InitTensorV2(&c, &a);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论