Unary.cpp 5.81 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
#include <math.h>
#include "../../XName.h"
#include "Unary.h"
#include "Unary.cuh"

namespace nts{
    

#ifdef USE_CUDA
/* define three marco separately, specify the respective function names */
#define _SIMPLE_UNARY_FUNCTION(_funcName, _cudaFuncName, origFunc)          \
void _funcName(const XTensor * a, XTensor * b)                              \
{                                                                           \
    /* run it on GPUs */                                                    \
    if (a->devID >= 0) {                                                    \
        _cudaFuncName(a, b);                                                \
    return;                                                                 \
    }                                                                       \
    CheckNTErrors((XTensor::IsSameShaped(a, b)),                            \
                  "Input tensors should have the same type!");              \
    CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");                 \
    DTYPE * d = (DTYPE*)a->data;                                            \
    DTYPE * db = (DTYPE*)b->data;                                           \
    for (int i = 0; i < a->unitNum; i++)                                    \
        db[i] = (DTYPE)origFunc(d[i]);                                      \
}

#define _SIMPLE_UNARY_FUNCTION_ME(_funcNameMe, _funcName)                   \
void _funcNameMe(XTensor * a)                                               \
{                                                                           \
    _funcName(a, a);                                                        \
}        

#define SIMPLE_UNARY_FUNCTION(funcName, _funcName, operationId)             \
XTensor funcName(const XTensor &a)                                          \
{                                                                           \
    XTensor b(&a);                                                          \
    b.SetTMP();                                                             \
    _funcName(&a, &b);                                                      \
    XLink::MakeLink(&a, NULL, &b, operationId);                             \
    return b;                                                               \
}

_SIMPLE_UNARY_FUNCTION(_Absolute, _CudaAbsolute, fabs)
_SIMPLE_UNARY_FUNCTION_ME(_AbsoluteMe, _Absolute)
SIMPLE_UNARY_FUNCTION(Absolute, _Absolute, MATH_ABSOLUTE)

_SIMPLE_UNARY_FUNCTION(_Exp, _CudaExp, exp)
_SIMPLE_UNARY_FUNCTION_ME(_ExpMe, _Exp)
SIMPLE_UNARY_FUNCTION(Exp, _Exp, MATH_EXP)

_SIMPLE_UNARY_FUNCTION(_Log, _CudaLog, log)
_SIMPLE_UNARY_FUNCTION_ME(_LogMe, _Log)
SIMPLE_UNARY_FUNCTION(Log, _Log, MATH_LOG)

_SIMPLE_UNARY_FUNCTION(_Sin, _CudaSin, sin)
_SIMPLE_UNARY_FUNCTION_ME(_SinMe, _Sin)
SIMPLE_UNARY_FUNCTION(Sin, _Sin, MATH_SIN)

_SIMPLE_UNARY_FUNCTION(_Cos, _CudaCos, cos)
_SIMPLE_UNARY_FUNCTION_ME(_CosMe, _Cos)
SIMPLE_UNARY_FUNCTION(Cos, _Cos, MATH_COS)

_SIMPLE_UNARY_FUNCTION(_Tan, _CudaTan, tan)
_SIMPLE_UNARY_FUNCTION_ME(_TanMe, _Tan)
SIMPLE_UNARY_FUNCTION(Tan, _Tan, MATH_TAN)
67

xiaotong committed
68
/*_SIMPLE_UNARY_FUNCTION(_Round, _CudaRound, round)
69
_SIMPLE_UNARY_FUNCTION_ME(_RoundMe, _Round)
xiaotong committed
70
SIMPLE_UNARY_FUNCTION(Round, _Round, MATH_ROUND)*/
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
#else
/* define three marco separately, specify the respective function names */
#define _SIMPLE_UNARY_FUNCTION(_funcName, origFunc)          \
void _funcName(const XTensor * a, XTensor * b)                              \
{                                                                           \
    CheckNTErrors((XTensor::IsSameShaped(a, b)),                            \
                  "Input tensors should have the same type!");              \
    CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");                 \
    DTYPE * d = (DTYPE*)a->data;                                            \
    DTYPE * db = (DTYPE*)b->data;                                           \
    for (int i = 0; i < a->unitNum; i++)                                    \
        db[i] = (DTYPE)origFunc(d[i]);                                      \
}

#define _SIMPLE_UNARY_FUNCTION_ME(_funcNameMe, _funcName)                   \
void _funcNameMe(XTensor * a)                                               \
{                                                                           \
    _funcName(a, a);                                                        \
}        

#define SIMPLE_UNARY_FUNCTION(funcName, _funcName, operationId)             \
XTensor funcName(const XTensor &a)                                          \
{                                                                           \
    XTensor b(&a);                                                          \
    b.SetTMP();                                                             \
    _funcName(&a, &b);                                                      \
    XLink::MakeLink(&a, NULL, &b, operationId);                             \
    return b;                                                               \
}

_SIMPLE_UNARY_FUNCTION(_Absolute, fabs)
_SIMPLE_UNARY_FUNCTION_ME(_AbsoluteMe, _Absolute)
SIMPLE_UNARY_FUNCTION(Absolute, _Absolute, MATH_ABSOLUTE)

_SIMPLE_UNARY_FUNCTION(_Exp, exp)
_SIMPLE_UNARY_FUNCTION_ME(_ExpMe, _Exp)
SIMPLE_UNARY_FUNCTION(Exp, _Exp, MATH_EXP)

_SIMPLE_UNARY_FUNCTION(_Log, log)
_SIMPLE_UNARY_FUNCTION_ME(_LogMe, _Log)
SIMPLE_UNARY_FUNCTION(Log, _Log, MATH_LOG)

_SIMPLE_UNARY_FUNCTION(_Sin, sin)
_SIMPLE_UNARY_FUNCTION_ME(_SinMe, _Sin)
SIMPLE_UNARY_FUNCTION(Sin, _Sin, MATH_SIN)

_SIMPLE_UNARY_FUNCTION(_Cos, cos)
_SIMPLE_UNARY_FUNCTION_ME(_CosMe, _Cos)
SIMPLE_UNARY_FUNCTION(Cos, _Cos, MATH_COS)

_SIMPLE_UNARY_FUNCTION(_Tan, tan)
_SIMPLE_UNARY_FUNCTION_ME(_TanMe, _Tan)
SIMPLE_UNARY_FUNCTION(Tan, _Tan, MATH_TAN)
124

xuchen committed
125
/*_SIMPLE_UNARY_FUNCTION(_Round, round)
126
_SIMPLE_UNARY_FUNCTION_ME(_RoundMe, _Round)
xuchen committed
127
SIMPLE_UNARY_FUNCTION(Round, _Round, MATH_ROUND)*/
128 129 130
#endif

}