Commit a4b98ac6 by liyinqiao

Bug fixed and clean the codes.

1. Fix the bug in SubMe function which cannot handle the scalar tensor and broadcast case.
2. Clean the codes.
parent 14ec9fad
......@@ -63,9 +63,24 @@ keep the result in the tensor a and return nothing
>> b - another tensor
>> beta - the scaling factor
*/
void SubMe(XTensor& a, const XTensor& b, DTYPE beta)
void SubMe(XTensor & a, const XTensor & b, DTYPE beta)
{
_Sub(&a, &b, &a, beta);
if (b.order == 0){
DTYPE shift = -(b.Get0D() * beta);
_ScaleAndShift(&a, &a, 1.0F, shift);
}
else {
int n = GetBroadcastDimIndex(a, b);
if (n == -1)
/* call _Sub function */
_Sub(&a, &b, &a, beta);
else if (n >= 0 && n < a.order)
/* call _SumDim function to do the SubDim operation */
_SumDim(&a, &b, &a, n, -beta);
else
ShowNTErrors("Something is wrong!");
}
}
/*
......@@ -77,7 +92,7 @@ make a new tensor c to keep the result and return it
>> beta - the scaling factor
<< return - the result of tensor subtraction
*/
XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
XTensor Sub(const XTensor & a, const XTensor & b, DTYPE beta)
{
XTensor c(&a);
c.SetTMPFlag();
......@@ -125,7 +140,7 @@ tensor subtraction c = a - b * \beta
>> c - where we put a-b*\beta. we save it in a if c is NULL
>> beta - the scaling factor
*/
void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
void Sub(const XTensor & a, const XTensor & b, XTensor & c, DTYPE beta)
{
if (!c.isInit || !IsSameShaped(a, c)) {
InitTensorV2(&c, &a);
......
......@@ -181,7 +181,7 @@ keep the result in the tensor a and return nothing
>> b - another tensor
>> beta - the scaling factor
*/
void SumMe(XTensor& a, const XTensor& b, DTYPE beta)
void SumMe(XTensor & a, const XTensor & b, DTYPE beta)
{
if (b.order == 0){
DTYPE shift = b.Get0D() * beta;
......@@ -190,17 +190,14 @@ void SumMe(XTensor& a, const XTensor& b, DTYPE beta)
else {
int n = GetBroadcastDimIndex(a, b);
if (n == -1) {
if (n == -1)
/* call _Sum function */
_Sum(&a, &b, &a, beta);
}
else if (n >= 0 && n < a.order) {
else if (n >= 0 && n < a.order)
/* call _SumDim function */
_SumDim(&a, &b, &a, n, beta);
}
else {
else
ShowNTErrors("Something is wrong!");
}
}
}
......@@ -247,7 +244,7 @@ make a new tensor c to keep the result and return it
>> beta - the scaling factor
<< return - the result of tensor summation
*/
XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
XTensor Sum(const XTensor & a, const XTensor & b, DTYPE beta)
{
XTensor c(&a);
c.SetTMPFlag();
......@@ -294,7 +291,7 @@ tensor summation c = a + b * \beta
>> b - another tensor
>> beta - the scaling factor
*/
void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
void Sum(const XTensor & a, const XTensor & b, XTensor & c, DTYPE beta)
{
if (!c.isInit || !IsSameShaped(a, c)) {
InitTensorV2(&c, &a);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论