Commit 6ba61279 by xiaotong

add broadcasting

parent 947db8fb
......@@ -17,6 +17,8 @@
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-29
* &Updated by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-12-26
* Add summation by broadcasting.
*/
#include "Sum.h"
......@@ -162,5 +164,48 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
return c;
}
/*
tensor broadcast summation c = a + b * \beta where some of dimensions of b can be of size 1
c = a + b * \beta
>> a - a tensor
>> b - another tensor that would be broadcasted
>> c - the resulting tensor
>> beta - the scaling factor
*/
void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0)
{
CheckNTErrors(a->order == b->order, "Wrong tensor orders!");
CheckNTErrors(a->order == c->order, "Wrong tensor orders!");
CheckNTErrors(a->order > 0, "TODO!");
int stride = 1;
int strideNum = 1;
int blockNum = 1;
int count = 0;
for(int i = a->order - 1; i >= 0; i--){
if(a->GetDim(i) == b->GetDim(i)){
stride *= b->GetDim(i);
continue;
}
else if(b->GetDim(i) == 1){
strideNum *= b->GetDim(i);
if(i > 0 && b->GetDim(i - 1) == 1)
continue;
}
else{
ShowNTErrors("Wrong dimension size in broadcasting");
}
if(stride > 0){
blockNum = b->unitNum / (stride * strideNum);
stride *= strideNum;
}
else{
}
}
}
}
......@@ -17,6 +17,8 @@
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-29
* &Updated by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-12-26
* Add summation by broadcasting.
*/
#include "SumDim.cuh"
......@@ -28,7 +30,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor summation of a tensor and a row vector
c = a + b * \beta
c = a + b * \beta
where a is a tensor and b is a row vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
......
......@@ -17,6 +17,8 @@
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-29
* &Updated by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-12-26
* Add summation by broadcasting.
*/
#ifndef __SUMDIM_CUH__
......
......@@ -18,6 +18,9 @@
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-29
* It reached to 39 centigrade around 3:00 pm in Shenyang
* &Updated by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-12-26
* Add summation by broadcasting.
* Four of my master students graduated. Good luck to them for their future work!
*/
#ifndef __SUMDIM_H__
......@@ -38,6 +41,9 @@ void _SumDim(XTensor * a, const XTensor * b, int n, DTYPE beta = (DTYPE)1.0);
/* tensor summation c = a + b * \beta where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting. We make a new tensor c to keep the result and return it */
XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta = (DTYPE)1.0);
/* tensor broadcast summation c = a + b * \beta where some of dimensions of b can be of size 1 */
void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论