Commit eb325d83 by liyinqiao

Support scalar.

Support scalar tensor for Sub operation.
parent 3d6f1230
...@@ -16,16 +16,15 @@ ...@@ -16,16 +16,15 @@
*/ */
/* /*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01 * $Created by: Li Yinqiao (email: li.yin.qiao.2012@hotmail.com) 2020-02-11
*/ */
#include "../../XTensor.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XUtility.h"
#include "../shape/IsSameShaped.h" #include "../shape/IsSameShaped.h"
#include "Sum.h"
#include "SumDim.h"
#include "../math/ScaleAndShift.h"
#include "Sub.h" #include "Sub.h"
#include "Sub.cuh"
#include "SubDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
...@@ -39,80 +38,7 @@ tensor subtraction c = a - b * \beta ...@@ -39,80 +38,7 @@ tensor subtraction c = a - b * \beta
*/ */
void _Sub(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta) void _Sub(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{ {
CheckNTErrors(a && b && c, "Empty tensor input!"); _Sum(a, b, c, -beta);
CheckNTErrors(a->unitNum == b->unitNum && a->unitNum == c->unitNum,
"Unmatched tensors in addition!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched tensors in addition!");
CheckDev(a->devID, b->devID);
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
if (a == c) {
int P2PAccesible = 0;
#ifdef CUDA_UVA
cudaDeviceCanAccessPeer(&P2PAccesible, a->devID, b->devID);
#endif
if ((a->devID < 0 && b->devID >= 0) ||
(a->devID >= 0 && b->devID < 0) ||
(a->devID >= 0 && b->devID >= 0 && a->devID != b->devID && !P2PAccesible))
{
ShowNTErrors("Cannot run this method on multiple devices simultaneously!");
}
else
_CudaSub(a, b, c, beta);
}
else
_CudaSub(a, b, c, beta);
#endif
}
else {
if (!a->isSparse && !b->isSparse) {
CheckNTErrors(!c->isSparse, "Illegal use of sparse tensor in addition!");
if (a->dataType == DEFAULT_DTYPE &&
b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE)
{
DTYPE * ap = (DTYPE*)a->data;
DTYPE * bp = (DTYPE*)b->data;
DTYPE * cp = (DTYPE*)c->data;
/* unrolling */
int num = a->unitNum;
if (num % 4 == 0) {
for (int i = 0; i < num; i += 4) {
cp[i] = ap[i] - bp[i] * beta;
cp[i + 1] = ap[i + 1] - bp[i + 1] * beta;
cp[i + 2] = ap[i + 2] - bp[i + 2] * beta;
cp[i + 3] = ap[i + 3] - bp[i + 3] * beta;
}
}
else if (num % 2 == 0) {
for (int i = 0; i < num; i += 2) {
cp[i] = ap[i] - bp[i] * beta;
cp[i + 1] = ap[i + 1] - bp[i + 1] * beta;
}
}
else {
for (int i = 0; i < num; i++) {
cp[i] = ap[i] - bp[i] * beta;
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
} }
/* /*
...@@ -140,35 +66,6 @@ void SubMe(XTensor& a, const XTensor& b, DTYPE beta) ...@@ -140,35 +66,6 @@ void SubMe(XTensor& a, const XTensor& b, DTYPE beta)
{ {
_Sub(&a, &b, &a, beta); _Sub(&a, &b, &a, beta);
} }
/*
return a dimension if the subtraction is performed as SubDim (in more details in SubDim.h)
>> a - a tensor
>> b - another tensor for subtraction
*/
int GetSubDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
if(IsSameShaped(a, b))
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/* /*
tensor subtraction c = a - b * \beta (return an XTensor structure) tensor subtraction c = a - b * \beta (return an XTensor structure)
...@@ -184,33 +81,38 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta) ...@@ -184,33 +81,38 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
XTensor c(&a); XTensor c(&a);
c.SetTMPFlag(); c.SetTMPFlag();
int n = GetSubDimIndex(a, b); if (b.order == 0){
DTYPE shift = -(b.Get0D() * beta);
ScaleAndShift(a, c, 1.0F, shift);
}
else {
int n = GetBroadcastDimIndex(a, b);
if(n == -1){
/* call _Sub function */
_Sub(&a, &b, &c, beta);
if(n == -1){ /* tensor connections */
/* call _Sub function */ if (a.enableGrad && b.enableGrad) {
_Sub(&a, &b, &c, beta); XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
/* tensor connections */ }
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
} }
} else if(n >= 0 && n < a.order){
else if(n >= 0 && n < a.order){ /* call _SumDim function to do the SubDim operation */
/* call _SubDim function */ _SumDim(&a, &b, &c, n, -beta);
_SubDim(&a, &b, &c, n, beta);
/* tensor connections */
/* tensor connections */ if (a.enableGrad && b.enableGrad) {
if (a.enableGrad && b.enableGrad) { XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHead(&c, beta);
XLink::AddParamToHead(&c, beta); }
}
else{
ShowNTErrors("Something is wrong!");
} }
} }
else{
ShowNTErrors("Something is wrong!");
}
return c; return c;
} }
...@@ -228,31 +130,37 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) ...@@ -228,31 +130,37 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
InitTensorV2(&c, &a); InitTensorV2(&c, &a);
} }
int n = GetSubDimIndex(a, b); if (b.order == 0){
DTYPE shift = -(b.Get0D() * beta);
ScaleAndShift(a, c, 1.0F, shift);
}
else {
int n = GetBroadcastDimIndex(a, b);
if (n == -1) { if (n == -1) {
/* call _Sub function */ /* call _Sub function */
_Sub(&a, &b, &c, beta); _Sub(&a, &b, &c, beta);
if (a.enableGrad && b.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB); XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
}
} }
} else if (n >= 0 && n < a.order) {
else if (n >= 0 && n < a.order) { /* call _SumDim function to do the SubDim operation */
/* call _SubDim function */ _SumDim(&a, &b, &c, n, -beta);
_SubDim(&a, &b, &c, n, beta);
if (a.enableGrad && b.enableGrad) {
if (a.enableGrad && b.enableGrad) { /* tensor connections */
/* tensor connections */ XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHead(&c, beta);
XLink::AddParamToHead(&c, beta); }
}
else {
ShowNTErrors("Something is wrong!");
} }
}
else {
ShowNTErrors("Something is wrong!");
} }
} }
......
...@@ -177,11 +177,11 @@ void SumMe(XTensor& a, const XTensor& b, DTYPE beta) ...@@ -177,11 +177,11 @@ void SumMe(XTensor& a, const XTensor& b, DTYPE beta)
} }
/* /*
return a dimension if the sum is performed as SumDim (in more details in SumDim.h) return a dimension if the operation is performed as broadcast(e.g. SumDim function)
>> a - a tensor >> a - a tensor
>> b - another tensor for sum >> b - another tensor for operation
*/ */
int GetSumDimIndex(const XTensor &a, const XTensor &b) int GetBroadcastDimIndex(const XTensor & a, const XTensor & b)
{ {
if(a.order < b.order) if(a.order < b.order)
return -1; return -1;
...@@ -229,7 +229,7 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta) ...@@ -229,7 +229,7 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
ScaleAndShift(a, c, 1.0F, shift); ScaleAndShift(a, c, 1.0F, shift);
} }
else { else {
int n = GetSumDimIndex(a, b); int n = GetBroadcastDimIndex(a, b);
if(n == -1){ if(n == -1){
/* call _Sum function */ /* call _Sum function */
...@@ -277,7 +277,7 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) ...@@ -277,7 +277,7 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
ScaleAndShift(a, c, 1.0F, shift); ScaleAndShift(a, c, 1.0F, shift);
} }
else { else {
int n = GetSumDimIndex(a, b); int n = GetBroadcastDimIndex(a, b);
if (n == -1) { if (n == -1) {
/* call _Sum function */ /* call _Sum function */
......
...@@ -26,6 +26,9 @@ ...@@ -26,6 +26,9 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* return a dimension if the operation is performed as broadcast(e.g. SumDim function) */
int GetBroadcastDimIndex(const XTensor & a, const XTensor & b);
/* tensor summation c = a + b * \beta */ /* tensor summation c = a + b * \beta */
void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0); void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论