/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2019-04-05
*/

#include <math.h>
#include "../../XName.h"
#include "Binary.h"
#include "Binary.cuh"

namespace nts {

int scale(int x, int scale)
{
    return x * scale;
}

float scale(float x, float scale)
{
    return x * scale;
}

int descale(int x, int descale)
{
    return x / descale;
}

float descale(float x, float descale)
{
    return x / descale;
}

int shift(int x, int shift)
{
    return x + shift;
}

float shift(float x, float shift)
{
    return x + shift;
}

int mod(int x, int mod)
{
    return x % mod;
}

#ifdef USE_CUDA
/* define three marco separately, specify the respective function names  (GPU mode) */
#define _SIMPLE_BINARY_FUNCTION_INT(_funcName, _cudaFuncName, origFunc)     \
void _funcName(const XTensor * a, XTensor * b, int num)                     \
{                                                                           \
    /* run it on GPUs */                                                    \
    if (a->devID >= 0) {                                                    \
        _cudaFuncName(a, b, num);                                           \
        return;                                                             \
    }                                                                       \
    CheckNTErrors((XTensor::IsSameShaped(a, b)),                            \
                "Input tensors should have the same data type!");           \
    CheckNTErrors((a->dataType == X_INT&&b->dataType == X_INT), "TODO!");   \
    int * d = (int*)a->data;                                                \
    int * db = (int*)b->data;                                               \
    for (int i = 0; i < a->unitNum; i++)                                    \
        db[i] = (int)origFunc(d[i], num);                                   \
}                                                                           \

#define _SIMPLE_BINARY_FUNCTION(_funcName, _cudaFuncName, origFunc)         \
void _funcName(const XTensor * a, XTensor * b, float num)                   \
{                                                                           \
    /* run it on GPUs */                                                    \
    if (a->devID >= 0) {                                                    \
        _cudaFuncName(a, b, num);                                           \
        return;                                                             \
    }                                                                       \
    CheckNTErrors((XTensor::IsSameShaped(a, b)),                            \
                "Input tensors should have the same data type!");           \
    CheckNTErrors((a->dataType == X_FLOAT&&b->dataType == X_FLOAT), "TODO!");\
    float * d = (float*)a->data;                                            \
    float * db = (float*)b->data;                                           \
    for (int i = 0; i < a->unitNum; i++)                                    \
        db[i] = (float)origFunc(d[i], num);                                 \
}

#define SIMPLE_BINARY_FUNCTION_ME_INT(funcName, _funcName)                  \
void funcName(XTensor &a, int num)                                          \
{                                                                           \
    _funcName(&a, &a, num);                                                 \
}                                                                           \

#define SIMPLE_BINARY_FUNCTION_ME(funcName, _funcName)                      \
void funcName(XTensor &a, float num)                                        \
{                                                                           \
    _funcName(&a, &a, num);                                                 \
}                                                                           \
    
#define SIMPLE_BINARY_FUNCTION_INT(funcName, _funcName)                     \
void funcName(const XTensor &a, XTensor &b, int num)                        \
{                                                                           \
    _funcName(&a, &b, num);                                                 \
}                                                                           \

#define SIMPLE_BINARY_FUNCTION(funcName, _funcName, operationId)            \
XTensor funcName(const XTensor &a, float num)                               \
{                                                                           \
    XTensor b(&a);                                                          \
    b.SetTMPFlag();                                                         \
    _funcName(&a, &b, num);                                                 \
    XLink::MakeLink(&a, NULL, &b, operationId);                             \
    return b;                                                               \
}                                                                           \

#define SIMPLE_BINARY_FUNCTION_VOID(funcName, _funcName, operationId)       \
void funcName(const XTensor &a, XTensor &b, float num, bool requireLink)    \
{                                                                           \
    if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) {                      \
        InitTensor(&b, &a);                                                 \
    }                                                                       \
    _funcName(&a, &b, num);                                                 \
    if (requireLink) {                                                      \
        XLink::MakeLink(&a, NULL, &b, operationId);                         \
    }                                                                       \
}                                                                           \

_SIMPLE_BINARY_FUNCTION_INT(_Scale, _CudaScale, scale)
SIMPLE_BINARY_FUNCTION_ME_INT(_ScaleMe, _Scale)
SIMPLE_BINARY_FUNCTION_INT(Scale, _Scale)

_SIMPLE_BINARY_FUNCTION(_Scale, _CudaScaleFloat, scale)
SIMPLE_BINARY_FUNCTION_ME(_ScaleMe, _Scale)
SIMPLE_BINARY_FUNCTION(Scale, _Scale, MATH_SCALE)
SIMPLE_BINARY_FUNCTION_VOID(Scale, _Scale, MATH_SCALE)

_SIMPLE_BINARY_FUNCTION_INT(_Descale, _CudaDescale, descale)
SIMPLE_BINARY_FUNCTION_ME_INT(_DescaleMe, _Descale)
SIMPLE_BINARY_FUNCTION_INT(Descale, _Descale)

_SIMPLE_BINARY_FUNCTION(_Descale, _CudaDescaleFloat, descale)
SIMPLE_BINARY_FUNCTION_ME(_DescaleMe, _Descale)
SIMPLE_BINARY_FUNCTION(Descale, _Descale, MATH_DESCALE)
SIMPLE_BINARY_FUNCTION_VOID(Descale, _Descale, MATH_DESCALE)

