Commit f784b909 by xiaotong

implement sum broadcasting by unsqueezing

parent 6ba61279
......@@ -24,7 +24,9 @@
#include "Sum.h"
#include "SumDim.h"
#include "SumDim.cuh"
#include "Unsqueeze.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "../movement/CopyValues.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -174,38 +176,108 @@ c = a + b * \beta
>> c - the resulting tensor
>> beta - the scaling factor
*/
void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0)
void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
CheckNTErrors(a->order == b->order, "Wrong tensor orders!");
CheckNTErrors(a->order == c->order, "Wrong tensor orders!");
CheckNTErrors(a->order > 0, "TODO!");
int stride = 1;
int strideNum = 1;
int blockNum = 1;
int order = a->order;
int count = 0;
void * source = 0;
void * target = 0;
for(int i = a->order - 1; i >= 0; i--){
if(a->GetDim(i) == b->GetDim(i)){
stride *= b->GetDim(i);
for(int i = 0; i < order; i++){
if(a->GetDim(i) == b->GetDim(i))
continue;
if(b->GetDim(i) == 1){
int fitSize = 1;
int j = i + 1;
/* we define a range over dimensions. It is to be unsqueezed */
for(; j < order; j++){
if(a->GetDim(j) == b->GetDim(j))
break;
fitSize *= a->GetDim(j);
}
else if(b->GetDim(i) == 1){
strideNum *= b->GetDim(i);
if(i > 0 && b->GetDim(i - 1) == 1)
continue;
int dimsS[MAX_TENSOR_DIM_NUM];
int dimsT[MAX_TENSOR_DIM_NUM];
for(int k = 0; k < i; k++){
dimsS[k] = a->GetDim(k);
dimsT[k] = a->GetDim(k);
}
dimsS[i] = 1;
dimsT[i] = fitSize;
bool isLast = true;
for(int k = j; k < order; k++){
dimsS[i + k - j + 1] = b->GetDim(k);
dimsT[i + k - j + 1] = b->GetDim(k);
if(a->GetDim(k) != b->GetDim(k)){
if(b->GetDim(k) == 1)
isLast = false;
else{
ShowNTErrors("Wrong dimension size in broadcasting");
ShowNTErrors("Wrong dimension size!")
}
}
if(stride > 0){
blockNum = b->unitNum / (stride * strideNum);
stride *= strideNum;
}
dimsS[0] = -dimsS[0];
dimsT[0] = -dimsT[0];
XTensor * s = NewTensor(order - (j - i) + 1, dimsS, a->dataType, a->denseRatio, a->devID, a->mem);
XTensor * t = NewTensor(order - (j - i) + 1, dimsT, b->dataType, b->denseRatio, b->devID, b->mem);
if(count == 0)
source = a->data;
else{
source = target;
}
target = t->mem != NULL ?
t->mem->AllocBuf(t->devID, t->unitNum * t->unitSize):
XMemAlloc(t->devID, t->unitNum * t->unitSize);
s->data = source;
t->data = target;
_Unsqueeze(s, t, i, fitSize);
/* free the memory space of the one before the last allocation */
if(count > 0){
int size = s->unitNum * s->unitSize;
if(t->mem != NULL)
t->mem->ReleaseBuf(t->devID, size);
else
XMemFree(t->devID, source);
}
/* we do summation here */
if(isLast){
CheckNTErrors(t->unitNum == c->unitNum, "Wrong tensor size!");
_Sum(t, a, c, beta);
if(t->mem != NULL)
t->mem->ReleaseBuf(t->devID, t->unitNum * t->unitSize);
else
XMemFree(t->devID, target);
target = NULL;
}
s->data = NULL;
t->data = NULL;
DelTensor(s);
DelTensor(t);
i = j;
count++;
}
}
CheckNTErrors(target == NULL, "Something is wrong!");
}
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论