Commit 32e78bda by xiaotong

the first version of _SumBroadcast

parent f784b909
......@@ -24,7 +24,7 @@
#include "Sum.h"
#include "SumDim.h"
#include "SumDim.cuh"
#include "Unsqueeze.h"
#include "../Shape/Unsqueeze.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "../movement/CopyValues.h"
......@@ -64,6 +64,20 @@ void _SumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE bet
return;
}
/*int dims[MAX_TENSOR_DIM_NUM];
for(int i = 0; i < a->order; i++)
dims[i] = 1;
dims[n] = a->GetDim(n);
XTensor * b2 = NewTensor(a->order, dims, b->dataType, b->denseRatio, b->devID, b->mem);
_CopyValues(b, b2);
_SumBroadcast(a, b2, c, beta);
DelTensor(b2);
return;*/
if(a->devID >= 0 || b->devID >= 0 || c->devID >= 0){
#ifdef USE_CUDA
_CudaSumDim(a, b, c, n, beta);
......@@ -192,7 +206,7 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta
continue;
if(b->GetDim(i) == 1){
int fitSize = 1;
int fitSize = a->GetDim(i);
int j = i + 1;
/* we define a range over dimensions. It is to be unsqueezed */
......@@ -210,12 +224,11 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta
dimsT[k] = a->GetDim(k);
}
dimsS[i] = 1;
dimsT[i] = fitSize;
bool isLast = true;
for(int k = j; k < order; k++){
dimsS[i + k - j + 1] = b->GetDim(k);
dimsS[i + k - j + 0] = b->GetDim(k);
dimsT[i + k - j + 1] = b->GetDim(k);
if(a->GetDim(k) != b->GetDim(k)){
if(b->GetDim(k) == 1)
......@@ -229,11 +242,11 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta
dimsS[0] = -dimsS[0];
dimsT[0] = -dimsT[0];
XTensor * s = NewTensor(order - (j - i) + 1, dimsS, a->dataType, a->denseRatio, a->devID, a->mem);
XTensor * s = NewTensor(order - (j - i), dimsS, a->dataType, a->denseRatio, a->devID, a->mem);
XTensor * t = NewTensor(order - (j - i) + 1, dimsT, b->dataType, b->denseRatio, b->devID, b->mem);
if(count == 0)
source = a->data;
source = b->data;
else{
source = target;
}
......@@ -259,7 +272,7 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta
/* we do summation here */
if(isLast){
CheckNTErrors(t->unitNum == c->unitNum, "Wrong tensor size!");
_Sum(t, a, c, beta);
_Sum(a, t, c, beta);
if(t->mem != NULL)
t->mem->ReleaseBuf(t->devID, t->unitNum * t->unitSize);
else
......@@ -276,6 +289,9 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta
count++;
}
}
if(count == 0)
_Sum(a, b, c, beta);
CheckNTErrors(target == NULL, "Something is wrong!");
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论