_SIMPLE_BINARY_FUNCTION_INT(_Shift, _CudaShift, shift)
SIMPLE_BINARY_FUNCTION_ME_INT(_ShiftMe, _Shift)
SIMPLE_BINARY_FUNCTION_INT(Shift, _Shift)

_SIMPLE_BINARY_FUNCTION(_Shift, _CudaShiftFloat, shift)
SIMPLE_BINARY_FUNCTION_ME(_ShiftMe, _Shift)
SIMPLE_BINARY_FUNCTION(Shift, _Shift, MATH_SHIFT)
SIMPLE_BINARY_FUNCTION_VOID(Shift, _Shift, MATH_SHIFT)

_SIMPLE_BINARY_FUNCTION_INT(_Mod, _CudaMod, mod)
SIMPLE_BINARY_FUNCTION_ME_INT(_ModMe, _Mod)
SIMPLE_BINARY_FUNCTION_INT(Mod, _Mod)

#else
/* define three marco separately, specify the respective function names (CPU mode) */
#define _SIMPLE_BINARY_FUNCTION_INT(_funcName, origFunc)                    \
void _funcName(const XTensor * a, XTensor * b, int num)                     \
{                                                                           \
    CheckNTErrors(a->devID < 0, "No GPU code is supported");                \
    CheckNTErrors((XTensor::IsSameShaped(a, b)),                            \
                "Input tensors should have the same data type!");           \
    CheckNTErrors((a->dataType == X_INT&&b->dataType == X_INT), "TODO!");   \
    int * d = (int*)a->data;                                                \
    int * db = (int*)b->data;                                               \
    for (int i = 0; i < a->unitNum; i++)                                    \
        db[i] = (int)origFunc(d[i], num);                                   \
}                                                                           \

#define _SIMPLE_BINARY_FUNCTION(_funcName, origFunc)         \
void _funcName(const XTensor * a, XTensor * b, float num)                   \
{                                                                           \
    CheckNTErrors(a->devID < 0, "No GPU code is supported");                \
    CheckNTErrors((XTensor::IsSameShaped(a, b)),                            \
                "Input tensors should have the same data type!");           \
    CheckNTErrors((a->dataType == X_FLOAT&&b->dataType == X_FLOAT), "TODO!");\
    float * d = (float*)a->data;                                            \
    float * db = (float*)b->data;                                           \
    for (int i = 0; i < a->unitNum; i++)                                    \
        db[i] = (float)origFunc(d[i], num);                                 \
}

#define SIMPLE_BINARY_FUNCTION_ME_INT(funcName, _funcName)                  \
void funcName(XTensor &a, int num)                                          \
{                                                                           \
    _funcName(&a, &a, num);                                                 \
}                                                                           \

#define SIMPLE_BINARY_FUNCTION_ME(funcName, _funcName)                      \
void funcName(XTensor &a, float num)                                        \
{                                                                           \
    _funcName(&a, &a, num);                                                 \
}                                                                           \

#define SIMPLE_BINARY_FUNCTION_INT(funcName, _funcName)                     \
void funcName(const XTensor &a, XTensor &b, int num)                        \
{                                                                           \
    _funcName(&a, &b, num);                                                 \
}                                                                           \

#define SIMPLE_BINARY_FUNCTION(funcName, _funcName)                         \
void funcName(const XTensor &a, XTensor &b, float num)                      \
{                                                                           \
    _funcName(&a, &b, num);                                                 \
}                                                                           \

    
_SIMPLE_BINARY_FUNCTION_INT(_Scale, scale)
SIMPLE_BINARY_FUNCTION_ME_INT(_ScaleMe, _Scale)
SIMPLE_BINARY_FUNCTION_INT(Scale, _Scale)

_SIMPLE_BINARY_FUNCTION(_Scale, scale)
SIMPLE_BINARY_FUNCTION_ME(_ScaleMe, _Scale)
SIMPLE_BINARY_FUNCTION(Scale, _Scale)
    
_SIMPLE_BINARY_FUNCTION_INT(_Descale, descale)
SIMPLE_BINARY_FUNCTION_ME_INT(_DescaleMe, _Descale)
SIMPLE_BINARY_FUNCTION_INT(Descale, _Descale)

_SIMPLE_BINARY_FUNCTION(_Descale, descale)
SIMPLE_BINARY_FUNCTION_ME(_DescaleMe, _Descale)
SIMPLE_BINARY_FUNCTION(Descale, _Descale)
    
_SIMPLE_BINARY_FUNCTION_INT(_Shift, shift)
SIMPLE_BINARY_FUNCTION_ME_INT(_Shift, _Shift)
SIMPLE_BINARY_FUNCTION_INT(Shift, _Shift)

_SIMPLE_BINARY_FUNCTION(_Shift, shift)
SIMPLE_BINARY_FUNCTION_ME(_ShiftMe, _Shift)
SIMPLE_BINARY_FUNCTION(Shift, _Shift)
    
_SIMPLE_BINARY_FUNCTION_INT(_Mod, mod)
SIMPLE_BINARY_FUNCTION_ME_INT(_ModMe, _Mod)
SIMPLE_BINARY_FUNCTION_INT(Mod, _Mod)

    
#endif

} // namespace nts(NiuTrans.Tensor)
