Commit 32e87681 by liyinqiao

Support sorted TopK function.

Add a new parameter isSorted in TopK function to sort the items in the top-k pool.
parent 48bdcb49
...@@ -35,8 +35,9 @@ get the top-k items along a given dimension ...@@ -35,8 +35,9 @@ get the top-k items along a given dimension
>> index - index of the top-k items >> index - index of the top-k items
>> dim - the dimension along which the sorting is performed >> dim - the dimension along which the sorting is performed
>> k - how many items returned after sorting >> k - how many items returned after sorting
>> isSorted - indicates whether the k items are sorted
*/ */
void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k, bool isSorted)
{ {
dim = MODX(dim, a->order); dim = MODX(dim, a->order);
...@@ -58,7 +59,7 @@ void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -58,7 +59,7 @@ void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
if (a->devID >= 0 || b->devID >= 0) { if (a->devID >= 0 || b->devID >= 0) {
#ifdef USE_CUDA #ifdef USE_CUDA
_CudaTopK(a, b, index, dim, k); _CudaTopK(a, b, index, dim, k, isSorted);
#else #else
ShowNTErrors("Plesae specify USE_CUDA and recompile the code!"); ShowNTErrors("Plesae specify USE_CUDA and recompile the code!");
#endif #endif
...@@ -116,15 +117,16 @@ get the top-k items along a given dimension ...@@ -116,15 +117,16 @@ get the top-k items along a given dimension
>> index - index of the top-k items >> index - index of the top-k items
>> dim - the dimension along which the sorting is performed >> dim - the dimension along which the sorting is performed
>> k - how many items returned after sorting >> k - how many items returned after sorting
>> isSorted - indicates whether the k items are sorted
*/ */
void TopK(XTensor &a, XTensor &b, XTensor &index, int dim, int k) void TopK(XTensor &a, XTensor &b, XTensor &index, int dim, int k, bool isSorted)
{ {
dim = MODX(dim, a.order); dim = MODX(dim, a.order);
if(a.dimSize[dim] <= k) if(a.dimSize[dim] <= k)
_Sort(&a, &b, &index, dim); _Sort(&a, &b, &index, dim);
else else
_TopK(&a, &b, &index, dim, k); _TopK(&a, &b, &index, dim, k, isSorted);
/* tensor connection */ /* tensor connection */
//TensorList list(2); //TensorList list(2);
......
...@@ -374,9 +374,10 @@ get the top-k items ...@@ -374,9 +374,10 @@ get the top-k items
>> minValue - min value of an item >> minValue - min value of an item
>> output - the output data array >> output - the output data array
>> index - the output index array >> index - the output index array
>> isSorted - indicates whether the k items are sorted
*/ */
template<class T> __global__ template<class T> __global__
void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index) void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index, bool isSorted)
{ {
__shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE - 512 * sizeof(T)) / sizeof(CudaHeapNode<T>)]; __shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE - 512 * sizeof(T)) / sizeof(CudaHeapNode<T>)];
__shared__ T eachHeapMaxValue[512]; __shared__ T eachHeapMaxValue[512];
...@@ -479,9 +480,22 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -479,9 +480,22 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
int offset = stride * k * blockIndex + offsetInBlock; int offset = stride * k * blockIndex + offsetInBlock;
T * dOutput = output + offset; T * dOutput = output + offset;
int * indexOutput = index + offset; int * indexOutput = index + offset;
for (int q = 0; q < k; ++q){ if (isSorted)
dOutput[stride * q] = ansHeapData.items[q].value; {
indexOutput[stride * q] = ansHeapData.items[q].index; for (int q = k - 1; q >= 0; q--) {
dOutput[stride * q] = ansHeapData.items[0].value;
indexOutput[stride * q] = ansHeapData.items[0].index;
ansHeapData.items[0] = ansHeapData.items[ansHeapData.count - 1];
ansHeapData.count--;
ansHeapData.Down(0);
}
}
else
{
for (int q = 0; q < k; ++q) {
dOutput[stride * q] = ansHeapData.items[q].value;
indexOutput[stride * q] = ansHeapData.items[q].index;
}
} }
} }
} }
...@@ -803,8 +817,9 @@ get the top-k items along a given dimension ...@@ -803,8 +817,9 @@ get the top-k items along a given dimension
>> index - index of the top-k items >> index - index of the top-k items
>> dim - the dimension along which the sorting is performed >> dim - the dimension along which the sorting is performed
>> k - how many items returned after sorting >> k - how many items returned after sorting
>> isSorted - indicates whether the k items are sorted
*/ */
void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k, bool isSorted)
{ {
CheckNTErrors((a->unitSize == b->unitSize), "Unmatched input tensors!"); CheckNTErrors((a->unitSize == b->unitSize), "Unmatched input tensors!");
CheckNTErrors((a->order == b->order), "Unmatched input tensors!"); CheckNTErrors((a->order == b->order), "Unmatched input tensors!");
...@@ -846,7 +861,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -846,7 +861,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
if (a->dataType == DEFAULT_DTYPE) { if (a->dataType == DEFAULT_DTYPE) {
KernelTopK3<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> KernelTopK3<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN, ((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN,
(DTYPE*)b->data, (int*)index->data); (DTYPE*)b->data, (int*)index->data, isSorted);
} }
else { else {
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
...@@ -882,6 +897,10 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -882,6 +897,10 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
KernelTopKRadixSelect<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> (goutput, stride, strideNumA, blockNum, k, DTYPE_MIN, (DTYPE *)b->data, (int *)index->data, stride * strideNumA * blockNum); KernelTopKRadixSelect<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> (goutput, stride, strideNumA, blockNum, k, DTYPE_MIN, (DTYPE *)b->data, (int *)index->data, stride * strideNumA * blockNum);
deconvert2floatV2 <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> ((unsigned int *)a->data, (float *)goutput, stride, strideNumA, blockNum, strideNumA*blockNum*stride); deconvert2floatV2 <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> ((unsigned int *)a->data, (float *)goutput, stride, strideNumA, blockNum, strideNumA*blockNum*stride);
if (isSorted)
{
ShowNTErrors("TODO!");
}
} }
} }
......
...@@ -29,7 +29,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -29,7 +29,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA #ifdef USE_CUDA
/* get the top-k items along a given dimension */ /* get the top-k items along a given dimension */
void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k); void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k, bool isSorted);
#endif // USE_CUDA #endif // USE_CUDA
......
...@@ -27,10 +27,10 @@ ...@@ -27,10 +27,10 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* get the top-k items along a given dimension */ /* get the top-k items along a given dimension */
void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k); void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k, bool isSorted = false);
/* get the top-k items along a given dimension */ /* get the top-k items along a given dimension */
void TopK(XTensor &a, XTensor &b, XTensor &index, int dim, int k); void TopK(XTensor &a, XTensor &b, XTensor &index, int dim, int k, bool isSorted = false);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论