Commit 32e78bda by xiaotong

the first version of _SumBroadcast

parent f784b909
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "Sum.h" #include "Sum.h"
#include "SumDim.h" #include "SumDim.h"
#include "SumDim.cuh" #include "SumDim.cuh"
#include "Unsqueeze.h" #include "../Shape/Unsqueeze.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XUtility.h" #include "../../XUtility.h"
#include "../movement/CopyValues.h" #include "../movement/CopyValues.h"
...@@ -64,6 +64,20 @@ void _SumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE bet ...@@ -64,6 +64,20 @@ void _SumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE bet
return; return;
} }
/*int dims[MAX_TENSOR_DIM_NUM];
for(int i = 0; i < a->order; i++)
dims[i] = 1;
dims[n] = a->GetDim(n);
XTensor * b2 = NewTensor(a->order, dims, b->dataType, b->denseRatio, b->devID, b->mem);
_CopyValues(b, b2);
_SumBroadcast(a, b2, c, beta);
DelTensor(b2);
return;*/
if(a->devID >= 0 || b->devID >= 0 || c->devID >= 0){ if(a->devID >= 0 || b->devID >= 0 || c->devID >= 0){
#ifdef USE_CUDA #ifdef USE_CUDA
_CudaSumDim(a, b, c, n, beta); _CudaSumDim(a, b, c, n, beta);
...@@ -192,7 +206,7 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta ...@@ -192,7 +206,7 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta
continue; continue;
if(b->GetDim(i) == 1){ if(b->GetDim(i) == 1){
int fitSize = 1; int fitSize = a->GetDim(i);
int j = i + 1; int j = i + 1;
/* we define a range over dimensions. It is to be unsqueezed */ /* we define a range over dimensions. It is to be unsqueezed */
...@@ -210,12 +224,11 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta ...@@ -210,12 +224,11 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta
dimsT[k] = a->GetDim(k); dimsT[k] = a->GetDim(k);
} }
dimsS[i] = 1;
dimsT[i] = fitSize; dimsT[i] = fitSize;
bool isLast = true; bool isLast = true;
for(int k = j; k < order; k++){ for(int k = j; k < order; k++){
dimsS[i + k - j + 1] = b->GetDim(k); dimsS[i + k - j + 0] = b->GetDim(k);
dimsT[i + k - j + 1] = b->GetDim(k); dimsT[i + k - j + 1] = b->GetDim(k);
if(a->GetDim(k) != b->GetDim(k)){ if(a->GetDim(k) != b->GetDim(k)){
if(b->GetDim(k) == 1) if(b->GetDim(k) == 1)
...@@ -229,11 +242,11 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta ...@@ -229,11 +242,11 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta
dimsS[0] = -dimsS[0]; dimsS[0] = -dimsS[0];
dimsT[0] = -dimsT[0]; dimsT[0] = -dimsT[0];
XTensor * s = NewTensor(order - (j - i) + 1, dimsS, a->dataType, a->denseRatio, a->devID, a->mem); XTensor * s = NewTensor(order - (j - i), dimsS, a->dataType, a->denseRatio, a->devID, a->mem);
XTensor * t = NewTensor(order - (j - i) + 1, dimsT, b->dataType, b->denseRatio, b->devID, b->mem); XTensor * t = NewTensor(order - (j - i) + 1, dimsT, b->dataType, b->denseRatio, b->devID, b->mem);
if(count == 0) if(count == 0)
source = a->data; source = b->data;
else{ else{
source = target; source = target;
} }
...@@ -259,7 +272,7 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta ...@@ -259,7 +272,7 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta
/* we do summation here */ /* we do summation here */
if(isLast){ if(isLast){
CheckNTErrors(t->unitNum == c->unitNum, "Wrong tensor size!"); CheckNTErrors(t->unitNum == c->unitNum, "Wrong tensor size!");
_Sum(t, a, c, beta); _Sum(a, t, c, beta);
if(t->mem != NULL) if(t->mem != NULL)
t->mem->ReleaseBuf(t->devID, t->unitNum * t->unitSize); t->mem->ReleaseBuf(t->devID, t->unitNum * t->unitSize);
else else
...@@ -276,6 +289,9 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta ...@@ -276,6 +289,9 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta
count++; count++;
} }
} }
if(count == 0)
_Sum(a, b, c, beta);
CheckNTErrors(target == NULL, "Something is wrong!"); CheckNTErrors(target == NULL, "Something is wrong!");
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论