/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


/*
 * $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
 */

#include <math.h>
#include "../../XDevice.h"
#include "../../XName.h"
#include "Unary.h"
#include "Unary.cuh"

namespace nts {

__device__
DTYPE cudasquare(DTYPE x)
{
    return x * x;
}

__device__
DTYPE cudaround(DTYPE r)
{
	return (r > 0.0) ? (DTYPE)floor(r + 0.5) : (DTYPE)ceil(r - 0.5);
}

__device__
DTYPE cudaisnonzero(DTYPE r)
{
    return (r != 0.0) ? (DTYPE)1.0 : (DTYPE)0.0;
}

__device__
DTYPE cudaiszero(DTYPE r)
{
    return (r == 0.0) ? (DTYPE)1.0 : (DTYPE)0.0;
}


#define SIMPLE_UNARY_FUNCTION_GPU(funcName, origFunc)                       \
__global__                                                                  \
void Kernel##funcName(DTYPE * a, DTYPE * b, int size)                       \
{                                                                           \
    int i = blockDim.x * blockIdx.x + threadIdx.x;                          \
                                                                            \
    if (i < size)                                                           \
        b[i] = (DTYPE)origFunc(a[i]);                                       \
}                                                                           \
__global__                                                                  \
void Kernel##funcName(__half * a, __half * b, int size)                     \
{                                                                           \
    return;                                                                 \
}                                                                           \
void _Cuda##funcName(const XTensor * a, XTensor * b)                        \
{                                                                           \
    CheckNTErrors((XTensor::IsSameShaped(a, b)),                            \
                  "Input tensors should have the same type!");              \
    CheckNTErrors((a->isSparse == false), "TODO!");                         \
                                                                            \
    int gridSize[3];                                                        \
    int blockSize[3];                                                       \
                                                                            \
    GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);         \
                                                                            \
    dim3 blocks(gridSize[0]);                                               \
    dim3 threads(blockSize[0]);                                             \
                                                                            \
    int devIDBackup;                                                        \
    ProtectCudaDev(a->devID, devIDBackup);                                  \
                                                                            \
    if (a->dataType == DEFAULT_DTYPE) {                                     \
        Kernel##funcName<<<blocks, threads>>>                               \
                         ((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum);    \
    }                                                                       \
    else if (a->dataType == X_FLOAT16) {                                    \
        Kernel##funcName<<<blocks, threads>>>                               \
                         ((__half*)a->data, (__half*)b->data, a->unitNum);  \
    }                                                                       \
    else {                                                                  \
        ShowNTErrors("TODO!");                                              \
    }                                                                       \
                                                                            \
    BacktoCudaDev(a->devID, devIDBackup);                                   \
}                                                                           \

SIMPLE_UNARY_FUNCTION_GPU(Absolute, fabs)
SIMPLE_UNARY_FUNCTION_GPU(Ceil, ceil)
SIMPLE_UNARY_FUNCTION_GPU(Exp, exp)
SIMPLE_UNARY_FUNCTION_GPU(Floor, floor)
SIMPLE_UNARY_FUNCTION_GPU(IsNonZero, cudaisnonzero)
SIMPLE_UNARY_FUNCTION_GPU(IsZero, cudaiszero)
SIMPLE_UNARY_FUNCTION_GPU(Log, log)
SIMPLE_UNARY_FUNCTION_GPU(Round, cudaround)
SIMPLE_UNARY_FUNCTION_GPU(Sqrt, sqrt)
SIMPLE_UNARY_FUNCTION_GPU(Square, cudasquare)

SIMPLE_UNARY_FUNCTION_GPU(Sin, sin)
SIMPLE_UNARY_FUNCTION_GPU(Cos, cos)
SIMPLE_UNARY_FUNCTION_GPU(Tan, tan)

}