Commit 314a02af by xiaotong

improve the implementation of _CudaSpreadForGather

parent 8f665e61
...@@ -102,6 +102,9 @@ XTensor Gather(const XTensor &s, const XTensor &index) ...@@ -102,6 +102,9 @@ XTensor Gather(const XTensor &s, const XTensor &index)
srcIndex[i] = (int)tmp[i]; srcIndex[i] = (int)tmp[i];
delete[] tmp; delete[] tmp;
} }
else{
ShowNTErrors("Unsupported data type!");
}
XTensor tensor; XTensor tensor;
tensor = Gather(s, 0, srcIndex, indexSize); tensor = Gather(s, 0, srcIndex, indexSize);
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XDevice.h" #include "../../XDevice.h"
#include "../../XUtility.h"
#include "Spread.cuh" #include "Spread.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
...@@ -146,6 +147,54 @@ void KernelSpreadForGather(DTYPE * sData, DTYPE * cData, int blockNum, ...@@ -146,6 +147,54 @@ void KernelSpreadForGather(DTYPE * sData, DTYPE * cData, int blockNum,
s[j] += c[j]; s[j] += c[j];
} }
/*
This is core assignment for backward computation of gather function.
Care of the operator "+=" instead of "=".
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - number of data blocks
>> blockSizeSrc - size of source data block
>> blockSizeColl - size of source data block
>> stride - stride of a data block
>> subtensorNum - number of sub-tensors
>> srcIndex - index of the source sub-tensor
>> colIndex - index of the sub-tensor in the collection tensor
*/
__global__
void KernelSpreadForGatherFuzed(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride,
int subtensorNum,
int * srcIndex, int * colIndex)
{
__shared__ DTYPE * sp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE * cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* offset in each block */
int offset = blockDim.y * blockIdx.y + threadIdx.y;
int blockId = i % blockNum;
int subtensorId = i / blockNum;
if(subtensorId >= subtensorNum || offset >= stride)
return;
if(threadIdx.y == 0){
sp[threadIdx.x] = sData + srcIndex[subtensorId] * stride;
cp[threadIdx.x] = cData + colIndex[subtensorId] * stride;
}
__syncthreads();
DTYPE * s = sp[threadIdx.x] + blockSizeSrc * blockId;
DTYPE * c = cp[threadIdx.x] + blockSizeColl * blockId;
s[offset] += c[offset];
}
/* /*
spread a collection tensor to source tensor (cuda version). spread a collection tensor to source tensor (cuda version).
And this is a special spread function for backward computation of gather function. And this is a special spread function for backward computation of gather function.
...@@ -172,9 +221,8 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim, ...@@ -172,9 +221,8 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
int blockNum = 1; int blockNum = 1;
int stride = 1; int stride = 1;
for (int i = dim + 1; i < order; i++) { for (int i = dim + 1; i < order; i++)
stride *= source->GetDim(i); stride *= source->GetDim(i);
}
blockSizeSrc = stride * source->GetDim(dim); blockSizeSrc = stride * source->GetDim(dim);
blockSizeColl = stride * collection->GetDim(dim); blockSizeColl = stride * collection->GetDim(dim);
...@@ -183,23 +231,50 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim, ...@@ -183,23 +231,50 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
int cudaGrids[3]; int cudaGrids[3];
int cudaBlocks[3]; int cudaBlocks[3];
GDevs.GetCudaThread2D(source->devID, blockNum, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup; int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup); ProtectCudaDev(source->devID, devIDBackup);
if(indexSize < 4){
GDevs.GetCudaThread2D(source->devID, blockNum, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
DTYPE * sData = (DTYPE*)source->data; DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data; DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++) { for(int i = 0; i < indexSize; i++) {
int src = srcIndex[i]; int src = srcIndex[i];
int tgt = collIndex[i]; int tgt = collIndex[i];
DTYPE * s = sData + src * stride; DTYPE * s = sData + src * stride;
DTYPE * c = cData + tgt * stride; DTYPE * c = cData + tgt * stride;
KernelSpreadForGather<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
}
}
else{
GDevs.GetCudaThread2D(source->devID, blockNum * indexSize, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
DTYPE * s = (DTYPE*)source->data;
DTYPE * c = (DTYPE*)collection->data;
XMem * mem = source->mem;
int * si = mem != NULL ?
(int*)mem->AllocBuf(mem->devID, sizeof(int) * indexSize * 2) :
(int*)XMemAlloc(mem->devID, sizeof(int) * indexSize * 2);
int * ci = si + indexSize;
XMemCopy(si, mem->devID, srcIndex, -1, sizeof(int) * indexSize);
XMemCopy(ci, mem->devID, collIndex, -1, sizeof(int) * indexSize);
KernelSpreadForGatherFuzed<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride, indexSize, si, ci);
KernelSpreadForGather<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride); if(mem != NULL)
mem->ReleaseBuf(mem->devID, sizeof(int) * indexSize * 2);
else
XMemFree(mem->devID, si);
} }
BacktoCudaDev(source->devID, devIDBackup); BacktoCudaDev(source->devID, devIDBackup);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论