Commit 14ec9fad by liyinqiao

Bug fixed.

Fix the bug in SumMe function which cannot handle the scalar tensor and broadcast case.
parent b0f2bbbf
...@@ -183,7 +183,25 @@ keep the result in the tensor a and return nothing ...@@ -183,7 +183,25 @@ keep the result in the tensor a and return nothing
*/ */
void SumMe(XTensor& a, const XTensor& b, DTYPE beta) void SumMe(XTensor& a, const XTensor& b, DTYPE beta)
{ {
if (b.order == 0){
DTYPE shift = b.Get0D() * beta;
_ScaleAndShift(&a, &a, 1.0F, shift);
}
else {
int n = GetBroadcastDimIndex(a, b);
if (n == -1) {
/* call _Sum function */
_Sum(&a, &b, &a, beta); _Sum(&a, &b, &a, beta);
}
else if (n >= 0 && n < a.order) {
/* call _SumDim function */
_SumDim(&a, &b, &a, n, beta);
}
else {
ShowNTErrors("Something is wrong!");
}
}
} }
/* /*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论