/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

#include "../XDevice.h"
#include "../XTensor.h"
#include "TopK.h"
#include "TopK.cuh"
#include "Sort.cuh"

namespace nts { // namespace nts(NiuTrans.Tensor)

#ifdef USE_CUDA

/* heap item */
template <typename T>
struct CudaHeapNode
{
    /* node index */
    int index;

    /* value of the node */
    T value;

    __device__ CudaHeapNode() {};

    __device__ CudaHeapNode(int i, T v)
    {
        index = i;
        value = v;
    };
};

/* heap (device code) */
template<HeapType hType, typename T>
class CudaXHeap
{
public:
    /* number of the items the heap keeps */
    int size;

    /* number of the items that are already in the heap */
    int count;

    /* items */
    CudaHeapNode<T> * items;

    /* value for the top-most item*/
    T topValue;

public:
    /* constructor */
    __device__ CudaXHeap(int mySize, CudaHeapNode<T> * myItems)
    {
        size = mySize;
        count = 0;
        items = myItems;
        topValue = 0;
    }
    /* constructor */
    __device__ CudaXHeap(int mySize, int myCount, CudaHeapNode<T> * myItems)
    {
        size = mySize;
        count = myCount;
        items = myItems;
        topValue = items[0].value;
    }
    /* compare node i and node j */
    __device__ bool Compare(int i, int j)
    {
        if (hType == MIN_HEAP)
            return items[i].value < items[j].value;
        else
            return items[j].value < items[i].value;
    }

    /* swap */
    __device__ void Swap(int i, int j)
    {
        /*CudaHeapNode<T> tmp = items[i];
        items[i] = items[j];
        items[j] = tmp;*/
        int tmpIndex = items[i].index;
        T tmpValue = items[i].value;
        items[i] = items[j];
        items[j].index = tmpIndex;
        items[j].value = tmpValue;
    }

    /* replace the top-most item and update the heap */
    __device__ void ReplaceTop(CudaHeapNode<T> node)
    {
        items[0] = node;
        Down(0);
        topValue = items[0].value;
    }

    /* replace the top-most item and update the heap */
    __device__ void ReplaceTop(int index, T value)
    {
        items[0].index = index;
        items[0].value = value;
        Down(0);
        topValue = items[0].value;
    }

    /* push an item into the heap */
    __device__ void Push(CudaHeapNode<T> node)
    {
        items[count] = node;
        Up(count);
        count++;
        topValue = items[0].value;
    }

    /* push an item into the heap */
    __device__ void Push(int index, T value)
    {
        items[count].index = index;
        items[count].value = value;
        Up(count);
        count++;
        topValue = items[0].value;
    }

    /* move item k down the tree */
    __device__ void Down(int k)
    {
        int i = k;
        int i2 = i + i;
        while (i2 + 1 < count) {
            int l = i2 + 1;
            int r = i2 + 2;
            int m = (Compare(l, r) || r >= count) ? l : r;
            if (Compare(i, m))
                break;
            Swap(i, m);
            i = m;
            i2 = m << 1;
        }
    }

