/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2017, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-17
 */

#include <math.h>
#include "CrossEntropy.h"
#include "CrossEntropy.cuh"
#include "../core/arithmetic/MultiplyDim.h"
#include "../core/arithmetic/Multiply.h"
#include "../core/math/Unary.h"
#include "../core/math/ScaleAndShift.h"
#include "../core/arithmetic/Negate.h"
#include "../core/reduce/ReduceSum.h"
#include "../core/reduce/ReduceSumAll.h"

namespace nts{ // namespace nts(NiuTrans.Tensor)

/*
compute the cross entropy loss

loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions 
        
>> output - model prediction
>> gold - gold standard
>> loss - compute loss
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
*/
void _CrossEntropy(const XTensor * output, const XTensor * gold,
                   XTensor * loss, const XTensor * weight, 
                   const XTensor * padding, int leadingDim)
{
    int n = leadingDim < 0 ? output->order - 1 : leadingDim;
    int unitNum = output->dimSize[n];

    CheckNTErrors(n >= 0 && n < output->order, "Wrong leadingDim!");
    CheckNTErrors(XTensor::IsSameShaped(output, gold), 
                 "The output tensor and gold tensor must be of the same size!");
    CheckNTErrors(weight == NULL || weight->unitNum == unitNum, "Wrong weight tensor!");
    CheckNTErrors(padding == NULL || XTensor::IsSameShaped(padding, loss), 
                 "The loss tensor and padding tensor must be same shape!");
    CheckNTErrors(loss->order == output->order - 1, "Wrong loss dimension!");
    CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE, "TODO!");

    XTensor * interBuf1 = NewTensorBuf(output, output->devID, output->mem);
    XTensor * interBuf2 = NewTensorBuf(output, output->devID, output->mem);
    
    _Log(output, interBuf1);
    _Multiply(gold, interBuf1, interBuf2);

    if(weight != NULL)
        _MultiplyDimMe(interBuf2, weight, n);
    _NegateMe(interBuf2);
    _ReduceSum(interBuf2, loss, n);
    
    if(padding != NULL)
        _MultiplyMe(loss, padding);

    DelTensorBuf(interBuf2);
    DelTensorBuf(interBuf1);
}

/*
compute the cross entropy loss (faster implementation with optimized code)

loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions 
        
>> output - model prediction
>> gold - gold standard
>> loss - compute loss
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
*/
void _CrossEntropyFast(const XTensor * output, const XTensor * gold,
                       XTensor * loss, const XTensor * weight,
                       const XTensor * padding, int leadingDim)
{
    int order = output->order;
    int n = leadingDim < 0 ? output->order - 1 : leadingDim;
    int leadingDimSize = output->GetDim(n);

    CheckNTErrors(n >= 0 && n < output->order, 
                 "Wrong leading dimension!");
    CheckNTErrors(XTensor::IsSameShaped(output, gold), 
                 "The output tensor and gold tensor must be of the same size!");
    CheckNTErrors(weight == NULL || weight->unitNum == leadingDimSize, 
                 "Wrong weight tensor!");
    CheckNTErrors(padding == NULL || XTensor::IsSameShaped(padding, loss), 
                 "The loss tensor and padding tensor must be same shape!");
    CheckNTErrors(loss->order == output->order - 1, 
                 "Wrong loss dimension!");
    CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE, 
                 "TODO!");
    
    for(int i = 0; i < order; i++){
        if(i < n){
            CheckNTErrors((output->GetDim(i) == loss->GetDim(i)), "Unmatched tensors!");
        }
        else if(i > n){
            CheckNTErrors((output->GetDim(i) == loss->GetDim(i - 1)), "Unmatched tensors!");
        }
    }

#ifdef USE_CUDA
    if(output->devID >= 0) {
        _CudaCrossEntropyFast(output, gold, loss, weight, padding, leadingDim);
        return;
    }
