/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2017, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

#include "Loss.h"
#include "Loss.cuh"
#include "../XDevice.h"

namespace nts{ // namespace nts(NiuTrans.Tensor)

#ifdef USE_CUDA

/*
loss function to measure the "number" of errors
*/


/* 
compute the loss 
>> gold - gold standard
>> y - model prediction
>> LFName - name of loss function
>> isLogOutput - is the output in log scale?
>> leadDim - the leading dimension for the output
>> gBeg - where to start in the gold standard (along the leading dimension)
>> gLen - segment length from oBeg (along the leading dimension)
>> yBeg - where to start in the model output (along the leading dimension)
<< return - error in model prediction with respect to gold standard
*/
DTYPE CudaLossCompute(XTensor * gold, XTensor * y, LOSS_FUNCTION_NAME LFName,
                      bool isLogOutput, int leadDim, int gBeg, int gLen, int yBeg)
{
    return 0;

    // TODO: call cuda kernels for computing the errors
}

/* 
the log version of loss computation

>> gold - gold standard
>> y - model prediction
>> LFName - name of loss function
>> leadDim - the leading dimension for the output
>> gBeg - where to start in the gold standard (along the leading dimension)
>> gLen - segment length from oBeg (along the leading dimension)
>> yBeg - where to start in the model output (along the leading dimension)
<< return - error in model prediction with respect to gold standard
*/
DTYPE CudaLossComputeForLogScale(XTensor * gold, XTensor * y, 
                                 LOSS_FUNCTION_NAME LFName,
                                 int leadDim, int gBeg, int gLen, int yBeg)
{
    return 0;

    // TODO: call cuda kernels for computing the errors
}

/* 
backward compuation for a single element (Cuda version)
dE/dy
where E is the error(loss) function that measure the errors in y
with respect to gold standard, and y this the model output
>> t - gold standard
>> y - model output
>> LFName - name of loss function
<< return dE/dy
*/
DTYPE CudaLossBackward(DTYPE t, DTYPE y, LOSS_FUNCTION_NAME LFName)
{
    return LossBackwardPoint(t, y, LFName);
   
    // TODO: call cuda kernels for computing the errors
}

/* 
backward compuation for squared error (Cuda kernel)
>> dedy - dE/dy (for return)
>> t - gold standard (in vector)
>> y - model output (in vector)
>> size - size of the vector (dedy)
*/
extern "C" __global__ 
void KernelLossBackwardSquaredError(DTYPE * dedy, DTYPE * t, DTYPE * y, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size){
        dedy[i] = y[i] - t[i];
    }
}

/* 
backward compuation of blocks for squared error (Cuda kernel)
>> dedy - dE/dy (for return)
>> t - gold standard (in vector)
>> y - model output (in vector)
>> blockSize - size of a block
>> begInBlock - the begining position in a block for computation 
>> lenInBlock - number of items in a block for computation 
>> size - size of the vector (dedy)
*/
extern "C" __global__ 
void KernelLossBackwardSquaredErrorBlock(DTYPE * dedy, DTYPE * t, DTYPE * y, 
                                         int blockSize, int begInBlock, int lenInBlock, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    int offset = i % blockSize;

    if(offset < begInBlock || offset >= begInBlock + lenInBlock)
        return;

    if (i < size){
        dedy[i] = y[i] - t[i];
    }
}

/* 
backward compuation for cross entropy (Cuda kernel)
>> dedy - dE/dy (for return)
>> t - gold standard (in vector)
>> y - model output (in vector)
>> size - size of the vector (dedy)
*/
extern "C" __global__ 
void KernelLossBackwardCrossEntropy(DTYPE * dedy, DTYPE * t, DTYPE * y, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size){
        dedy[i] =  -t[i]/y[i];
    }
}

/* 
backward compuation for cross entropy (Cuda kernel)
>> dedy - dE/dy (for return)
>> t - gold standard (in vector)
>> y - model output (in vector)
>> blockSize - size of a block
>> begInBlock - the begining position in a block for computation 
>> lenInBlock - number of items in a block for computation 
>> size - size of the vector (dedy)
*/
extern "C" __global__ 
void KernelLossBackwardCrossEntropyBlock(DTYPE * dedy, DTYPE * t, DTYPE * y, 
                                         int blockSize, int begInBlock, int lenInBlock, int size)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    int offset = i % blockSize;

    if(offset < begInBlock || offset >= begInBlock + lenInBlock)
        return;

    if (i < size){
        dedy[i] =  -t[i]/y[i];
    }
}