    /* move item k up the tree */
    __device__ void Up(int k)
    {
        int i = k;
        int parent = (i - 1) >> 1;
        while (i > 0 && !Compare(parent, i)) {
            Swap(parent, i);
            i = parent;
            parent = (i - 1) >> 1;
        }
    }
};

/*
get the top-k items
>> input - the input data array
>> stride - number of items we go over when we move to the next item along a given dimension
>> strideNum - size of the given dimension
>> blockNum - number of data blocks
>> k - as it is
>> minValue - min value of an item
>> output - the output data array
>> index - the output index array
*/
template<class T> __global__
void KernelTopK(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index)
{
    __shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE) / sizeof(CudaHeapNode<T>)];

    /* worker index */
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    /* index of the data arry along the given dimension */
    int j = blockDim.y * blockIdx.y + threadIdx.y;

    if (i >= strideNum || i >= blockDim.x || j >= stride * blockNum)
        return;

    int blockIndex = j / stride;
    int offsetInBlock = j % stride;
    T * d = input + stride * strideNum * blockIndex + offsetInBlock;

    CudaXHeap<MIN_HEAP, T> heap(k, heapData + k * (threadIdx.y * blockDim.x + threadIdx.x));
    __syncthreads();

    /* go over the data array and build the heap */
    int indexOffset = blockDim.x;
    int dataOffset = stride * blockDim.x;

    if (i + (heap.size - 1) * indexOffset < strideNum) {
        int p = i;
        int q = i * stride;
        for (int m = 0; m < heap.size; m++) {
            heap.Push(p, d[q]);
            p += indexOffset;
            q += dataOffset;
        }

        for (; p < strideNum; p += indexOffset, q += dataOffset) {
            T v = d[q];
            if (v > heap.topValue) {
                heap.ReplaceTop(p, v);
            }
        }
    }
    else {
        for (int p = i, q = i * stride; p < strideNum; p += indexOffset, q += dataOffset) {
            heap.Push(p, d[q]);
        }
    }

    /* fill the heap if no enough items are processed */
    while (heap.count < heap.size) {
        heap.Push(-1, minValue);
    }

    __syncthreads();

    if (threadIdx.x == 0) {
        CudaXHeap<MIN_HEAP, T> heapFinal(k, k, heapData + k * threadIdx.y * blockDim.x);

        /* merge the result over the workers.
        This can be improved by parallel merging */
        if (blockDim.x > 1) {
            for (int p = 1; p < blockDim.x && p < strideNum; p++) {
                CudaHeapNode<T> * hd = heapData + k * (threadIdx.y * blockDim.x + p);
                for (int q = 0; q < k; q++) {
                    if (hd[q].value > heapFinal.topValue)
                        heapFinal.ReplaceTop(hd[q]);
                }
            }
        }

        int offset = stride * k * blockIndex + offsetInBlock;
        T * dOutput = output + offset;
        int * indexOutput = index + offset;

        /* pop for the final result */
        for (int q = k - 1; q >= 0; q--) {
            dOutput[stride * q] = heapFinal.items[0].value;
            indexOutput[stride * q] = heapFinal.items[0].index;
            heapFinal.items[0] = heapFinal.items[heapFinal.count - 1];
            heapFinal.count--;
            heapFinal.Down(0);
        }
    }
}

/*
get the top-k items
>> input - the input data array
>> stride - number of items we go over when we move to the next item along a given dimension
>> strideNum - size of the given dimension
>> blockNum - number of data blocks
>> k - as it is
>> minValue - min value of an item
>> output - the output data array
>> index - the output index array
*/
template<class T> __global__
void KernelTopK2(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index)
{
    __shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE) / sizeof(CudaHeapNode<T>)];

    /* worker index */
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    /* index of the data arry along the given dimension */
    int j = blockDim.y * blockIdx.y + threadIdx.y;

    if (i >= strideNum || i >= blockDim.x || j >= stride * blockNum)
        return;

    int blockIndex = j / stride;
    int offsetInBlock = j % stride;
    T * d = input + stride * strideNum * blockIndex + offsetInBlock;

    CudaXHeap<MIN_HEAP, T> heap(k, heapData + k * (threadIdx.y * blockDim.x + threadIdx.x));
    __syncthreads();

    /* go over the data array and build the heap */
    int indexOffset = blockDim.x;
    int dataOffset = stride * blockDim.x;

    if (i + (heap.size - 1) * indexOffset < strideNum) {
        int p = i;
        int q = i * stride;
        for (int m = 0; m < heap.size; m++) {
            heap.Push(p, d[q]);
            p += indexOffset;
            q += dataOffset;
        }

        for (; p < strideNum; p += indexOffset, q += dataOffset) {
            T v = d[q];
            if (v > heap.topValue) {
                heap.ReplaceTop(p, v);
            }
        }
    }
    else {
        for (int p = i, q = i * stride; p < strideNum; p += indexOffset, q += dataOffset) {
            heap.Push(p, d[q]);
        }
    }

    /* fill the heap if no enough items are processed */
    while (heap.count < heap.size) {
        heap.Push(-1, minValue);
    }

    __syncthreads();

    /* parallel merging */
    int heapOffset = threadIdx.y * blockDim.x;
    CudaHeapNode<T> * heapLocalData = heapData + k * (heapOffset + i);
    CudaXHeap<MIN_HEAP, T> heapLocal(k, k, heapLocalData);
    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (threadIdx.x < s && i + s < strideNum) {
            CudaHeapNode<T> * hd = heapLocalData + k * s;
            for (int q = 0; q < k; q++) {
                if (hd[q].value > heapLocal.topValue)
                    heapLocal.ReplaceTop(hd[q]);
            }
        }
        __syncthreads();
    }

    if (threadIdx.x == 0) {
        CudaXHeap<MIN_HEAP, T> heapFinal(k, k, heapData + k * heapOffset);
        int offset = stride * k * blockIndex + offsetInBlock;
        T * dOutput = output + offset;
        int * indexOutput = index + offset;

        /* pop for the final result */
        for (int q = k - 1; q >= 0; q--) {
            dOutput[stride * q] = heapFinal.items[0].value;
            indexOutput[stride * q] = heapFinal.items[0].index;
            heapFinal.items[0] = heapFinal.items[heapFinal.count - 1];
            heapFinal.count--;
            heapFinal.Down(0);
        }
    }
}