#endif

    int blockNum = 1;
    int blockSize = 1;
    int stride = 1;

    for(int i = n + 1; i < order; i++)
        stride *= output->GetDim(i);
    
    blockSize = stride * leadingDimSize;
    blockNum = output->unitNum / blockSize;

    DTYPE * outputData = (DTYPE*)output->data;
    DTYPE * goldData = (DTYPE*)gold->data;
    DTYPE * lossData = (DTYPE*)loss->data;

    DTYPE tmpLoss;
    int lossPos;
    int goldPos;

    if(weight == NULL) {
        if(padding == NULL) {
            for(int i = 0; i < blockNum; i++) {
                for(int j = 0; j < stride; j++) {
                    tmpLoss = 0;
                    lossPos = i * stride + j;
                    for(int k = 0; k < leadingDimSize; k++) {
                        goldPos = i * blockSize + j + k * stride;
                        tmpLoss += -(*(goldData + goldPos)) * 
                                    (DTYPE)log(*(outputData + goldPos));
                    }
                    *(lossData + lossPos) = tmpLoss;
                }
            }
        }
        else {
            DTYPE * paddingData = (DTYPE*)padding->data;
            for(int i = 0; i < blockNum; i++) {
                for(int j = 0; j < stride; j++) {
                    lossPos = i * stride + j;
                    if(*(paddingData + lossPos) == 0)
                        *(lossData + lossPos) = 0;
                    else {
                        tmpLoss = 0;
                        for(int k = 0; k < leadingDimSize; k++) {
                            goldPos = i * blockSize + j + k * stride;
                            tmpLoss += -(*(goldData + goldPos)) * 
                                        (DTYPE)log(*(outputData + goldPos));
                        }
                        *(lossData + lossPos) = tmpLoss;
                    }
                }
            }            
        }
    }
    else {
        DTYPE * weightData = (DTYPE*)weight->data;
        if(padding == NULL) {
            for(int i = 0; i < blockNum; i++) {
                for(int j = 0; j < stride; j++) {
                    tmpLoss = 0;
                    lossPos = i * stride + j;
                    for(int k = 0; k < leadingDimSize; k++) {
                        goldPos = i * blockSize + j + k * stride;
                        tmpLoss += -(*(goldData + goldPos)) * 
                                    (DTYPE)log(*(outputData + goldPos)) *
                                    (*(weightData + k));
                    }
                    *(lossData + lossPos) = tmpLoss;                    
                }
            }
        }
        else {
            DTYPE * paddingData = (DTYPE*)padding->data;
            for(int i = 0; i < blockNum; i++) {
                for(int j = 0; j < stride; j++) {
                    lossPos = i * stride + j;
                    if(*(paddingData + lossPos) == 0)
                        *(lossData + lossPos) = 0;
                    else {
                        tmpLoss = 0;
                        for(int k = 0; k < leadingDimSize; k++) {
                            goldPos = i * blockSize + j + k * stride;
                            tmpLoss += -(*(goldData + goldPos)) * 
                                        (DTYPE)log(*(outputData + goldPos)) *
                                        (*(weightData + k));
                        }
                        *(lossData + lossPos) = tmpLoss;
                    }
                }
            }              
        }
    }
}

/*
compute the cross entropy loss
loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions 
        
>> output - model prediction
>> gold - gold standard
>> reduce - loss compute way, sum or mean
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
*/
DTYPE _CrossEntropy(const XTensor * output, const XTensor * gold,
                    LOSS_COMPUTE_WAY reduceWay, const XTensor * weight, 
                    const XTensor * padding, int leadingDim)
{
    DTYPE loss = 0;
    
    int order = output->order;
    int n = leadingDim < 0 ? output->order - 1 : leadingDim;
    int unitNum = output->dimSize[n];
    
    CheckNTErrors(n >= 0 && n < output->order, "Wrong leadingDim!");
    CheckNTErrors(XTensor::IsSameShaped(output, gold), 
                 "The output tensor and gold tensor must be of the same size!");
    CheckNTErrors(weight == NULL || weight->unitNum == unitNum, "Wrong weight tensor!");
    CheckNTErrors(padding == NULL || padding->order == output->order - 1, 
                 "The loss tensor and padding tensor must be same shape!");
    CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE, "TODO!");

    int * dimSize = new int[order - 1];
    for (int i = 0; i < order; i++) {
        if(i < n)
            dimSize[i] = output->dimSize[i];
        else if(i > n)
            dimSize[i - 1] = output->dimSize[i];
    }

    XTensor * lossBuf = NewTensorBuf(output->order - 1, dimSize, output->dataType, output->denseRatio, 
                                     output->devID, output->mem);

    _CrossEntropy(output, gold, lossBuf, weight, padding, leadingDim);

    loss = _ReduceSumAll(lossBuf);

    if(reduceWay == REDUCE_MEAN) {
        int nonZeroNum;
        if(padding == NULL) {
            nonZeroNum = lossBuf->unitNum;
        }
        else {
            XTensor * tmp = NewTensorBuf(padding, padding->devID, padding->mem);
            _IsNonZero(padding, tmp);
            nonZeroNum = (int)_ReduceSumAll(tmp);
            DelTensorBuf(tmp);
        }

        loss = loss / (DTYPE)nonZeroNum;
    }
    else if(reduceWay == REDUCE_SUM) {
        /* don't need to do anything */
    }
    else {
        ShowNTErrors("TODO");
    }

    delete[] dimSize;
    DelTensorBuf(lossBuf);

    return loss;
}

