Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
965bf9d6
Commit
965bf9d6
authored
Jul 24, 2019
by
张裕浩
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
topk BUG fix
parent
6d05d6c8
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
11 行增加
和
9 行删除
+11
-9
source/tensor/core/sort/TopK.cu
+11
-9
没有找到文件。
source/tensor/core/sort/TopK.cu
查看文件 @
965bf9d6
...
...
@@ -377,8 +377,8 @@ get the top-k items
template<class T> __global__
void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index)
{
__shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE -
1024
* sizeof(T)) / sizeof(CudaHeapNode<T>)];
__shared__ T eachHeapMaxValue[
1024
];
__shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE -
512
* sizeof(T)) / sizeof(CudaHeapNode<T>)];
__shared__ T eachHeapMaxValue[
512
];
/*optimization k size the parameter must more than half of k*/
int parameter = 0;
...
...
@@ -429,7 +429,7 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
}
__syncthreads();
/*
to merge the heap use another way
*/
/*
to merge the heap use another way
*/
T minData = minValue;
int heapLimit = heap.count / 2;
if (heapLimit % 2 == 0 && heapLimit != 0) heapLimit -= 1;
...
...
@@ -438,12 +438,13 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
minData = heap.items[counter].value;
}
eachHeapMaxValue[threadIdx.y * blockDim.x + threadIdx.x] = minData;
//need more optimation
if (i == 0) {
int threadLimit =
(threadIdx.y + 1) * blockDim.x
;
int threadLimit =
threadIdx.y * blockDim.x + min(blockDim.x,strideNum)
;
CudaXHeap<MIN_HEAP, T> chooseHeap(k, heapData + k * ((blockDim.x * blockDim.y) + threadIdx.y));
int counter = threadIdx.y * blockDim.x;
for (; counter < threadIdx.y * blockDim.x +
k
; ++counter) {
for (; counter < threadIdx.y * blockDim.x +
min(k, blockDim.x)
; ++counter) {
chooseHeap.Push(counter, eachHeapMaxValue[counter]);
}
for (; counter < threadLimit; ++counter) {
...
...
@@ -451,15 +452,16 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
chooseHeap.ReplaceTop(counter, eachHeapMaxValue[counter]);
}
}
int heapNum = chooseHeap.count;
CudaXHeap<MIN_HEAP, T> ansHeapData(k, k - parameter, heapData + k * chooseHeap.items[0].index);
int miss = parameter;
for (counter = 1; counter <
k
; ++counter) {
for (counter = 1; counter <
heapNum
; ++counter) {
chooseHeap.items[0] = chooseHeap.items[chooseHeap.count - 1];
chooseHeap.count--;
chooseHeap.Down(0);
CudaHeapNode<T> * cmpHeapData = heapData + k * (chooseHeap.items[0].index);
int cmpHeapLimit = 0;
if (counter + heapLimit <= k - parameter){
if (counter + heapLimit <= k - parameter
&& heapNum == k
){
cmpHeapLimit = heapLimit;
}
/* take the max data from the minHeap,so start search from the leaf node */
...
...
@@ -826,7 +828,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
else if (k < 22) workerNum = 128;
else if (k < 44) workerNum = 64;
else workerNum = 32;
int cudaGrids[3];
int cudaBlocks[3];
...
...
@@ -840,7 +842,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
/* we run the kernel if the heaps can fit into the shared memory */
cudaGrids[1] *= cudaBlocks[1];
cudaBlocks[1] = 1;
if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) < SHARED_MEMORY_SIZE) {
if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int))
+ (512 * sizeof(int))
< SHARED_MEMORY_SIZE) {
if (a->dataType == DEFAULT_DTYPE) {
KernelTopK3<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论