Commit 2f4da0fa by liyinqiao

Support scalar.

Support scalar tensor for MulAndShift operation.
parent 51db4cfe
......@@ -27,36 +27,6 @@
#include "Sum.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
return a dimension if the sum is performed as SumDim (in more details in SumDim.h)
>> a - a tensor
>> b - another tensor for sum
*/
int GetSumIndex(const XTensor &a, const XTensor &b)
{
if (a.order < b.order)
return -1;
if (IsSameShaped(a, b))
return -1;
int hitCount = 0;
int hitDim = -1;
for (int i = 0; i < b.order; i++) {
if (b.dimSize[b.order - 1 - i] == 1)
continue;
else if (b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]) {
hitCount++;
hitDim = a.order - b.order + i;
}
}
if (hitCount == 1)
return hitDim;
else
return -1;
}
/*
operation c = x * w + b MulAndShift
>> x - tensor x
......@@ -99,31 +69,33 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
XTensor c(tmp);
c.SetTMPFlag();
int n = GetSumIndex(tmp, b);
if (b.order == 0)
ScaleAndShift(*tmp, c, 1.0F, b.Get0D());
else {
int n = GetBroadcastDimIndex(tmp, b);
if (n == -1) {
/* call _Sum function */
_Sum(tmp, &b, &c);
if (n == -1) {
/* call _Sum function */
_Sum(tmp, &b, &c);
// TODO!!
ShowNTErrors("TODO!");
// TODO!!
ShowNTErrors("TODO!");
}
else if (n >= 0 && n < tmp->order) {
/* call _SumDim function */
_SumDim(tmp, &b, &c, n);
}
else {
ShowNTErrors("Something is wrong!");
}
/* tensor connections */
if (w.enableGrad && b.enableGrad) {
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
}
else if (n >= 0 && n < tmp->order) {
/* call _SumDim function */
_SumDim(tmp, &b, &c, n);
}
else {
ShowNTErrors("Something is wrong!");
}
/* tensor connections */
if (w.enableGrad && b.enableGrad) {
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
}
}
/* destroy variables */
......@@ -174,7 +146,7 @@ XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedA,
XTensor c(tmp);
c.SetTMPFlag();
int n = GetSumIndex(tmp, b);
int n = GetBroadcastDimIndex(tmp, b);
if (n == -1) {
/* call _Sum function */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论