/*
compute the cross entropy loss (faster implementation with optimized code)

loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions 
        
>> output - model prediction
>> gold - gold standard
>> reduceWay - loss compute way, sum or mean
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
<< return - the cross entropy loss that is a scalar
*/
DTYPE _CrossEntropyFast(const XTensor * output, const XTensor * gold,
                        LOSS_COMPUTE_WAY reduceWay, const XTensor * weight,
                        const XTensor * padding, int leadingDim)
{
    DTYPE loss = 0;

    int order = output->order;
    int n = leadingDim < 0 ? output->order - 1 : leadingDim;
    int leadingDimSize = output->GetDim(n);

    CheckNTErrors(n >= 0 && n < output->order, 
                 "Wrong leadingDim!");
    CheckNTErrors(XTensor::IsSameShaped(output, gold), 
                 "The output tensor and gold tensor must be of the same size!");
    CheckNTErrors(weight == NULL || weight->unitNum == leadingDimSize, 
                 "Wrong weight tensor!");
    CheckNTErrors(padding == NULL || padding->order == output->order - 1, 
                 "Wrong padding tensor!");
    CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE, 
                 "TODO!");
    
    if(padding != NULL) {
        for(int i = 0; i < order; i++){
            if(i < n){
                CheckNTErrors((output->GetDim(i) == padding->GetDim(i)), "Unmatched tensors!");
            }
            else if(i > n){
                CheckNTErrors((output->GetDim(i) == padding->dimSize[i - 1]), "Unmatched tensors!");
            }
        }
    }

#ifdef USE_CUDA
    if(output->devID >= 0) {
        return _CudaCrossEntropyFast(output, gold, reduceWay, weight, padding, leadingDim);
    }
#endif

    int blockNum = 1;
    int blockSize = 1;
    int stride = 1;

    for(int i = n + 1; i < order; i++)
        stride *= output->GetDim(i);
    
    blockSize = stride * leadingDimSize;
    blockNum = output->unitNum / blockSize;

    DTYPE * outputData = (DTYPE*)output->data;
    DTYPE * goldData = (DTYPE*)gold->data;

    int paddingPos;
    int goldPos;
    int nonZeroNum = 0;

    if(weight == NULL) {
        if(padding == NULL) {
            nonZeroNum = blockNum * stride;

            for(int i = 0; i < blockNum; i++) {
                for(int j = 0; j < stride; j++) {
                    paddingPos = i * stride + j;
                    for(int k = 0; k < leadingDimSize; k++) {
                        goldPos = i * blockSize + j + k * stride;
                        loss += -(*(goldData + goldPos)) * 
                                 (DTYPE)log(*(outputData + goldPos));
                    }
                }
            }
        }
        else {
            DTYPE * paddingData = (DTYPE*)padding->data;
            for(int i = 0; i < blockNum; i++) {
                for(int j = 0; j < stride; j++) {
                    paddingPos = i * stride + j;
                    if(*(paddingData + paddingPos) == 0)
                        continue;
                    else {
                        nonZeroNum += 1;
                        for(int k = 0; k < leadingDimSize; k++) {
                            goldPos = i * blockSize + j + k * stride;
                            loss += -(*(goldData + goldPos)) * 
                                     (DTYPE)log(*(outputData + goldPos));
                        }    
                    }
                }
            }
        }
    }
    else {
        DTYPE * weightData = (DTYPE*)weight->data;
        if(padding == NULL) {
            nonZeroNum = blockNum * stride;
            for(int i = 0; i < blockNum; i++) {
                for(int j = 0; j < stride; j++) {
                    paddingPos = i * stride + j;
                    for(int k = 0; k < leadingDimSize; k++) {
                        goldPos = i * blockSize + j + k * stride;
                        loss += -(*(goldData + goldPos)) * 
                                 (DTYPE)log(*(outputData + goldPos)) *
                                 (*(weightData + k));
                    }
                }
            }
        }
        else {
            DTYPE * paddingData = (DTYPE*)padding->data;
            for(int i = 0; i < blockNum; i++) {
                for(int j = 0; j < stride; j++) {
                    paddingPos = i * stride + j;
                    if(*(paddingData + paddingPos) == 0)
                        continue;
                    else {
                        nonZeroNum += 1;
                        for(int k = 0; k < leadingDimSize; k++) {
                            goldPos = i * blockSize + j + k * stride;
                            loss += -(*(goldData + goldPos)) * 
                                     (DTYPE)log(*(outputData + goldPos)) *
                                     (*(weightData + j));
                        }    
                    }
                }
            }
        }
    }

    if(reduceWay == REDUCE_MEAN) {
        loss = loss / (DTYPE)nonZeroNum;
    }
    else if(reduceWay == REDUCE_SUM) {
        /* don't need to do anything */
    }
    else {
        ShowNTErrors("TODO");
    }

    return loss;
}

