Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
314a02af
Commit
314a02af
authored
Nov 21, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
improve the implementation of _CudaSpreadForGather
parent
8f665e61
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
93 行增加
和
15 行删除
+93
-15
source/tensor/core/movement/Gather.cpp
+3
-0
source/tensor/core/movement/Spread.cu
+90
-15
没有找到文件。
source/tensor/core/movement/Gather.cpp
查看文件 @
314a02af
...
...
@@ -102,6 +102,9 @@ XTensor Gather(const XTensor &s, const XTensor &index)
srcIndex
[
i
]
=
(
int
)
tmp
[
i
];
delete
[]
tmp
;
}
else
{
ShowNTErrors
(
"Unsupported data type!"
);
}
XTensor
tensor
;
tensor
=
Gather
(
s
,
0
,
srcIndex
,
indexSize
);
...
...
source/tensor/core/movement/Spread.cu
查看文件 @
314a02af
...
...
@@ -24,6 +24,7 @@
#include "../../XTensor.h"
#include "../../XDevice.h"
#include "../../XUtility.h"
#include "Spread.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
...
...
@@ -146,6 +147,54 @@ void KernelSpreadForGather(DTYPE * sData, DTYPE * cData, int blockNum,
s[j] += c[j];
}
/*
This is core assignment for backward computation of gather function.
Care of the operator "+=" instead of "=".
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - number of data blocks
>> blockSizeSrc - size of source data block
>> blockSizeColl - size of source data block
>> stride - stride of a data block
>> subtensorNum - number of sub-tensors
>> srcIndex - index of the source sub-tensor
>> colIndex - index of the sub-tensor in the collection tensor
*/
__global__
void KernelSpreadForGatherFuzed(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride,
int subtensorNum,
int * srcIndex, int * colIndex)
{
__shared__ DTYPE * sp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE * cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* offset in each block */
int offset = blockDim.y * blockIdx.y + threadIdx.y;
int blockId = i % blockNum;
int subtensorId = i / blockNum;
if(subtensorId >= subtensorNum || offset >= stride)
return;
if(threadIdx.y == 0){
sp[threadIdx.x] = sData + srcIndex[subtensorId] * stride;
cp[threadIdx.x] = cData + colIndex[subtensorId] * stride;
}
__syncthreads();
DTYPE * s = sp[threadIdx.x] + blockSizeSrc * blockId;
DTYPE * c = cp[threadIdx.x] + blockSizeColl * blockId;
s[offset] += c[offset];
}
/*
spread a collection tensor to source tensor (cuda version).
And this is a special spread function for backward computation of gather function.
...
...
@@ -172,9 +221,8 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
int blockNum = 1;
int stride = 1;
for (int i = dim + 1; i < order; i++)
{
for (int i = dim + 1; i < order; i++)
stride *= source->GetDim(i);
}
blockSizeSrc = stride * source->GetDim(dim);
blockSizeColl = stride * collection->GetDim(dim);
...
...
@@ -183,23 +231,50 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(source->devID, blockNum, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup);
if(indexSize < 4){
GDevs.GetCudaThread2D(source->devID, blockNum, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++) {
int src = srcIndex[i];
int tgt = collIndex[i];
DTYPE * s = sData + src * stride;
DTYPE * c = cData + tgt * stride;
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++) {
int src = srcIndex[i];
int tgt = collIndex[i];
DTYPE * s = sData + src * stride;
DTYPE * c = cData + tgt * stride;
KernelSpreadForGather<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
}
}
else{
GDevs.GetCudaThread2D(source->devID, blockNum * indexSize, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
DTYPE * s = (DTYPE*)source->data;
DTYPE * c = (DTYPE*)collection->data;
XMem * mem = source->mem;
int * si = mem != NULL ?
(int*)mem->AllocBuf(mem->devID, sizeof(int) * indexSize * 2) :
(int*)XMemAlloc(mem->devID, sizeof(int) * indexSize * 2);
int * ci = si + indexSize;
XMemCopy(si, mem->devID, srcIndex, -1, sizeof(int) * indexSize);
XMemCopy(ci, mem->devID, collIndex, -1, sizeof(int) * indexSize);
KernelSpreadForGatherFuzed<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride, indexSize, si, ci);
KernelSpreadForGather<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
if(mem != NULL)
mem->ReleaseBuf(mem->devID, sizeof(int) * indexSize * 2);
else
XMemFree(mem->devID, si);
}
BacktoCudaDev(source->devID, devIDBackup);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论