Commit a26caf40 by xiaotong

better code of SumDim

parent daa2f801
...@@ -41,8 +41,8 @@ int main( int argc, const char ** argv ) ...@@ -41,8 +41,8 @@ int main( int argc, const char ** argv )
//TransposeTest(); //TransposeTest();
//return 0; //return 0;
SumDimTest(); //SumDimTest();
return 0; //return 0;
if(argc > 1 && !strcmp(argv[1], "-test")) if(argc > 1 && !strcmp(argv[1], "-test"))
1;//Test(); 1;//Test();
......
...@@ -37,6 +37,8 @@ void XMathGrad::MakeGrad(XTensor * node) ...@@ -37,6 +37,8 @@ void XMathGrad::MakeGrad(XTensor * node)
if(operID == MATH_SUM) if(operID == MATH_SUM)
GradSum(node); GradSum(node);
else if(operID == MATH_SUMDIM)
GradSumDim(node);
else if(operID == MATH_MULTIPLY) else if(operID == MATH_MULTIPLY)
GradMultiply(node); GradMultiply(node);
else if(operID == MATH_MATRIXMUL) else if(operID == MATH_MATRIXMUL)
...@@ -80,6 +82,90 @@ void XMathGrad::GradSum(XTensor * node) ...@@ -80,6 +82,90 @@ void XMathGrad::GradSum(XTensor * node)
} }
/* /*
gradient for sum with one dimension
c = a + b * \beta
where the size of b is equal to dimension n of a, i.e., |b| = a.dimSize[n]
dE/da = dE/dc
dE/db = dE/dc * b.reduce(0,...,n-1,n+1,...) * \beta
*/
void XMathGrad::GradSumDim(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 2, "Wrong input tensor number for SUM!");
XTensor * a = income.tails[0];
XTensor * b = income.tails[1];
int n = income.GetParamInt(0);
DTYPE beta = income.GetParam(1);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
_Sum(a->grad, node->grad, a->grad);
int order = a->order;
int dimSize[MAX_TENSOR_DIM_NUM];
memcpy(dimSize, a->dimSize, sizeof(int) * a->order);
if(n == order - 1){
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = a->unitNum/dimSize[order - 1];
reshapedSize[1] = dimSize[order - 1];
/* we reshape dE/dc to a matrix whose column number is equal to the
size of b. Then we can reduce the matrix into a row vector. */
node->grad->Reshape(2, reshapedSize);
if(b->outgo.tailNum > 1){
XTensor * bGradTMP = NewTensorBuf(b->grad, b->devID, b->mem);
_ReduceSum(node->grad, bGradTMP, 0);
_Sum(bGradTMP, b->grad, b->grad);
DelTensorBuf(bGradTMP);
}
else
_ReduceSum(node->grad, b->grad, 0);
node->grad->Reshape(order, dimSize);
}
else{
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = 1;
reshapedSize[1] = dimSize[n];
reshapedSize[2] = 1;
for(int i = 0; i < order; i++){
if(i < n)
reshapedSize[0] *= dimSize[i];
}
reshapedSize[2] = a->unitNum / (reshapedSize[0] * reshapedSize[1]);
/* we reshape dE/dc to a 3D tensor of size (x, y, z) where y = |b|.
Then reduce along with z and x to obtain dE/db. */
node->grad->Reshape(3, reshapedSize);
XTensor * interGrad = NewTensorBuf(2, reshapedSize, b->devID, b->mem, b->dataType, b->denseRatio);
_ReduceSum(node->grad, interGrad, 2);
if(b->outgo.tailNum > 1){
XTensor * bGradTMP = NewTensorBuf(b->grad, b->devID, b->mem);
_ReduceSum(interGrad, bGradTMP, 0);
_Sum(bGradTMP, b->grad, b->grad);
DelTensorBuf(bGradTMP);
}
else
_ReduceSum(interGrad, b->grad, 0);
node->grad->Reshape(order, dimSize);
DelTensorBuf(interGrad);
}
node->visitMark = NODE_FINISHED;
}
/*
gradient for multiply (dot production) gradient for multiply (dot production)
for for
c = a * b c = a * b
......
...@@ -44,6 +44,11 @@ private: ...@@ -44,6 +44,11 @@ private:
static static
void GradSum(XTensor * node); void GradSum(XTensor * node);
/* gradient for sum with one dimension: c = a + b * \beta
where the size of b is equal to that of one dimension of a */
static
void GradSumDim(XTensor * node);
/* gradient for multiply (dot production): c = a * b */ /* gradient for multiply (dot production): c = a * b */
static static
void GradMultiply(XTensor * node); void GradMultiply(XTensor * node);
......
...@@ -999,15 +999,11 @@ void ForwardAutoDiff(XTensor inputs[], XTensor &output, FNNModel &model) ...@@ -999,15 +999,11 @@ void ForwardAutoDiff(XTensor inputs[], XTensor &output, FNNModel &model)
hidden = Merge(hidden, 2, 0); hidden = Merge(hidden, 2, 0);
/* hidden layers */ /* hidden layers */
for(int i = 0; i < depth; i++){ for(int i = 0; i < depth; i++)
b = Unsqueeze(model.hiddenB[i], 0, batchSize); hidden = MMul(hidden, model.hiddenW[i]) + model.hiddenB[i];
hidden = MMul(hidden, model.hiddenW[i]) + b;
}
b = Unsqueeze(model.outputB, 0, batchSize);
/* output layer */ /* output layer */
output = LogSoftmax(MMul(hidden, model.outputW) + b, 1); output = LogSoftmax(MMul(hidden, model.outputW) + model.outputB, 1);
//XLink::ShowNetwork(stderr, &output); //XLink::ShowNetwork(stderr, &output);
} }
......
...@@ -41,6 +41,8 @@ const char * GetOPName(int type) ...@@ -41,6 +41,8 @@ const char * GetOPName(int type)
return "M_SIGN"; return "M_SIGN";
else if (type == MATH_SUM) else if (type == MATH_SUM)
return "M_SUM"; return "M_SUM";
else if (type == MATH_SUMDIM)
return "M_SUMDIM";
else if (type == MATH_LOG) else if (type == MATH_LOG)
return "M_LOG"; return "M_LOG";
else if (type == MATH_NORMALIZE) else if (type == MATH_NORMALIZE)
......
...@@ -1885,12 +1885,13 @@ generate a XTensor which allocates data on the buffer ...@@ -1885,12 +1885,13 @@ generate a XTensor which allocates data on the buffer
>> myDimSize - the size of each dimension >> myDimSize - the size of each dimension
>> myMem - memory pool used to allocating the data array. >> myMem - memory pool used to allocating the data array.
we actually allocate the data on the buffer associated with we actually allocate the data on the buffer associated with
the memory pool. the memory pool
>> devID - device id
>> myDataType - unit size (e.g., int, float, and double) >> myDataType - unit size (e.g., int, float, and double)
>> myDenseRatio - how often an element has non-zero value >> myDenseRatio - how often an element has non-zero value
*/ */
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, XMem * myMem, XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, int devID, XMem * myMem,
const TENSOR_DATA_TYPE myDataType, const float myDenseRatio) const TENSOR_DATA_TYPE myDataType, const float myDenseRatio)
{ {
CheckNTErrors(myMem != NULL, "No memory pool specified!"); CheckNTErrors(myMem != NULL, "No memory pool specified!");
...@@ -1901,12 +1902,31 @@ XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, XMem * myMem, ...@@ -1901,12 +1902,31 @@ XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, XMem * myMem,
dims[0] = -abs(dims[0]); dims[0] = -abs(dims[0]);
XTensor * tensor = NewTensor(myOrder, dims, myDataType, myDenseRatio, -1, myMem); XTensor * tensor = NewTensor(myOrder, dims, myDataType, myDenseRatio, -1, myMem);
tensor->data = myMem->AllocBuf(myMem->devID, tensor->unitNum * tensor->unitSize);
if(myMem != NULL)
tensor->data = myMem->AllocBuf(myMem->devID, tensor->unitNum * tensor->unitSize);
else
tensor->data = XMemAlloc(devID, tensor->unitNum * tensor->unitSize);
return tensor; return tensor;
} }
/* /*
generate a XTensor which allocates data on the buffer
>> reference - reference tensor
>> devID - device id
>> myMem - memory pool used to allocating the data array.
we actually allocate the data on the buffer associated with
the memory pool
*/
XTensor * NewTensorBuf(const XTensor * reference, int devID, XMem * myMem)
{
return NewTensorBuf(reference->order, reference->dimSize,
devID, myMem,
reference->dataType, reference->denseRatio);
}
/*
generate a dense vector generate a dense vector
>> num - number of entries >> num - number of entries
>> myDataType - unit size (e.g., int, float, and double) >> myDataType - unit size (e.g., int, float, and double)
...@@ -2056,7 +2076,7 @@ XTensor * NewTensor(XTensor * a, bool isFilledData) ...@@ -2056,7 +2076,7 @@ XTensor * NewTensor(XTensor * a, bool isFilledData)
free the data space of a given tensor free the data space of a given tensor
>> tensor - pointer to the tensor >> tensor - pointer to the tensor
*/ */
void DelTensor(const XTensor * tensor) void DelTensor(XTensor * tensor)
{ {
delete tensor; delete tensor;
} }
...@@ -2065,10 +2085,13 @@ void DelTensor(const XTensor * tensor) ...@@ -2065,10 +2085,13 @@ void DelTensor(const XTensor * tensor)
free the data space of a given tensor (on the buffer) free the data space of a given tensor (on the buffer)
>> tensor - pointer to the tensor >> tensor - pointer to the tensor
*/ */
void DelTensorBuf(const XTensor * tensor) void DelTensorBuf(XTensor * tensor)
{ {
CheckNTErrors(tensor->mem != NULL, "No memory pool found!"); if(tensor->mem != NULL)
tensor->mem->ReleaseBuf(tensor->devID, tensor->unitNum * tensor->unitSize); tensor->mem->ReleaseBuf(tensor->devID, tensor->unitNum * tensor->unitSize);
else
XMemFree(tensor->devID, tensor->data);
tensor->data = NULL;
delete tensor; delete tensor;
} }
......
...@@ -391,9 +391,12 @@ XTensor * NewTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_ ...@@ -391,9 +391,12 @@ XTensor * NewTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_
const float myDenseRatio = 1.0F, const int myDevID = -1, XMem * myMem = NULL); const float myDenseRatio = 1.0F, const int myDevID = -1, XMem * myMem = NULL);
/* generate a XTensor which allocates data on the buffer */ /* generate a XTensor which allocates data on the buffer */
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, XMem * myMem, XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, int devID, XMem * myMem,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const float myDenseRatio = 1.0F); const TENSOR_DATA_TYPE myDataType = X_FLOAT, const float myDenseRatio = 1.0F);
/* generate a XTensor which allocates data on the buffer */
XTensor * NewTensorBuf(const XTensor * reference, int devID, XMem * myMem);
/* generate a dense vector */ /* generate a dense vector */
XTensor * NewTensor1D(const int num, const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1, XTensor * NewTensor1D(const int num, const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1,
XMem * myMem = NULL); XMem * myMem = NULL);
...@@ -422,10 +425,10 @@ XTensor * NewTensor5D(const int d0, const int d1, const int d2, const int d3, co ...@@ -422,10 +425,10 @@ XTensor * NewTensor5D(const int d0, const int d1, const int d2, const int d3, co
XTensor * NewTensor(XTensor * a, bool isFilledData = true); XTensor * NewTensor(XTensor * a, bool isFilledData = true);
/* free the data space of a given tensor */ /* free the data space of a given tensor */
void DelTensor(const XTensor * tensor); void DelTensor(XTensor * tensor);
/* free the data space of a given tensor (on the buffer) */ /* free the data space of a given tensor (on the buffer) */
void DelTensorBuf(const XTensor * tensor); void DelTensorBuf(XTensor * tensor);
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "../../XUtility.h" #include "../../XUtility.h"
#include "Sum.h" #include "Sum.h"
#include "Sum.cuh" #include "Sum.cuh"
#include "SumDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
...@@ -123,6 +124,33 @@ void _SumMe(XTensor * a, const XTensor * b, DTYPE beta) ...@@ -123,6 +124,33 @@ void _SumMe(XTensor * a, const XTensor * b, DTYPE beta)
{ {
_Sum(a, b, a, beta); _Sum(a, b, a, beta);
} }
/*
return a dimension if the sum is performed as SumDim (in more details in SumDim.h
>> a - a tensor
>> b - another tensor for sum
*/
int GetSumDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/* /*
tensor summation c = a + b * \beta (return a XTensor structure) tensor summation c = a + b * \beta (return a XTensor structure)
...@@ -138,12 +166,28 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta) ...@@ -138,12 +166,28 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
XTensor c(&a); XTensor c(&a);
c.SetTMP(); c.SetTMP();
/* call _Sum function */ int n = GetSumDimIndex(a, b);
_Sum(&a, &b, &c, beta);
if(n == -1){
/* call _Sum function */
_Sum(&a, &b, &c, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUM); XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
}
else if(n >= 0 && n < a.order){
/* call _Sum function */
_SumDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
else{
ShowNTErrors("Something is wrong!");
}
return c; return c;
} }
......
...@@ -151,7 +151,7 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta) ...@@ -151,7 +151,7 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
c.SetTMP(); c.SetTMP();
/* call _Sum function */ /* call _Sum function */
_Sum(&a, &b, &c, beta); _SumDim(&a, &b, &c, n, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM); XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
......
...@@ -49,7 +49,7 @@ void KernelAddWithRow(T * a, T * b, T * c, int rowNum, int colNum, T beta) ...@@ -49,7 +49,7 @@ void KernelAddWithRow(T * a, T * b, T * c, int rowNum, int colNum, T beta)
return; return;
if(threadIdx.y == 0) if(threadIdx.y == 0)
bv[threadIdx.x] = b[col]; bv[threadIdx.x] = b[col];
__syncthreads(); __syncthreads();
...@@ -89,7 +89,7 @@ void KernelAddWithCol(T * a, T * b, T * c, int rowNum, int colNum, int blockSize ...@@ -89,7 +89,7 @@ void KernelAddWithCol(T * a, T * b, T * c, int rowNum, int colNum, int blockSize
return; return;
if(threadIdx.x == 0) if(threadIdx.x == 0)
bv[threadIdx.y] = b[row]; bv[threadIdx.y] = b[row];
__syncthreads(); __syncthreads();
...@@ -138,6 +138,9 @@ void _CudaSumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE ...@@ -138,6 +138,9 @@ void _CudaSumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE
int cudaGrids[3]; int cudaGrids[3];
int cudaBlocks[3]; int cudaBlocks[3];
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE){ if (a->dataType == DEFAULT_DTYPE){
if(stride > 1){ if(stride > 1){
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks); GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
...@@ -168,6 +171,8 @@ void _CudaSumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE ...@@ -168,6 +171,8 @@ void _CudaSumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE
else { else {
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
} }
BacktoCudaDev(a->devID, devIDBackup);
} }
#endif #endif
......
...@@ -37,8 +37,8 @@ copy indexed sub-tensors ...@@ -37,8 +37,8 @@ copy indexed sub-tensors
>> indexSize - length of srcIndex (and tgtIndex) >> indexSize - length of srcIndex (and tgtIndex)
>> tgtIndex - index of the target sub-tensors >> tgtIndex - index of the target sub-tensors
>> copyNum - number of the sub-tensors we copy for each source index, >> copyNum - number of the sub-tensors we copy for each source index,
e.g., for srcIndex = [1,4] and copyNum = 2, e.g., for srcIndex = [1,4] and copyNum = 2,
we actually copy the source sub-tensors 1, 2, 4, 5 we actually copy the source sub-tensors 1, 2, 4, 5
*/ */
void _CopyIndexed(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize, int * tgtIndex, int copyNum) void _CopyIndexed(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize, int * tgtIndex, int copyNum)
{ {
...@@ -73,17 +73,23 @@ void _CopyIndexed(const XTensor * s, XTensor * t, int dim, int * srcIndex, int i ...@@ -73,17 +73,23 @@ void _CopyIndexed(const XTensor * s, XTensor * t, int dim, int * srcIndex, int i
int * realSrcIndex = new int[realIndexSize]; int * realSrcIndex = new int[realIndexSize];
int * realTgtIndex = new int[realIndexSize]; int * realTgtIndex = new int[realIndexSize];
for (int i = 0; i < indexOffsetNum; i++) { for (int i = 0; i < indexOffsetNum; i++) {
int base = i * indexSize * copyNum;
int baseSrc = i * leadDimSizeSrc;
int baseTgt = i * leadDimSizeTgt;
for (int j = 0; j < indexSize; j++) { for (int j = 0; j < indexSize; j++) {
int offset = base + j * copyNum;
int * rsi = realSrcIndex + offset;
int * rti = realTgtIndex + offset;
for (int k = 0; k < copyNum; k++) { for (int k = 0; k < copyNum; k++) {
realSrcIndex[i * indexSize * copyNum + j * copyNum + k] = i * leadDimSizeSrc + srcIndex[j] + k; rsi[k] = baseSrc + srcIndex[j] + k;
realTgtIndex[i * indexSize * copyNum + j * copyNum + k] = i * leadDimSizeTgt + tgtIndex[j] + k; rti[k] = baseTgt + tgtIndex[j] + k;
} }
} }
} }
for (int i = 0; i < indexSize; i++) { for (int i = 0; i < indexSize; i++) {
CheckNTErrors((srcIndex[i] < blockNumSrc), "Index is out of range!"); CheckNTErrors((srcIndex[i] < blockNumSrc), "Index is out of scope!");
CheckNTErrors((tgtIndex[i] < blockNumTgt), "Index is out of range!"); CheckNTErrors((tgtIndex[i] < blockNumTgt), "Index is out of scope!");
} }
_CopyBlocks(s->data, blockSizeSrc * s->unitSize, realSrcIndex, realIndexSize, t->data, realTgtIndex, s->mem, s->devID); _CopyBlocks(s->data, blockSizeSrc * s->unitSize, realSrcIndex, realIndexSize, t->data, realTgtIndex, s->mem, s->devID);
......
...@@ -32,12 +32,108 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -32,12 +32,108 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
insert a dimension by copying the blocks for n times (where n is the size of the inerted dimension) insert a dimension by copying the blocks for n times (where n is the size of the inerted dimension)
>> s - pointer to the source data array >> s - pointer to the source data array
>> blockSize - size of a block >> blockSize - size of a block
>> totalSize - total size of the blocks (i.e., blockSIze * n)
>> t - pointer to the target data array
>> n - number of blocks to copy data
*/
template<class T>
__global__
void KernelUnsqueezeFlat(void * s, int blockSize, int totalSize, void * t, int n)
{
/* index of data items */
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= blockSize)
return;
T value = ((T*)s)[i];
T * tData = (T*)t;
__syncthreads();
for (int k = i; k < totalSize; k += blockSize)
tData[k] = value;
}
/*
insert a dimension by copying the blocks for n times (where n is the size of the inerted dimension)
>> s - pointer to the source data array
>> blockSize - size of a block
>> totalSize - total size of the blocks (i.e., blockSIze * n)
>> t - pointer to the target data array
>> n - number of blocks to copy data
*/
template<class T>
__global__
void KernelUnsqueezeFlatBigram(void * s, int blockSize, int totalSize, void * t, int n)
{
/* index of data items */
int i = (blockDim.x * blockIdx.x + threadIdx.x) * 2;
if (i >= blockSize)
return;
T value = ((T*)s)[i];
T value2 = ((T*)s)[i + 1];
T * tData = (T*)t;
__syncthreads();
for (int k = i; k < totalSize; k += blockSize){
tData[k] = value;
tData[k + 1] = value2;
}
}
/*
insert a dimension by copying the blocks for n times (where n is the size of the inerted dimension)
>> s - pointer to the source data array
>> blockSize - size of a block
>> totalSize - total size of the blocks (i.e., blockSIze * n)
>> t - pointer to the target data array
>> n - number of blocks to copy data
*/
template<class T>
__global__
void KernelUnsqueezeFlat2D(void * s, int blockSize, int totalSize, void * t, int n)
{
__shared__ T data[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ int offsets[MAX_CUDA_THREAD_NUM_PER_BLOCK];
/* index of data items */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* index of data items */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if (i >= blockSize || j >= n)
return;
if(threadIdx.y == 0)
data[threadIdx.x] = ((T*)s)[i];
if(threadIdx.x == 0)
offsets[threadIdx.y] = blockSize * j;
__syncthreads();
((T*)t)[offsets[threadIdx.y] + i] = data[threadIdx.x];
}
/*
insert a dimension by copying the blocks for n times (where n is the size of the inerted dimension)
>> s - pointer to the source data array
>> blockSize - size of a block
>> blockNum - number of the blocks >> blockNum - number of the blocks
>> totalSize - total size of the blocks (i.e., blockSIze * n)
>> t - pointer to the target data array >> t - pointer to the target data array
>> n - number of blocks to copy data
*/ */
template<class T> template<class T>
__global__ __global__
void KernelUnsqueeze(void * s, int blockSize, int blockNum, void * t, int n) void KernelUnsqueeze(void * s, int blockSize, int blockNum, int totalSize, void * t, int n)
{ {
/* index of data items */ /* index of data items */
int i = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.x * blockIdx.x + threadIdx.x;
...@@ -51,11 +147,10 @@ void KernelUnsqueeze(void * s, int blockSize, int blockNum, void * t, int n) ...@@ -51,11 +147,10 @@ void KernelUnsqueeze(void * s, int blockSize, int blockNum, void * t, int n)
MTYPE offset = blockSize * j; MTYPE offset = blockSize * j;
T value = ((T*)s)[offset + i]; T value = ((T*)s)[offset + i];
T * tData = (T*)t + offset * n; T * tData = (T*)t + offset * n;
int length = blockSize * n;
__syncthreads(); __syncthreads();
for (int k = i; k < length; k += blockSize) for (int k = i; k < totalSize; k += blockSize)
tData[k] = value; tData[k] = value;
} }
...@@ -83,21 +178,71 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize) ...@@ -83,21 +178,71 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
int cudaGrids[3]; int cudaGrids[3];
int cudaBlocks[3]; int cudaBlocks[3];
GDevs.GetCudaThread2D(a->devID, blockSize, blockNumA, MAX_INT, cudaGrids, cudaBlocks);
int devIDBackup = 0; int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup); ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == X_FLOAT && a->dataType == X_FLOAT) { if(blockNumA > 1){
KernelUnsqueeze<float> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> > GDevs.GetCudaThread2D(a->devID, blockSize, blockNumA, MAX_INT, cudaGrids, cudaBlocks);
(a->data, blockSize, blockNumA, b->data, dSize);
if (a->dataType == X_FLOAT && a->dataType == X_FLOAT) {
KernelUnsqueeze<float> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockNumA, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_INT && a->dataType == X_INT) {
KernelUnsqueeze<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockNumA, blockSize * dSize, b->data, dSize);
}
else {
ShowNTErrors("TODO!");
}
}
else if(blockNumA == 1 && blockSize < MAX_CUDA_THREAD_NUM_PER_BLOCK){
GDevs.GetCudaThread2D(a->devID, blockSize, dSize, MAX_CUDA_THREAD_NUM_PER_BLOCK/4, cudaGrids, cudaBlocks);
if (a->dataType == X_FLOAT && a->dataType == X_FLOAT) {
KernelUnsqueezeFlat2D<float> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_INT && a->dataType == X_INT) {
KernelUnsqueezeFlat2D<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else {
ShowNTErrors("TODO!");
}
}
else if(blockNumA == 1 && blockSize % 2 == 0){
GDevs.GetCudaThread(a->devID, blockSize/2, cudaGrids, cudaBlocks);
if (a->dataType == X_FLOAT && a->dataType == X_FLOAT) {
KernelUnsqueezeFlatBigram<float> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_INT && a->dataType == X_INT) {
KernelUnsqueezeFlatBigram<int> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else {
ShowNTErrors("TODO!");
}
} }
else if (a->dataType == X_INT && a->dataType == X_INT) { else if(blockNumA == 1){
KernelUnsqueeze<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> > GDevs.GetCudaThread(a->devID, blockSize, cudaGrids, cudaBlocks);
(a->data, blockSize, blockNumA, b->data, dSize);
if (a->dataType == X_FLOAT && a->dataType == X_FLOAT) {
KernelUnsqueezeFlat<float> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_INT && a->dataType == X_INT) {
KernelUnsqueezeFlat<int> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else {
ShowNTErrors("TODO!");
}
} }
else { else{
ShowNTErrors("TODO!"); ShowNTErrors("Something is wrong!");
} }
BacktoCudaDev(a->devID, devIDBackup); BacktoCudaDev(a->devID, devIDBackup);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论