/* 
backward compuation for (dense) vectors (Cuda version)
dE/dy
where E is the error(loss) function that measure the errors in y
with respect to gold standard, and y this the model output
>> dedy - dE/dy (for return)
>> t - gold standard (in vector)
>> y - model output (in vector)
>> LFName - name of loss function
>> leadDim - the leading dimension for the output
>> tBeg - where to start in the gold standard (along the leading dimension)
>> tLen - segment length from oBeg (along the leading dimension)
>> yBeg - where to start in the model output (along the leading dimension)
*/
void CudaLossBackward(XTensor * dedy, XTensor * t, XTensor * y, 
                      LOSS_FUNCTION_NAME LFName, 
                      int leadDim, int tBeg, int tLen, int yBeg)
{
    CheckNTErrors((XTensor::IsIdentical(t, y)&& XTensor::IsIdentical(dedy, y)), 
                        "The input tensors must be of the same size!");
    CheckNTErrors((t->dimSizeRDI[0] == 1 && y->dimSizeRDI[0] == 1 && dedy->dimSizeRDI[1] == 1), "TODO!");
    CheckNTErrors((t->dataType == DEFAULT_DTYPE && 
                         y->dataType == DEFAULT_DTYPE && 
                         dedy->dataType == DEFAULT_DTYPE),
                         "Input vectors are not in default type.");

    CheckNTErrors((dedy->devID >= 0 && t->devID >= 0 && y->devID >= 0),
                         "The backward compuation must be performed on GPUs.");

    CheckNTErrors((dedy->devID == t->devID && dedy->devID == y->devID),
                        "The vectors must be on the same GPU.");
    CheckNTErrors((tBeg == yBeg), "TODO!");

    int leadDimRDI = y->order - leadDim - 1;
    if(leadDimRDI < 0){
        leadDimRDI = y->dimSizeRDI[y->order - 1];
        tBeg = 0;
        yBeg = 0;
        tLen = y->dimSizeRDI[leadDimRDI];
    }

    int stride = 1;
    int blockSize = 1;
    int size = 1;

    for(int i = 0; i < leadDimRDI; i++)
        stride *= y->dimSizeRDI[i];
    size = tLen * stride;

    int cudaGridSize[3], cudaBlockSize[3];

    GDevs.GetCudaThread(dedy->devID, y->unitNum, cudaGridSize, cudaBlockSize);

    dim3 blocks(cudaGridSize[0]);
    dim3 threads(cudaBlockSize[0]);

    DTYPE * tp = (DTYPE*)t->data;
    DTYPE * yp = (DTYPE*)y->data;
    DTYPE * dedyp = (DTYPE*)dedy->data;

    int devIDBackup;
    ProtectCudaDev(y->devID, devIDBackup);

    /* 
    squared error 
    loss = sum_{i} 0.5*(t_i - y_i)^2, where t_i is the gold standard and y_i is the model output
    dloss/dy_i = y_i - t_i
    */
    if(LFName == SQUAREDERROR){
        if(t->isSparse){
            ShowNTErrors("TODO!");
        }
        else if(size == y->unitNum){
            KernelLossBackwardSquaredError<<<blocks, threads>>>(dedyp, tp, yp, y->unitNum);
        }
        else{
            KernelLossBackwardSquaredErrorBlock<<<blocks, threads>>>(dedyp, tp, yp, blockSize, tBeg * stride, tLen * stride, y->unitNum);
        }
    }

    /* 
    cross entropy
    loss = sum_{i} (-t_i * log(y_i)), where t and y are distributions 
    dloss/dy_i = -t_i / y_i
    */
    if(LFName == CROSSENTROPY){
        if(t->isSparse){
            ShowNTErrors("TODO!");
        }
        else if(size == y->unitNum){
            KernelLossBackwardCrossEntropy<<<blocks, threads>>>(dedyp, tp, yp, tLen);
        }
        else{
            KernelLossBackwardCrossEntropyBlock<<<blocks, threads>>>(dedyp, tp, yp, blockSize, tBeg * stride, tLen * stride, y->unitNum);
        }
    }

    BacktoCudaDev(y->devID, devIDBackup);
}

#endif

} // namespace nts(NiuTrans.Tensor)