/* 
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-18
* I'm surprised that I did not write this file till today.
*/

#include <curand.h>
#include <time.h>
#include "SetData.cuh"
#include <curand_kernel.h>
#include "../../XDevice.h"

namespace nts { // namespace nts(NiuTrans.Tensor)

/* 
set an integer data array with a fixed value p (in int) 
>> d - pointer to the data array
>> size - size of the array
>> p - the initial value
*/
__global__ 
void KernelSetDataFixedInt(int * d, int size, int p)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size)
        d[i] = p;
}

/* 
generate data items with a fixed value p (in int) 
>> tensor - the tensor for initialization
>> p - the initial value
*/
void _CudaSetDataFixedInt(XTensor * tensor, int p)
{
    CheckNTErrors(tensor->dataType == X_INT, "the tensor must be in X_INT!");

    int gridSize[3];
    int blockSize[3];

    GDevs.GetCudaThread(tensor->devID, tensor->unitNum, gridSize, blockSize);

    dim3 blocks(gridSize[0]);
    dim3 threads(blockSize[0]);

    int devIDBackup;
    ProtectCudaDev(tensor->devID, devIDBackup);

    KernelSetDataFixedInt <<<blocks, threads >>>((int*)tensor->data, tensor->unitNum, p);

    BacktoCudaDev(tensor->devID, devIDBackup);
}

/* 
set a float data array with a fixed value p (in int) 
>> d - pointer to the data array
>> size - size of the array
>> p - the initial value
*/
__global__ 
void KernelSetDataFixedFloat(float * d, int size, float p)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size)
        d[i] = p;
}

/* 
generate data items with a fixed value p (in float)
>> tensor - the tensor for initialization
>> p - the initial value
*/
void _CudaSetDataFixedFloat(XTensor * tensor, float p)
{
    CheckNTErrors(tensor->dataType == X_FLOAT, "the tensor must be in X_FLOAT!");

    int gridSize[3];
    int blockSize[3];

    GDevs.GetCudaThread(tensor->devID, tensor->unitNum, gridSize, blockSize);

    dim3 blocks(gridSize[0]);
    dim3 threads(blockSize[0]);

    int devIDBackup;
    ProtectCudaDev(tensor->devID, devIDBackup);

    KernelSetDataFixedFloat <<<blocks, threads >>>((float*)tensor->data, tensor->unitNum, p);

    BacktoCudaDev(tensor->devID, devIDBackup);
}

/* 
set a double data array with a fixed value p (in int) 
>> d - pointer to the data array
>> size - size of the array
>> p - the initial value
*/
__global__ 
void KernelSetDataFixedDouble(double * d, int size, double p)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size)
        d[i] = p;
}

/* 
generate data items with a fixed value p (in double) 
>> tensor - the tensor for initialization
>> p - the initial value
*/
void _CudaSetDataFixedDouble(XTensor * tensor, double p)
{
    CheckNTErrors(tensor->dataType == X_DOUBLE, "the tensor must be in X_DOUBLE!");

    int gridSize[3];
    int blockSize[3];

    GDevs.GetCudaThread(tensor->devID, tensor->unitNum, gridSize, blockSize);

    dim3 blocks(gridSize[0]);
    dim3 threads(blockSize[0]);

    int devIDBackup;
    ProtectCudaDev(tensor->devID, devIDBackup);

    KernelSetDataFixedDouble <<<blocks, threads >>>((double*)tensor->data, tensor->unitNum, p);

    BacktoCudaDev(tensor->devID, devIDBackup);
}

/* 
set data array with a uniform distribution in [low, high] 
>> deviceStates - the state of curand
>> d - float datatype pointer to the data array 
>> size - size of the array
>> lower - low value of the range
>> variance - the variance of the range
*/
__global__
void KernelSetDataRandFloat(float * d, int size, DTYPE lower, DTYPE variance)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;
    
    if (i < size) {
        d[i] = d[i] * variance + lower;
    }
}
/* 
set data array with a uniform distribution in [low, high] 
>> deviceStates - the state of curand
>> d - double datatype pointer to the data array
>> size - size of the array
>> lower - low value of the range
>> variance - the variance of the range
*/
__global__
void KernelSetDataRandDouble(double * d, int size, DTYPE lower, DTYPE variance)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;
    
    if (i < size){
        d[i] = d[i] * variance + lower;
    }
}

/*
generate data items with a uniform distribution in [lower, upper]
>> tensor - the tensor whose data array would be initialized
>> lower - lower value of the range
>> upper - upper value of the range
*/
void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
{
    CheckNTErrors(upper > lower, "the high value must be greater than low value!");

    int gridSize[3];
    int blockSize[3];

    GDevs.GetCudaThread(tensor->devID, tensor->unitNum, gridSize, blockSize);

    dim3 blocks(gridSize[0]);
    dim3 threads(blockSize[0]);

    int devIDBackup;
    ProtectCudaDev(tensor->devID, devIDBackup);
    
    curandGenerator_t gen;
    curandCreateGenerator (&gen, CURAND_RNG_PSEUDO_DEFAULT);
    curandSetPseudoRandomGeneratorSeed(gen, time(NULL));
    curandGenerateUniform(gen , (float*)tensor->data , tensor->unitNum);
    curandDestroyGenerator(gen);
    DTYPE variance = upper - lower;

    if (tensor->dataType == X_FLOAT)
        KernelSetDataRandFloat <<<blocks, threads >>>((float*)tensor->data, tensor->unitNum, lower, variance);
    else if (tensor->dataType == X_DOUBLE)
        KernelSetDataRandDouble <<<blocks, threads >>>((double*)tensor->data, tensor->unitNum, lower, variance);

    BacktoCudaDev(tensor->devID, devIDBackup);
}

} // namespace nts(NiuTrans.Tensor)
