#include <math.h>
#include "../../XDevice.h"
#include "../../XName.h"
#include "Unary.cuh"

namespace nts {

#define SIMPLE_UNARY_FUNCTION_GPU(funcName, origFunc)                       \
__global__                                                                  \
void Kernel##funcName(DTYPE * a, DTYPE * b, int size)                       \
{                                                                           \
    int i = blockDim.x * blockIdx.x + threadIdx.x;                          \
                                                                            \
    if (i < size)                                                           \
        b[i] = (DTYPE)origFunc(a[i]);                                       \
}                                                                           \
__global__                                                                  \
    void Kernel##funcName(__half * a, __half * b, int size)                 \
{                                                                           \
    return;                                                                 \
}                                                                           \
void _Cuda##funcName(const XTensor * a, XTensor * b)                        \
{                                                                           \
    CheckNTErrors((XTensor::IsSameShaped(a, b)),                            \
                  "Input tensors should have the same type!");              \
    CheckNTErrors((a->isSparse == false), "TODO!");                         \
                                                                            \
    int gridSize[3];                                                        \
    int blockSize[3];                                                       \
                                                                            \
    GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);         \
                                                                            \
    dim3 blocks(gridSize[0]);                                               \
    dim3 threads(blockSize[0]);                                             \
                                                                            \
    int devIDBackup;                                                        \
    ProtectCudaDev(a->devID, devIDBackup);                                  \
                                                                            \
    if (a->dataType == DEFAULT_DTYPE) {                                     \
        Kernel##funcName << <blocks, threads >> >                           \
                     ((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum);        \
    }                                                                       \
    else if (a->dataType == X_FLOAT16) {                                    \
        Kernel##funcName << <blocks, threads >> >                           \
                     ((__half*)a->data, (__half*)b->data, a->unitNum);      \
    }                                                                       \
    else {                                                                  \
        ShowNTErrors("TODO!");                                              \
    }                                                                       \
                                                                            \
    BacktoCudaDev(a->devID, devIDBackup);                                   \
}                                                                           \

SIMPLE_UNARY_FUNCTION_GPU(Absolute, fabs)
SIMPLE_UNARY_FUNCTION_GPU(Exp, exp)
SIMPLE_UNARY_FUNCTION_GPU(Log, log)
SIMPLE_UNARY_FUNCTION_GPU(Sin, sin)
SIMPLE_UNARY_FUNCTION_GPU(Cos, cos)
SIMPLE_UNARY_FUNCTION_GPU(Tan, tan)
//SIMPLE_UNARY_FUNCTION_GPU(Round, round)

}