/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2019-04-05
*/

#include <math.h>
#include "../../XDevice.h"
#include "../../XName.h"
#include "Binary.h"
#include "Binary.cuh"

namespace nts { // namespace nts(NiuTrans.Tensor)

#ifdef USE_CUDA

__device__
int cudascale(int x, int scale)
{
    return x * scale;
}

__device__
int cudadescale(int x, int descale)
{
    return x / descale;
}

__device__
int cudashift(int x, int shift)
{
    return x + shift;
}

__device__
int cudamod(int x, int mod)
{
    return x % mod;
}


#define SIMPLE_BINARY_FUNCTION_GPU(funcName, origFunc)                      \
__global__                                                                  \
void Kernel##funcName(int * a, int * b, int size, int num)                  \
{                                                                           \
    int i = blockDim.x * blockIdx.x + threadIdx.x;                          \
                                                                            \
    if (i < size)                                                           \
        b[i] = (int)origFunc(a[i], num);                                    \
}                                                                           \
                                                                            \
void _Cuda##funcName(const XTensor * a, XTensor * b, int num)               \
{                                                                           \
    CheckNTErrors((XTensor::IsSameShaped(a, b)),                            \
                  "Input tensors should have the same type!");              \
    CheckNTErrors((a->isSparse == false), "TODO!");                         \
                                                                            \
    int gridSize[3];                                                        \
    int blockSize[3];                                                       \
                                                                            \
    GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);         \
                                                                            \
    dim3 blocks(gridSize[0]);                                               \
    dim3 threads(blockSize[0]);                                             \
                                                                            \
    int devIDBackup;                                                        \
    ProtectCudaDev(a->devID, devIDBackup);                                  \
                                                                            \
    if (a->dataType == X_INT) {                                             \
        Kernel##funcName<<<blocks, threads>>>                               \
                         ((int*)a->data, (int*)b->data, a->unitNum, num);   \
    }                                                                       \
    else {                                                                  \
        ShowNTErrors("TODO!");                                              \
    }                                                                       \
                                                                            \
    BacktoCudaDev(a->devID, devIDBackup);                                   \
}                                                                           \

SIMPLE_BINARY_FUNCTION_GPU(Scale, cudascale)
SIMPLE_BINARY_FUNCTION_GPU(Descale, cudadescale)
SIMPLE_BINARY_FUNCTION_GPU(Shift, cudashift)
SIMPLE_BINARY_FUNCTION_GPU(Mod, cudamod)

#endif // USE_CUDA

} // namespace nts(NiuTrans.Tensor)