Commit 32e87681 by liyinqiao

Support sorted TopK function.

Add a new parameter isSorted in TopK function to sort the items in the top-k pool.
parent 48bdcb49
......@@ -35,8 +35,9 @@ get the top-k items along a given dimension
>> index - index of the top-k items
>> dim - the dimension along which the sorting is performed
>> k - how many items returned after sorting
>> isSorted - indicates whether the k items are sorted
*/
void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k, bool isSorted)
{
dim = MODX(dim, a->order);
......@@ -58,7 +59,7 @@ void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
if (a->devID >= 0 || b->devID >= 0) {
#ifdef USE_CUDA
_CudaTopK(a, b, index, dim, k);
_CudaTopK(a, b, index, dim, k, isSorted);
#else
ShowNTErrors("Plesae specify USE_CUDA and recompile the code!");
#endif
......@@ -116,15 +117,16 @@ get the top-k items along a given dimension
>> index - index of the top-k items
>> dim - the dimension along which the sorting is performed
>> k - how many items returned after sorting
>> isSorted - indicates whether the k items are sorted
*/
void TopK(XTensor &a, XTensor &b, XTensor &index, int dim, int k)
void TopK(XTensor &a, XTensor &b, XTensor &index, int dim, int k, bool isSorted)
{
dim = MODX(dim, a.order);
if(a.dimSize[dim] <= k)
_Sort(&a, &b, &index, dim);
else
_TopK(&a, &b, &index, dim, k);
_TopK(&a, &b, &index, dim, k, isSorted);
/* tensor connection */
//TensorList list(2);
......
......@@ -374,9 +374,10 @@ get the top-k items
>> minValue - min value of an item
>> output - the output data array
>> index - the output index array
>> isSorted - indicates whether the k items are sorted
*/
template<class T> __global__
void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index)
void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index, bool isSorted)
{
__shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE - 512 * sizeof(T)) / sizeof(CudaHeapNode<T>)];
__shared__ T eachHeapMaxValue[512];
......@@ -479,11 +480,24 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
int offset = stride * k * blockIndex + offsetInBlock;
T * dOutput = output + offset;
int * indexOutput = index + offset;
for (int q = 0; q < k; ++q){
if (isSorted)
{
for (int q = k - 1; q >= 0; q--) {
dOutput[stride * q] = ansHeapData.items[0].value;
indexOutput[stride * q] = ansHeapData.items[0].index;
ansHeapData.items[0] = ansHeapData.items[ansHeapData.count - 1];
ansHeapData.count--;
ansHeapData.Down(0);
}
}
else
{
for (int q = 0; q < k; ++q) {
dOutput[stride * q] = ansHeapData.items[q].value;
indexOutput[stride * q] = ansHeapData.items[q].index;
}
}
}
}
......@@ -803,8 +817,9 @@ get the top-k items along a given dimension
>> index - index of the top-k items
>> dim - the dimension along which the sorting is performed
>> k - how many items returned after sorting
>> isSorted - indicates whether the k items are sorted
*/
void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k, bool isSorted)
{
CheckNTErrors((a->unitSize == b->unitSize), "Unmatched input tensors!");
CheckNTErrors((a->order == b->order), "Unmatched input tensors!");
......@@ -846,7 +861,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
if (a->dataType == DEFAULT_DTYPE) {
KernelTopK3<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN,
(DTYPE*)b->data, (int*)index->data);
(DTYPE*)b->data, (int*)index->data, isSorted);
}
else {
ShowNTErrors("TODO!");
......@@ -882,6 +897,10 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
KernelTopKRadixSelect<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> (goutput, stride, strideNumA, blockNum, k, DTYPE_MIN, (DTYPE *)b->data, (int *)index->data, stride * strideNumA * blockNum);
deconvert2floatV2 <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> ((unsigned int *)a->data, (float *)goutput, stride, strideNumA, blockNum, strideNumA*blockNum*stride);
if (isSorted)
{
ShowNTErrors("TODO!");
}
}
}
......
......@@ -29,7 +29,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* get the top-k items along a given dimension */
void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k);
void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k, bool isSorted);
#endif // USE_CUDA
......
......@@ -27,10 +27,10 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* get the top-k items along a given dimension */
void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k);
void _TopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k, bool isSorted = false);
/* get the top-k items along a given dimension */
void TopK(XTensor &a, XTensor &b, XTensor &index, int dim, int k);
void TopK(XTensor &a, XTensor &b, XTensor &index, int dim, int k, bool isSorted = false);
} // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论