Commit 2f4da0fa by liyinqiao

Support scalar.

Support scalar tensor for MulAndShift operation.
parent 51db4cfe
...@@ -27,36 +27,6 @@ ...@@ -27,36 +27,6 @@
#include "Sum.h" #include "Sum.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/*
return a dimension if the sum is performed as SumDim (in more details in SumDim.h)
>> a - a tensor
>> b - another tensor for sum
*/
int GetSumIndex(const XTensor &a, const XTensor &b)
{
if (a.order < b.order)
return -1;
if (IsSameShaped(a, b))
return -1;
int hitCount = 0;
int hitDim = -1;
for (int i = 0; i < b.order; i++) {
if (b.dimSize[b.order - 1 - i] == 1)
continue;
else if (b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]) {
hitCount++;
hitDim = a.order - b.order + i;
}
}
if (hitCount == 1)
return hitDim;
else
return -1;
}
/* /*
operation c = x * w + b MulAndShift operation c = x * w + b MulAndShift
>> x - tensor x >> x - tensor x
...@@ -99,7 +69,10 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, ...@@ -99,7 +69,10 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
XTensor c(tmp); XTensor c(tmp);
c.SetTMPFlag(); c.SetTMPFlag();
int n = GetSumIndex(tmp, b); if (b.order == 0)
ScaleAndShift(*tmp, c, 1.0F, b.Get0D());
else {
int n = GetBroadcastDimIndex(tmp, b);
if (n == -1) { if (n == -1) {
/* call _Sum function */ /* call _Sum function */
...@@ -112,12 +85,10 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, ...@@ -112,12 +85,10 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
else if (n >= 0 && n < tmp->order) { else if (n >= 0 && n < tmp->order) {
/* call _SumDim function */ /* call _SumDim function */
_SumDim(tmp, &b, &c, n); _SumDim(tmp, &b, &c, n);
} }
else { else {
ShowNTErrors("Something is wrong!"); ShowNTErrors("Something is wrong!");
} }
/* tensor connections */ /* tensor connections */
if (w.enableGrad && b.enableGrad) { if (w.enableGrad && b.enableGrad) {
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT); XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
...@@ -125,6 +96,7 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, ...@@ -125,6 +96,7 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
} }
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -174,7 +146,7 @@ XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedA, ...@@ -174,7 +146,7 @@ XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedA,
XTensor c(tmp); XTensor c(tmp);
c.SetTMPFlag(); c.SetTMPFlag();
int n = GetSumIndex(tmp, b); int n = GetBroadcastDimIndex(tmp, b);
if (n == -1) { if (n == -1) {
/* call _Sum function */ /* call _Sum function */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论