/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2017, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
 * $Update by: Lin Ye (email: linye2015@outlook.com) 2019-07-30 float16/int/int8 added
 */

#include "../../XTensor.h"
#include "../../XDevice.h"
#include "ConvertDataType.cuh"

namespace nts { // namespace nts(NiuTrans.Tensor)

#ifdef USE_CUDA
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>

__global__ 
void KernelFloatToFloat16(float * s, __half * t, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size){
        t[i] = __float2half(s[i]);
    }
}

__global__ 
void KernelFloat16ToFloat(__half * s, float * t, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size){
        t[i] = __half2float(s[i]);
    }
}

__global__ 
void KernelFloatToInt(float * inputData, int * outputData, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size){
        outputData[i] = (int)(inputData[i]);
    }
}

__global__ 
void KernelIntToFloat(int * inputData, float * outputData, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size){
        outputData[i] = (float)(inputData[i]);
    }
}

__global__
void KernelFloatToInt8(float * inputData, __int8 * outputData, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size) {
        outputData[i] = (__int8)(inputData[i]);
    }
}

__global__
void KernelInt8ToFloat(__int8 * inputData, float * outputData, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size) {
        outputData[i] = (float)(inputData[i]);
    }
}

__global__
void KernelIntToInt8(int * inputData, __int8 * outputData, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size) {
        outputData[i] = (__int8)(inputData[i]);
    }
}

__global__
void KernelInt8ToInt(__int8 * inputData, int * outputData, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size) {
        outputData[i] = (int)(inputData[i]);
    }
}


/* 
data conversion (cuda code) 
>> devID - device id
>> s - source data array
>> typeS - source data type
>> t - target data array
>> typeT - target data type
>> size - number of the items in s (and t)
*/
void _CudaConvertDataType(int devID, void * s, TENSOR_DATA_TYPE typeS, void * t, TENSOR_DATA_TYPE typeT, int size)
{
    CheckNTErrors((devID >= 0), "This code must be run on GPUs!");

    if(typeS == typeT)
        return;

    int gridSize[3];
    int blockSize[3];

    GDevs.GetCudaThread(devID, size, gridSize, blockSize);

    dim3 blocks(gridSize[0]);
    dim3 threads(blockSize[0]);

    int devIDBackup;
    ProtectCudaDev(devID, devIDBackup);

    if(typeS == X_FLOAT && typeT == X_FLOAT16)
        KernelFloatToFloat16<<<blocks, threads>>>((float*)s, (__half*)t, size);
    else if(typeS == X_FLOAT16 && typeT == X_FLOAT)
        KernelFloat16ToFloat<<<blocks, threads>>>((__half*)s, (float*)t, size);
    else{
        ShowNTErrors("Unsupported data types for conversion!");
    }

    ProtectCudaDev(devID, devIDBackup);
}
/*
convert data type (cuda code) 
>> input - input tensor
>> output - output tensor
*/
void _CudaConvertDataType(const XTensor * input, XTensor * output)
{
    if (input->dataType == output->dataType)
        return;

    int gridSize[3];
    int blockSize[3];

    GDevs.GetCudaThread(input->devID, input->unitNum, gridSize, blockSize);

    dim3 blocks(gridSize[0]);
    dim3 threads(blockSize[0]);

    int devIDBackup;
    ProtectCudaDev(input->devID, devIDBackup);

    if(input->dataType == X_FLOAT && output->dataType == X_INT)
        KernelFloatToInt<<<blocks, threads>>>((float*)input->data, (int*)output->data, input->unitNum);
    else if(input->dataType == X_INT && output->dataType == X_FLOAT)
        KernelIntToFloat<<<blocks, threads>>>((int*)input->data, (float*)output->data, input->unitNum);
    else if(input->dataType == X_FLOAT && output->dataType == X_FLOAT16)
        KernelFloatToFloat16<<<blocks, threads>>>((float*)input->data, (__half*)output->data, input->unitNum);
    else if(input->dataType == X_FLOAT16 && output->dataType == X_FLOAT)
        KernelFloat16ToFloat<<<blocks, threads>>>((__half*)input->data, (float*)output->data, input->unitNum);
    else if (input->dataType == X_FLOAT && output->dataType == X_INT8)
        KernelFloatToInt8 << <blocks, threads >> >((float*)input->data, (__int8*)output->data, input->unitNum);
    else if (input->dataType == X_INT8 && output->dataType == X_FLOAT)
        KernelInt8ToFloat << <blocks, threads >> >((__int8*)input->data, (float*)output->data, input->unitNum);
    else if (input->dataType == X_INT && output->dataType == X_INT8)
        KernelIntToInt8 << <blocks, threads >> >((int*)input->data, (__int8*)output->data, input->unitNum);
    else if (input->dataType == X_INT8 && output->dataType == X_INT)
        KernelInt8ToInt << <blocks, threads >> >((__int8*)input->data, (int*)output->data, input->unitNum);
    else{
        ShowNTErrors("Unsupported data types for conversion!");
    }

    ProtectCudaDev(input->devID, devIDBackup);
}

#endif // USE_CUDA

} // namespace nts(NiuTrans.Tensor)