Commit ca1f1843 by liyinqiao

Support scalar tensor & bug fixed.

1. Support scalar tensor for Multiply and Div operation;
2. Fix backward bugs in Multiply and Div function;
3. Minor bugs fixed.
parent 275a812a
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "../../XName.h" #include "../../XName.h"
#include "../../XUtility.h" #include "../../XUtility.h"
#include "../shape/IsSameShaped.h" #include "../shape/IsSameShaped.h"
#include "Sum.h"
#include "../math/ScaleAndShift.h"
#include "Div.h" #include "Div.h"
#include "Div.cuh" #include "Div.cuh"
#include "DivDim.h" #include "DivDim.h"
...@@ -127,7 +129,7 @@ void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int le ...@@ -127,7 +129,7 @@ void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int le
element-wise division of two tensors (do it on site) element-wise division of two tensors (do it on site)
keep the result in the input tensor a and return nothing keep the result in the input tensor a and return nothing
a(i) = a(i)*b(i) + \alpha * a(i) a(i) = a(i)/b(i) + \alpha * a(i)
where i is the index of the item where i is the index of the item
>> a - tensor a (where keep the result) >> a - tensor a (where keep the result)
...@@ -157,40 +159,11 @@ void DivMe(XTensor& a, const XTensor& b, DTYPE alpha, int leadingDim) ...@@ -157,40 +159,11 @@ void DivMe(XTensor& a, const XTensor& b, DTYPE alpha, int leadingDim)
_Div(&a, &b, &a, alpha, leadingDim); _Div(&a, &b, &a, alpha, leadingDim);
} }
/*
return a dimension if the division is performed as DivDim (in more details in DivDim.h)
>> a - a tensor
>> b - another tensor for division
*/
int GetDivDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
if(IsSameShaped(a, b))
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/* /*
element-wise division of two tensors (return an XTensor structure) element-wise division of two tensors (return an XTensor structure)
make a new tensor c to keep the result and return it make a new tensor c to keep the result and return it
c(i) = a(i)*b(i) c(i) = a(i)/b(i)
where i is the index of the item where i is the index of the item
>> a - tensor a >> a - tensor a
...@@ -199,39 +172,43 @@ where i is the index of the item ...@@ -199,39 +172,43 @@ where i is the index of the item
>> leadingDim - the dimension along which we perform broadcasting >> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors << return - the product of the tensors
*/ */
XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim) XTensor Div(const XTensor &a, const XTensor &b, int leadingDim)
{ {
XTensor c(&a); XTensor c(&a);
c.SetTMPFlag(); c.SetTMPFlag();
int n = GetDivDimIndex(a, b); if (b.order == 0){
DTYPE scale = 1.0F / b.Get0D();
ScaleAndShift(a, c, scale, 0.0F);
}
else {
DTYPE alpha = 0.0F;
int n = GetBroadcastDimIndex(a, b);
if(n == -1){
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
if(n == -1){ /* call _Div function */
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!"); _Div(&a, &b, &c, alpha, leadingDim);
/* call _Div function */ /* tensor connections */
_Div(&a, &b, &c, alpha, leadingDim); if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_DIV);
/* tensor connections */ }
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
} }
} else if(n >= 0 && n < a.order){
else if(n >= 0 && n < a.order){ /* call _DivDim function */
/* call _DivDim function */ _DivDim(&a, &b, &c, n, alpha);
_DivDim(&a, &b, &c, n, alpha);
/* tensor connections */
/* tensor connections */ if (a.enableGrad && b.enableGrad) {
if (a.enableGrad && b.enableGrad) { XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadInt(&c, n); }
XLink::AddParamToHead(&c, alpha); }
else{
ShowNTErrors("Something is wrong!");
} }
}
else{
ShowNTErrors("Something is wrong!");
} }
return c; return c;
...@@ -255,36 +232,46 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin ...@@ -255,36 +232,46 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
InitTensorV2(&c, &a); InitTensorV2(&c, &a);
} }
int n = GetDivDimIndex(a, b); if (b.order == 0){
DTYPE scale = 1.0F / b.Get0D();
if (n == -1) { XTensor * tmp1 = NewTensorBufV2(&a, a.devID, a.mem);
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!"); XTensor * tmp2 = NewTensorBufV2(&c, c.devID, c.mem);
/* call _Div function */ ScaleAndShift(a, *tmp1, scale, 0.0F);
_Div(&a, &b, &c, 0, leadingDim); ScaleAndShift(c, *tmp2, alpha, 0.0F);
Sum(*tmp2, *tmp1, c);
if (a.enableGrad && b.enableGrad) { DelTensorBuf(tmp1);
/* tensor connections */ DelTensorBuf(tmp2);
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
} }
else if (n >= 0 && n < a.order) { else {
/* call _DivDim function */ int n = GetBroadcastDimIndex(a, b);
_DivDim(&a, &b, &c, n, alpha);
if (a.enableGrad && b.enableGrad) { if (n == -1) {
/* tensor connections */ CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n); /* call _Div function */
XLink::AddParamToHead(&c, alpha); _Div(&a, &b, &c, alpha, leadingDim);
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV);
}
}
else if (n >= 0 && n < a.order) {
/* call _DivDim function */
_DivDim(&a, &b, &c, n, alpha);
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
}
}
else {
ShowNTErrors("Something is wrong!");
} }
} }
else {
ShowNTErrors("Something is wrong!");
}
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
...@@ -48,7 +48,7 @@ make a new tensor to keep the result and return it ...@@ -48,7 +48,7 @@ make a new tensor to keep the result and return it
c(i) = a(i)/b(i) c(i) = a(i)/b(i)
where i is the index of the element where i is the index of the element
*/ */
XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha = 0.0, int leadingDim = 0); XTensor Div(const XTensor &a, const XTensor &b, int leadingDim = 0);
/* /*
element-wise division of two tensors: element-wise division of two tensors:
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "../../XName.h" #include "../../XName.h"
#include "../../XUtility.h" #include "../../XUtility.h"
#include "../shape/IsSameShaped.h" #include "../shape/IsSameShaped.h"
#include "Sum.h"
#include "../math/ScaleAndShift.h"
#include "Multiply.h" #include "Multiply.h"
#include "Multiply.cuh" #include "Multiply.cuh"
#include "MultiplyDim.h" #include "MultiplyDim.h"
...@@ -158,35 +160,6 @@ void MultiplyMe(XTensor& a, const XTensor& b, DTYPE alpha, int leadingDim) ...@@ -158,35 +160,6 @@ void MultiplyMe(XTensor& a, const XTensor& b, DTYPE alpha, int leadingDim)
_Multiply(&a, &b, &a, alpha, leadingDim); _Multiply(&a, &b, &a, alpha, leadingDim);
} }
/*
return a dimension if the multiplication is performed as MultiplyDim (in more details in MultiplyDim.h)
>> a - a tensor
>> b - another tensor for multiplication
*/
int GetMultiplyDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
if(IsSameShaped(a, b))
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/* /*
element-wise product of two tensors (return an XTensor structure) element-wise product of two tensors (return an XTensor structure)
make a new tensor c to keep the result and return it make a new tensor c to keep the result and return it
...@@ -199,40 +172,43 @@ where i is the index of the item ...@@ -199,40 +172,43 @@ where i is the index of the item
>> leadingDim - the dimension along which we perform broadcasting >> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors << return - the product of the tensors
*/ */
XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim) XTensor Multiply(const XTensor &a, const XTensor &b, int leadingDim)
{ {
XTensor c(&a); XTensor c(&a);
c.SetTMPFlag(); c.SetTMPFlag();
int n = GetMultiplyDimIndex(a, b); if (b.order == 0){
DTYPE scale = b.Get0D();
if(n == -1){ ScaleAndShift(a, c, scale, 0.0F);
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* call _Multiply function */
_Multiply(&a, &b, &c, 0, leadingDim);
/* tensor connections */
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
} }
else if(n >= 0 && n < a.order){ else {
/* call _MultiplyDim function */ DTYPE alpha = 0.0F;
_MultiplyDim(&a, &b, &c, n, alpha); int n = GetBroadcastDimIndex(a, b);
/* tensor connections */ if(n == -1){
if (a.enableGrad && b.enableGrad) { CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n); /* call _Multiply function */
XLink::AddParamToHead(&c, alpha); _Multiply(&a, &b, &c, alpha, leadingDim);
/* tensor connections */
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
}
}
else if(n >= 0 && n < a.order){
/* call _MultiplyDim function */
_MultiplyDim(&a, &b, &c, n, alpha);
/* tensor connections */
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
}
}
else{
ShowNTErrors("Something is wrong!");
} }
}
else{
ShowNTErrors("Something is wrong!");
} }
return c; return c;
...@@ -256,36 +232,46 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l ...@@ -256,36 +232,46 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
InitTensorV2(&c, &a); InitTensorV2(&c, &a);
} }
int n = GetMultiplyDimIndex(a, b); if (b.order == 0){
DTYPE scale = b.Get0D();
XTensor * tmp1 = NewTensorBufV2(&a, a.devID, a.mem);
XTensor * tmp2 = NewTensorBufV2(&c, c.devID, c.mem);
if (n == -1) { ScaleAndShift(a, *tmp1, scale, 0.0F);
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!"); ScaleAndShift(c, *tmp2, alpha, 0.0F);
Sum(*tmp2, *tmp1, c);
/* call _Multiply function */ DelTensorBuf(tmp1);
_Multiply(&a, &b, &c, 0, leadingDim); DelTensorBuf(tmp2);
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
} }
else if (n >= 0 && n < a.order) { else {
/* call _MultiplyDim function */ int n = GetBroadcastDimIndex(a, b);
_MultiplyDim(&a, &b, &c, n, alpha);
if (a.enableGrad && b.enableGrad) { if (n == -1) {
/* tensor connections */ CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n); /* call _Multiply function */
XLink::AddParamToHead(&c, alpha); _Multiply(&a, &b, &c, alpha, leadingDim);
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
}
}
else if (n >= 0 && n < a.order) {
/* call _MultiplyDim function */
_MultiplyDim(&a, &b, &c, n, alpha);
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
}
}
else {
ShowNTErrors("Something is wrong!");
} }
} }
else {
ShowNTErrors("Something is wrong!");
}
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
...@@ -48,7 +48,7 @@ make a new tensor to keep the result and return it ...@@ -48,7 +48,7 @@ make a new tensor to keep the result and return it
c(i) = a(i)*b(i) c(i) = a(i)*b(i)
where i is the index of the element where i is the index of the element
*/ */
XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha = 0.0, int leadingDim = 0); XTensor Multiply(const XTensor &a, const XTensor &b, int leadingDim = 0);
/* /*
element-wise product of two tensors: element-wise product of two tensors:
......
...@@ -27,7 +27,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -27,7 +27,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* /*
case 1: element-wise product of two tensors case 1: element-wise product of two tensors
c(i) = a(i)*b(i) + \alpha * c(i) c(i) = a(i)*b(i) + \alpha * c(i)
In this case, (2, 2) (2, 2) -> (2, 2), leadingDim=0, alpha=0. In this case, (2, 2) * (2, 2) -> (2, 2), leadingDim=0, alpha=0.
*/ */
bool TestMultiply1() bool TestMultiply1()
{ {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论