Commit ca1f1843 by liyinqiao

Support scalar tensor & bug fixed.

1. Support scalar tensor for Multiply and Div operation;
2. Fix backward bugs in Multiply and Div function;
3. Minor bugs fixed.
parent 275a812a
......@@ -23,6 +23,8 @@
#include "../../XName.h"
#include "../../XUtility.h"
#include "../shape/IsSameShaped.h"
#include "Sum.h"
#include "../math/ScaleAndShift.h"
#include "Div.h"
#include "Div.cuh"
#include "DivDim.h"
......@@ -127,7 +129,7 @@ void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int le
element-wise division of two tensors (do it on site)
keep the result in the input tensor a and return nothing
a(i) = a(i)*b(i) + \alpha * a(i)
a(i) = a(i)/b(i) + \alpha * a(i)
where i is the index of the item
>> a - tensor a (where keep the result)
......@@ -158,39 +160,10 @@ void DivMe(XTensor& a, const XTensor& b, DTYPE alpha, int leadingDim)
}
/*
return a dimension if the division is performed as DivDim (in more details in DivDim.h)
>> a - a tensor
>> b - another tensor for division
*/
int GetDivDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
if(IsSameShaped(a, b))
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/*
element-wise division of two tensors (return an XTensor structure)
make a new tensor c to keep the result and return it
c(i) = a(i)*b(i)
c(i) = a(i)/b(i)
where i is the index of the item
>> a - tensor a
......@@ -199,12 +172,18 @@ where i is the index of the item
>> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors
*/
XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim)
{
XTensor c(&a);
c.SetTMPFlag();
int n = GetDivDimIndex(a, b);
if (b.order == 0){
DTYPE scale = 1.0F / b.Get0D();
ScaleAndShift(a, c, scale, 0.0F);
}
else {
DTYPE alpha = 0.0F;
int n = GetBroadcastDimIndex(a, b);
if(n == -1){
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
......@@ -215,8 +194,6 @@ XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
/* tensor connections */
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
}
else if(n >= 0 && n < a.order){
......@@ -227,12 +204,12 @@ XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
}
else{
ShowNTErrors("Something is wrong!");
}
}
return c;
}
......@@ -255,19 +232,30 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
InitTensorV2(&c, &a);
}
int n = GetDivDimIndex(a, b);
if (b.order == 0){
DTYPE scale = 1.0F / b.Get0D();
XTensor * tmp1 = NewTensorBufV2(&a, a.devID, a.mem);
XTensor * tmp2 = NewTensorBufV2(&c, c.devID, c.mem);
ScaleAndShift(a, *tmp1, scale, 0.0F);
ScaleAndShift(c, *tmp2, alpha, 0.0F);
Sum(*tmp2, *tmp1, c);
DelTensorBuf(tmp1);
DelTensorBuf(tmp2);
}
else {
int n = GetBroadcastDimIndex(a, b);
if (n == -1) {
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* call _Div function */
_Div(&a, &b, &c, 0, leadingDim);
_Div(&a, &b, &c, alpha, leadingDim);
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
}
else if (n >= 0 && n < a.order) {
......@@ -278,13 +266,12 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
}
else {
ShowNTErrors("Something is wrong!");
}
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -48,7 +48,7 @@ make a new tensor to keep the result and return it
c(i) = a(i)/b(i)
where i is the index of the element
*/
XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha = 0.0, int leadingDim = 0);
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim = 0);
/*
element-wise division of two tensors:
......
......@@ -23,6 +23,8 @@
#include "../../XName.h"
#include "../../XUtility.h"
#include "../shape/IsSameShaped.h"
#include "Sum.h"
#include "../math/ScaleAndShift.h"
#include "Multiply.h"
#include "Multiply.cuh"
#include "MultiplyDim.h"
......@@ -159,35 +161,6 @@ void MultiplyMe(XTensor& a, const XTensor& b, DTYPE alpha, int leadingDim)
}
/*
return a dimension if the multiplication is performed as MultiplyDim (in more details in MultiplyDim.h)
>> a - a tensor
>> b - another tensor for multiplication
*/
int GetMultiplyDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
if(IsSameShaped(a, b))
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/*
element-wise product of two tensors (return an XTensor structure)
make a new tensor c to keep the result and return it
......@@ -199,25 +172,28 @@ where i is the index of the item
>> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors
*/
XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
XTensor Multiply(const XTensor &a, const XTensor &b, int leadingDim)
{
XTensor c(&a);
c.SetTMPFlag();
int n = GetMultiplyDimIndex(a, b);
if (b.order == 0){
DTYPE scale = b.Get0D();
ScaleAndShift(a, c, scale, 0.0F);
}
else {
DTYPE alpha = 0.0F;
int n = GetBroadcastDimIndex(a, b);
if(n == -1){
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* call _Multiply function */
_Multiply(&a, &b, &c, 0, leadingDim);
_Multiply(&a, &b, &c, alpha, leadingDim);
/* tensor connections */
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
}
else if(n >= 0 && n < a.order){
......@@ -228,12 +204,12 @@ XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
}
else{
ShowNTErrors("Something is wrong!");
}
}
return c;
}
......@@ -256,19 +232,30 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
InitTensorV2(&c, &a);
}
int n = GetMultiplyDimIndex(a, b);
if (b.order == 0){
DTYPE scale = b.Get0D();
XTensor * tmp1 = NewTensorBufV2(&a, a.devID, a.mem);
XTensor * tmp2 = NewTensorBufV2(&c, c.devID, c.mem);
ScaleAndShift(a, *tmp1, scale, 0.0F);
ScaleAndShift(c, *tmp2, alpha, 0.0F);
Sum(*tmp2, *tmp1, c);
DelTensorBuf(tmp1);
DelTensorBuf(tmp2);
}
else {
int n = GetBroadcastDimIndex(a, b);
if (n == -1) {
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* call _Multiply function */
_Multiply(&a, &b, &c, 0, leadingDim);
_Multiply(&a, &b, &c, alpha, leadingDim);
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
}
else if (n >= 0 && n < a.order) {
......@@ -279,13 +266,12 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
}
else {
ShowNTErrors("Something is wrong!");
}
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -48,7 +48,7 @@ make a new tensor to keep the result and return it
c(i) = a(i)*b(i)
where i is the index of the element
*/
XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha = 0.0, int leadingDim = 0);
XTensor Multiply(const XTensor &a, const XTensor &b, int leadingDim = 0);
/*
element-wise product of two tensors:
......
......@@ -27,7 +27,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: element-wise product of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
In this case, (2, 2) (2, 2) -> (2, 2), leadingDim=0, alpha=0.
In this case, (2, 2) * (2, 2) -> (2, 2), leadingDim=0, alpha=0.
*/
bool TestMultiply1()
{
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论