/*
get the top-k items along a given dimension
>> a - input tensor
>> b - output tensor (top-k result)
>> index - index of the top-k items
>> dim - the dimension along which the sorting is performed
>> k - how many items returned after sorting
*/
void CudaTopK(XTensor * a, XTensor * b, XTensor * index, int dim, int k)
{
    CheckNTErrors((a->unitSize == b->unitSize), "Unmatched input tensors!");
    CheckNTErrors((a->order == b->order), "Unmatched input tensors!");
    CheckNTErrors((index == NULL || a->order == index->order), "Unmatched input tensors!");
    CheckNTErrors((index->dataType == X_INT), "Wrong data type!");
    CheckNTErrors((b->dimSize[dim] == k), "A too large K");

    int dimRDI = a->order - dim - 1;
    int stride = 1;
    int strideNumA = a->dimSizeRDI[dimRDI];
    for (int i = 0; i < dimRDI; i++)
        stride *= a->dimSizeRDI[i];

    int blockNum = 1;
    for (int i = dimRDI + 1; i < a->order; i++)
        blockNum *= a->dimSizeRDI[i];

    int workerNum = blockNum < 16 ? 64 : 32; // should be tuned for better performance

    int cudaGrids[3];
    int cudaBlocks[3];

    GDevs.GetCudaThread2D(a->mem->devID,
        workerNum, stride * blockNum, MAX_INT,
        cudaGrids, cudaBlocks);

    for (int i = 0; i < 2; i++) {
        if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) >= SHARED_MEMORY_SIZE) {
            if (cudaBlocks[1] >= 2 && cudaBlocks[1] % 2 == 0) {
                cudaBlocks[1] /= 2;
                cudaGrids[1] *= 2;
            }
        }

        if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) >= SHARED_MEMORY_SIZE) {
            if (cudaBlocks[0] >= 2 && cudaBlocks[0] % 2 == 0) {
                cudaBlocks[0] /= 2;
                cudaGrids[0] *= 2;
            }
        }
    }

    int devIDBackup = 0;
    ProtectCudaDev(a->devID, devIDBackup);

    /* we run the kernel if the heaps can fit into the shared memory */
    if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) < SHARED_MEMORY_SIZE) {
        if (a->dataType == DEFAULT_DTYPE) {
            KernelTopK2<DTYPE> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
                                 ((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN,
                                 (DTYPE*)b->data, (int*)index->data);
        }
        else {
            ShowNTErrors("TODO!");
        }

    }
    /* we resort to sorting if the data cannot fit inside the shared memory */
    else {
        int dimSize[MAX_TENSOR_DIM_NUM];
        memcpy(dimSize, a->dimSize, sizeof(int) * a->order);
        dimSize[0] = -dimSize[0];
        XTensor * indexA = new XTensor(a->order, dimSize, X_INT, 1.0F, a->mem);
        indexA->data = a->mem->AllocBuf(a->devID, a->unitNum * sizeof(int));

        /* make the index tensor */
        indexA->SetAscendingOrder(dim);

        CudaSortBig(a, b, indexA, index, dim, k);

        a->mem->ReleaseBuf(a->devID, a->unitNum * sizeof(int));
        delete indexA;
    }

    BacktoCudaDev(a->devID, devIDBackup);
}

#endif // USE_CUDA

} // namespace nts(NiuTrans.Tensor)