/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

#include "../../XDevice.h"
#include "../../XTensor.h"
#include "Div.h"
#include "Div.cuh"

namespace nts { // namespace nts(NiuTrans.Tensor)

#ifdef USE_CUDA
/*
division of data arrays in a element-wise manner c(i) = a(i)/b(i)
>> a - data array a
>> b - data array b
>> c - result data array
>> size - size of c
*/
__global__
void KernelDivElementWise(DTYPE * a, DTYPE * b, DTYPE * c, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size)
        c[i] = a[i] / b[i];
}

/*
division of data arrays in a element-wise manner c(i) = a(i)/b(i) + \alpha*c(i)
>> a - data array a
>> b - data array b
>> c - result data array
>> size - size of c
>> alpha - the coefficient
*/
__global__
void KernelDivElementWiseV2(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE alpha)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size)
        c[i] = a[i] / b[i] + alpha * c[i];
}

/*
division of two tensors in a element-wise manner c(i) = a(i)/b(i).
Note that a and b can be of different sizes here, i.e.,
|a_lead| <= |c_lead| and |b_lead| <= |c_lead|
where |a_lead| means the size of the leading dimension of a
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> stride - the number of items we go over when move next along the leading dimension in a block
>> ldSizeA - size of the leading dimension of a
>> ldSizeB - size of the leading dimension of b
>> ldSizeC - size of the leading dimension of c
>> blockNum - number of blocks
*/
template<int nonZeroAlpha> __global__
void KernelDivElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE alpha,
    int stride, int ldSizeA, int ldSizeB, int ldSizeC, int blockNum)
{
    __shared__ DTYPE* ap[MAX_CUDA_THREAD_NUM_PER_BLOCK];
    __shared__ DTYPE* bp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
    __shared__ DTYPE* cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];

    int i = blockDim.x * blockIdx.x + threadIdx.x;
    int j = blockDim.y * blockIdx.y + threadIdx.y;

    if (i >= blockNum * stride || j >= ldSizeC)
        return;

    if (threadIdx.y == 0) {
        int block = i / stride;
        int size = block * stride;
        ap[threadIdx.x] = a + size * ldSizeA;
        bp[threadIdx.x] = b + size * ldSizeB;
        cp[threadIdx.x] = c + size * ldSizeC;
    }

    __syncthreads();

    int aj = j >= ldSizeA ? j % ldSizeA : j;
    int bj = j >= ldSizeB ? j % ldSizeB : j;
    int offseti = i % stride;

    if (nonZeroAlpha == 0)
        cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] / bp[threadIdx.x][bj * ldSizeB + offseti];
    else
        cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] / bp[threadIdx.x][bj * ldSizeB + offseti]
                                                 + alpha * cp[threadIdx.x][j * ldSizeC + offseti];
}

/*
element-wise division of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the item index
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> leadingDim - dimension along which we perform broadcasting
*/
void _CudaDiv(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim)
{
	int leadingDimRDI = a->order - leadingDim - 1;
    CheckNTErrors((a->unitNum <= c->unitNum && b->unitNum <= c->unitNum),
                  "Unmatched tensors in multiplication!");
    CheckNTErrors((a->order == b->order && a->order == c->order), "Unmatched tensors!");

    int stride = 1;
    int blockSizeA = 1;
    int blockNum = 1;
    int dimensionSizeA = a->dimSizeRDI[leadingDimRDI];
    int dimensionSizeB = b->dimSizeRDI[leadingDimRDI];
    int dimensionSizeC = c->dimSizeRDI[leadingDimRDI];

    for (int i = 0; i < a->order; i++) {
        if (i != leadingDimRDI) {
            CheckNTErrors((a->dimSizeRDI[i] == b->dimSizeRDI[i] &&
                           a->dimSizeRDI[i] == c->dimSizeRDI[i]),
                          "Unmatched tensors!");
        }
        if (i < leadingDimRDI)
            stride *= a->dimSizeRDI[i];
    }

    blockSizeA = stride * dimensionSizeA;
    blockNum = a->unitNum / blockSizeA;

    int devIDBackup;
    ProtectCudaDev(a->devID, devIDBackup);

    if (!a->isSparse && !b->isSparse) {
        if (a->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE) {
            int cudaGridSize[3];
            int cudaBlockSize[3];

            if (a->unitNum == c->unitNum && b->unitNum == c->unitNum) {
                GDevs.GetCudaThread(a->devID, c->unitNum, cudaGridSize, cudaBlockSize);
                dim3 blocks(cudaGridSize[0]), threads(cudaBlockSize[0]);

                if (alpha == 0)
                    KernelDivElementWise << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, c->unitNum);
                else
                    KernelDivElementWiseV2 << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, c->unitNum, alpha);
            }
            else {
                GDevs.GetCudaThread2D(c->devID, stride * blockNum, dimensionSizeC, MAX_INT, cudaGridSize, cudaBlockSize);
                dim3 blocks(cudaGridSize[0], cudaGridSize[1]), threads(cudaBlockSize[0], cudaBlockSize[1]);

                if (alpha == 0) {
                    KernelDivElementWiseTensorDynamic<0> << <blocks, threads >> >
                        ((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, 0,
                        stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
                }
                else {
                    KernelDivElementWiseTensorDynamic<1> << <blocks, threads >> >
                        ((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, alpha,
                        stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
                }
            }
        }
        else {
            // TODO!!
            ShowNTErrors("TODO!");
        }
    }
    else {
        // TODO!!
        ShowNTErrors("TODO!");
    }

    BacktoCudaDev(a->devID, devIDBackup);
}

#endif // USE_CUDA

} // namespace nts(NiuTrans.Tensor)