Commit 314a02af by xiaotong

improve the implementation of _CudaSpreadForGather

parent 8f665e61
......@@ -102,6 +102,9 @@ XTensor Gather(const XTensor &s, const XTensor &index)
srcIndex[i] = (int)tmp[i];
delete[] tmp;
}
else{
ShowNTErrors("Unsupported data type!");
}
XTensor tensor;
tensor = Gather(s, 0, srcIndex, indexSize);
......
......@@ -24,6 +24,7 @@
#include "../../XTensor.h"
#include "../../XDevice.h"
#include "../../XUtility.h"
#include "Spread.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -147,6 +148,54 @@ void KernelSpreadForGather(DTYPE * sData, DTYPE * cData, int blockNum,
}
/*
This is core assignment for backward computation of gather function.
Care of the operator "+=" instead of "=".
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - number of data blocks
>> blockSizeSrc - size of source data block
>> blockSizeColl - size of source data block
>> stride - stride of a data block
>> subtensorNum - number of sub-tensors
>> srcIndex - index of the source sub-tensor
>> colIndex - index of the sub-tensor in the collection tensor
*/
__global__
void KernelSpreadForGatherFuzed(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride,
int subtensorNum,
int * srcIndex, int * colIndex)
{
__shared__ DTYPE * sp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE * cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* offset in each block */
int offset = blockDim.y * blockIdx.y + threadIdx.y;
int blockId = i % blockNum;
int subtensorId = i / blockNum;
if(subtensorId >= subtensorNum || offset >= stride)
return;
if(threadIdx.y == 0){
sp[threadIdx.x] = sData + srcIndex[subtensorId] * stride;
cp[threadIdx.x] = cData + colIndex[subtensorId] * stride;
}
__syncthreads();
DTYPE * s = sp[threadIdx.x] + blockSizeSrc * blockId;
DTYPE * c = cp[threadIdx.x] + blockSizeColl * blockId;
s[offset] += c[offset];
}
/*
spread a collection tensor to source tensor (cuda version).
And this is a special spread function for backward computation of gather function.
......@@ -172,9 +221,8 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
int blockNum = 1;
int stride = 1;
for (int i = dim + 1; i < order; i++) {
for (int i = dim + 1; i < order; i++)
stride *= source->GetDim(i);
}
blockSizeSrc = stride * source->GetDim(dim);
blockSizeColl = stride * collection->GetDim(dim);
......@@ -183,14 +231,15 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
int cudaGrids[3];
int cudaBlocks[3];
int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup);
if(indexSize < 4){
GDevs.GetCudaThread2D(source->devID, blockNum, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup);
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++) {
......@@ -201,6 +250,32 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
KernelSpreadForGather<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
}
}
else{
GDevs.GetCudaThread2D(source->devID, blockNum * indexSize, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
DTYPE * s = (DTYPE*)source->data;
DTYPE * c = (DTYPE*)collection->data;
XMem * mem = source->mem;
int * si = mem != NULL ?
(int*)mem->AllocBuf(mem->devID, sizeof(int) * indexSize * 2) :
(int*)XMemAlloc(mem->devID, sizeof(int) * indexSize * 2);
int * ci = si + indexSize;
XMemCopy(si, mem->devID, srcIndex, -1, sizeof(int) * indexSize);
XMemCopy(ci, mem->devID, collIndex, -1, sizeof(int) * indexSize);
KernelSpreadForGatherFuzed<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride, indexSize, si, ci);
if(mem != NULL)
mem->ReleaseBuf(mem->devID, sizeof(int) * indexSize * 2);
else
XMemFree(mem->devID, si);
}
BacktoCudaDev(source->devID, devIDBackup);
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论