/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

#include "../XDevice.h"
#include "../XTensor.h"
#include "Power.h"
#include "Power.cuh"

namespace nts { // namespace nts(NiuTrans.Tensor)

#ifdef USE_CUDA

/*
set all entries to its root (CUDA Kernel)
>> d - data array
>> size - size of the data array
*/
__global__
void KernelSqrtV2(DTYPE * d, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size)
        d[i] = sqrt(d[i]);
}

/*
set all entries to its root (CUDA Kernel)
>> d - data array
>> size - size of the data array
*/
__global__
void KernelSqrtV2(__half * d, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
    if (i < size)
        d[i] = hsqrt(d[i]);
#else
    if (i < size)
        d[i] = __float2half(sqrt(__half2float(d[i])));
#endif
}


/*
get power(d[i], p)
>> d - data array
>> p - power
>> size - size of the data array
*/
__global__
void KernelPower(DTYPE * d, DTYPE p, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size)
        d[i] = pow(d[i], p);
}

/*
get power(d[i], p)
>> d - data array
>> p - power
>> size - size of the data array
*/
__global__
void KernelPower(__half * d, __half p, int size)
{
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
    //int i = blockDim.x * blockIdx.x + threadIdx.x;
    //if (i < size)
    //    d[i] = hpow(d[i], p);
#else
    int i = blockDim.x * blockIdx.x + threadIdx.x;
    if (i < size)
        d[i] = __float2half(pow(__half2float(d[i]), __half2float(p)));
#endif
}

/* get the power of the entries */
extern "C"
void CudaPower(XTensor * a, DTYPE p)
{
    int gridSize[3];
    int blockSize[3];

    GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);

    dim3 blocks(gridSize[0]);
    dim3 threads(blockSize[0]);

    int devIDBackup;
    ProtectCudaDev(a->devID, devIDBackup);

    if (a->dataType == DEFAULT_DTYPE) {
        if (p == (DTYPE)0.5) {
            KernelSqrtV2 << <blocks, threads >> >((DTYPE*)a->data, a->unitNum);
        }
        else if (p != (DTYPE)1.0) {
            KernelPower << <blocks, threads >> >((DTYPE*)a->data, p, a->unitNum);
        }
    }
    else if (a->dataType == X_FLOAT16) {
        if (p == (DTYPE)0.5) {
            KernelSqrtV2 << <blocks, threads >> >((__half*)a->data, a->unitNum);
        }
        else if (p != (DTYPE)1.0) {
            ShowNTErrors("TODO!");
            //unsigned short p2 = FloatToFloat16(p);
            //__half * pp = (__half*)&p2;
            //KernelPower<<<blocks, threads>>>((__half*)a->data, *pp, a->unitNum);
        }
    }
    else {
        ShowNTErrors("TODO!");
    }

    BacktoCudaDev(a->devID, devIDBackup);
}

#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)