/* 
backward compuation for cross entropy function

loss = sum_{i} (-t_i * log(y_i))
dE/dy_i = -t_i / y_i
where E is the error(loss) function that measure the errors in y
with respect to gold standard, and y this the model output

>> dedy - dE/dy (for return)
>> output - model prediction
>> gold - gold standard
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
*/
void _CrossEntropyBackward(XTensor * dedy, const XTensor * output, 
                           const XTensor * gold, const XTensor * weight,
                           XTensor * padding, int leadingDim)
{
    int order = output->order;
    int n = leadingDim < 0 ? output->order - 1 : leadingDim;
    int leadingDimSize = output->GetDim(n);

    CheckNTErrors(n >= 0 && n < output->order, 
                 "Wrong leading dimension!");
    CheckNTErrors(XTensor::IsSameShaped(dedy, output, gold), 
                 "The output tensor and gold tensor must be of the same size!");
    CheckNTErrors(weight == NULL || weight->unitNum == leadingDimSize, 
                 "Wrong weight tensor!");
    CheckNTErrors(padding == NULL || padding->order == output->order - 1, 
                 "Wrong padding tensor!");
    CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE, 
                 "TODO!");

    if(padding != NULL) {
        for(int i = 0; i < order; i++){
            if(i < n){
                CheckNTErrors((output->GetDim(i) == padding->GetDim(i)), "Unmatched tensors!");
            }
            else if(i > n){
                CheckNTErrors((output->GetDim(i) == padding->dimSize[i - 1]), "Unmatched tensors!");
            }
        }    
    }


#ifdef USE_CUDA
    if(output->devID >= 0) {
        _CudaCrossEntropyBackward(dedy, output, gold, weight, padding, leadingDim);
        return;
    }
#endif

    int blockNum = 1;
    int blockSize = 1;
    int stride = 1;

    for(int i = n + 1; i < order; i++)
        stride *= output->GetDim(i);
    
    blockSize = stride * leadingDimSize;
    blockNum = output->unitNum / blockSize;

    DTYPE * dedyData = (DTYPE*)dedy->data;
    DTYPE * outputData = (DTYPE*)output->data;
    DTYPE * goldData = (DTYPE*)gold->data;

    int paddingPos;
    int goldPos;

    if(weight == NULL) {
        if(padding == NULL) {
            for(int i = 0; i < blockNum; i++) {
                for(int j = 0; j < stride; j++) {
                    for(int k = 0; k < leadingDimSize; k++) {
                        goldPos = i * blockSize + j + k * stride;
                        *(dedyData + goldPos) = -(*(goldData + goldPos)) / 
                                                 (*(outputData + goldPos));
                    }
                }
            }
        }
        else {
            DTYPE * paddingData = (DTYPE*)padding->data;
            for(int i = 0; i < blockNum; i++) {
                for(int j = 0; j < stride; j++) {
                    paddingPos = i * stride + j;
                    for(int k = 0; k < leadingDimSize; k++) {
                        goldPos = i * blockSize + j + k * stride;
                        if(*(paddingData + paddingPos) == 0)
                            *(dedyData + goldPos) = 0;
                        else
                            *(dedyData + goldPos) = -(*(goldData + goldPos)) / 
                                                     (*(outputData + goldPos));
                    }
                }
            }
        }
    }
    else {
        DTYPE * weightData = (DTYPE*)weight->data;
        if(padding == NULL) {
            for(int i = 0; i < blockNum; i++) {
                for(int j = 0; j < stride; j++) {
                    for(int k = 0; k < leadingDimSize; k++) {
                        goldPos = i * blockSize + j + k * stride;
                        *(dedyData + goldPos) = -(*(weightData + k)) * 
                                                 (*(goldData + goldPos)) / 
                                                 (*(outputData + goldPos));
                    }
                }
            }
        }
        else {
            DTYPE * paddingData = (DTYPE*)padding->data;
            for(int i = 0; i < blockNum; i++) {
                for(int j = 0; j < stride; j++) {
                    paddingPos = i * stride + j;
                    for(int k = 0; k < leadingDimSize; k++) {
                        goldPos = i * blockSize + j + k * stride;
                        if(*(paddingData + paddingPos) == 0)
                            *(dedyData + goldPos) = 0;
                        else
                            *(dedyData + goldPos) = -(*(weightData + k)) * 
                                                     (*(goldData + goldPos)) / 
                                                     (*(outputData + goldPos));
                    }
                }
            }
        }
    }

    //if(padding != NULL) {
    //    XTensor * tmp = NewTensor(padding);
    //    _IsNonZero(padding, tmp);
    //    int nonZeroNum = (int)_ReduceSumAll(tmp);
    //    _ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)nonZeroNum);
    //    delete tmp;
    //}
    //else {
    //    _ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)blockNum);
    //}
}

} // namespace nts(NiuTrans.Tensor)