#include <math.h>
#include "../../XName.h"
#include "Unary.h"
#include "Unary.cuh"

namespace nts{
    

#ifdef USE_CUDA
/* define three marco separately, specify the respective function names */
#define _SIMPLE_UNARY_FUNCTION(_funcName, _cudaFuncName, origFunc)          \
void _funcName(const XTensor * a, XTensor * b)                              \
{                                                                           \
    /* run it on GPUs */                                                    \
    if (a->devID >= 0) {                                                    \
        _cudaFuncName(a, b);                                                \
    return;                                                                 \
    }                                                                       \
    CheckNTErrors((XTensor::IsSameShaped(a, b)),                            \
                  "Input tensors should have the same type!");              \
    CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");                 \
    DTYPE * d = (DTYPE*)a->data;                                            \
    DTYPE * db = (DTYPE*)b->data;                                           \
    for (int i = 0; i < a->unitNum; i++)                                    \
        db[i] = (DTYPE)origFunc(d[i]);                                      \
}

#define _SIMPLE_UNARY_FUNCTION_ME(_funcNameMe, _funcName)                   \
void _funcNameMe(XTensor * a)                                               \
{                                                                           \
    _funcName(a, a);                                                        \
}        

#define SIMPLE_UNARY_FUNCTION(funcName, _funcName, operationId)             \
XTensor funcName(const XTensor &a)                                          \
{                                                                           \
    XTensor b(&a);                                                          \
    b.SetTMP();                                                             \
    _funcName(&a, &b);                                                      \
    XLink::MakeLink(&a, NULL, &b, operationId);                             \
    return b;                                                               \
}

_SIMPLE_UNARY_FUNCTION(_Absolute, _CudaAbsolute, fabs)
_SIMPLE_UNARY_FUNCTION_ME(_AbsoluteMe, _Absolute)
SIMPLE_UNARY_FUNCTION(Absolute, _Absolute, MATH_ABSOLUTE)

_SIMPLE_UNARY_FUNCTION(_Exp, _CudaExp, exp)
_SIMPLE_UNARY_FUNCTION_ME(_ExpMe, _Exp)
SIMPLE_UNARY_FUNCTION(Exp, _Exp, MATH_EXP)

_SIMPLE_UNARY_FUNCTION(_Log, _CudaLog, log)
_SIMPLE_UNARY_FUNCTION_ME(_LogMe, _Log)
SIMPLE_UNARY_FUNCTION(Log, _Log, MATH_LOG)

_SIMPLE_UNARY_FUNCTION(_Sin, _CudaSin, sin)
_SIMPLE_UNARY_FUNCTION_ME(_SinMe, _Sin)
SIMPLE_UNARY_FUNCTION(Sin, _Sin, MATH_SIN)

_SIMPLE_UNARY_FUNCTION(_Cos, _CudaCos, cos)
_SIMPLE_UNARY_FUNCTION_ME(_CosMe, _Cos)
SIMPLE_UNARY_FUNCTION(Cos, _Cos, MATH_COS)

_SIMPLE_UNARY_FUNCTION(_Tan, _CudaTan, tan)
_SIMPLE_UNARY_FUNCTION_ME(_TanMe, _Tan)
SIMPLE_UNARY_FUNCTION(Tan, _Tan, MATH_TAN)

/*_SIMPLE_UNARY_FUNCTION(_Round, _CudaRound, round)
_SIMPLE_UNARY_FUNCTION_ME(_RoundMe, _Round)
SIMPLE_UNARY_FUNCTION(Round, _Round, MATH_ROUND)*/
#else
/* define three marco separately, specify the respective function names */
#define _SIMPLE_UNARY_FUNCTION(_funcName, origFunc)          \
void _funcName(const XTensor * a, XTensor * b)                              \
{                                                                           \
    CheckNTErrors((XTensor::IsSameShaped(a, b)),                            \
                  "Input tensors should have the same type!");              \
    CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");                 \
    DTYPE * d = (DTYPE*)a->data;                                            \
    DTYPE * db = (DTYPE*)b->data;                                           \
    for (int i = 0; i < a->unitNum; i++)                                    \
        db[i] = (DTYPE)origFunc(d[i]);                                      \
}

#define _SIMPLE_UNARY_FUNCTION_ME(_funcNameMe, _funcName)                   \
void _funcNameMe(XTensor * a)                                               \
{                                                                           \
    _funcName(a, a);                                                        \
}        

#define SIMPLE_UNARY_FUNCTION(funcName, _funcName, operationId)             \
XTensor funcName(const XTensor &a)                                          \
{                                                                           \
    XTensor b(&a);                                                          \
    b.SetTMP();                                                             \
    _funcName(&a, &b);                                                      \
    XLink::MakeLink(&a, NULL, &b, operationId);                             \
    return b;                                                               \
}

_SIMPLE_UNARY_FUNCTION(_Absolute, fabs)
_SIMPLE_UNARY_FUNCTION_ME(_AbsoluteMe, _Absolute)
SIMPLE_UNARY_FUNCTION(Absolute, _Absolute, MATH_ABSOLUTE)

_SIMPLE_UNARY_FUNCTION(_Exp, exp)
_SIMPLE_UNARY_FUNCTION_ME(_ExpMe, _Exp)
SIMPLE_UNARY_FUNCTION(Exp, _Exp, MATH_EXP)

_SIMPLE_UNARY_FUNCTION(_Log, log)
_SIMPLE_UNARY_FUNCTION_ME(_LogMe, _Log)
SIMPLE_UNARY_FUNCTION(Log, _Log, MATH_LOG)

_SIMPLE_UNARY_FUNCTION(_Sin, sin)
_SIMPLE_UNARY_FUNCTION_ME(_SinMe, _Sin)
SIMPLE_UNARY_FUNCTION(Sin, _Sin, MATH_SIN)

_SIMPLE_UNARY_FUNCTION(_Cos, cos)
_SIMPLE_UNARY_FUNCTION_ME(_CosMe, _Cos)
SIMPLE_UNARY_FUNCTION(Cos, _Cos, MATH_COS)

_SIMPLE_UNARY_FUNCTION(_Tan, tan)
_SIMPLE_UNARY_FUNCTION_ME(_TanMe, _Tan)
SIMPLE_UNARY_FUNCTION(Tan, _Tan, MATH_TAN)

/*_SIMPLE_UNARY_FUNCTION(_Round, round)
_SIMPLE_UNARY_FUNCTION_ME(_RoundMe, _Round)
SIMPLE_UNARY_FUNCTION(Round, _Round, MATH_ROUND)*/
#endif

}