/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

#include "../XTensor.h"
#include "../XUtility.h"
#include "Merge.h"
#include "MakeMergeBlockIndex.h"
#include "CopyBlocksOnSite.h"

namespace nts { // namespace nts(NiuTrans.Tensor)


/*
transform a tensor by merging it alone with a dimension, e.g., (N/3, M, 3) -> (N, M)
>> s - the source tensor
>> t - the target tensor (for return)
>> whereToMerge - the merging operation is along with which dimension
>> leadingDim - the leading dimension of merging, take (N/3, M, 3) -> (N, M) for example
whereToMerge = 0 (i.e., the dimension for "N/3")
leadingDim = 2 (i.e., the dimension for "3")
*/
void Merge(XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
{
	int whereToMergeRDI = s->order - whereToMerge - 1;
	int leadingDimRDI = s->order - leadingDim - 1;
    if (leadingDimRDI < 0)
		leadingDimRDI = s->order - 1;

    CheckNTErrors((s != NULL && t != NULL), "Invalid tensors!");
    CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)),
        "the data must be kept on the same device!");

    CheckNTErrors((s->unitNum == t->unitNum && s->unitSize == t->unitSize), "Unmatched tensors!");
    CheckNTErrors((s->order == t->order + 1), "Unmatched tensors!");
    CheckNTErrors((leadingDimRDI > whereToMergeRDI), "Invalid leading dimension!");

    for (int i = 0; i < s->order; i++) {
        if (i == whereToMergeRDI) {
            CheckNTErrors((t->dimSizeRDI[i] == s->dimSizeRDI[i] * s->dimSizeRDI[leadingDimRDI]),
                "Unmatched tensor sizes!");
        }
        else if (i > leadingDimRDI) {
            CheckNTErrors((s->dimSizeRDI[i - 1] == t->dimSizeRDI[i]),
                "Unmatched tensor sizes!");
        }
    }

    int blockSize = 1;
    int blockNum = 1;
    int gridSize = 1;
    int gridNum = 1;
    int mergedNum = s->dimSizeRDI[leadingDimRDI];

    for (int i = 0; i < s->order; i++) {
        if (i <= leadingDimRDI) {
            if (i <= whereToMergeRDI)
                blockSize *= s->dimSizeRDI[i];
            else
                blockNum *= s->dimSizeRDI[i];
        }
    }

    CheckNTErrors((s->unitNum % (blockSize * blockNum) == 0), "Incorrect size!");

    /* a grid has a number of blocks. there might be several grids */
    gridSize = blockNum;
    gridNum = s->unitNum / (blockSize * blockNum);

    if (mergedNum * gridNum <= MIN_TENSOR_SPLIT_NUM) {
        int sPitch = blockSize * s->unitSize;
        int tPtich = blockSize * mergedNum * t->unitSize;
        int mSize = blockSize * t->unitSize;
        int n = blockNum / mergedNum;
        int sStep = n * sPitch;
        int tStep = blockSize * t->unitSize;
        for (int g = 0; g < gridNum; g++) {
            char * tData = (char*)t->data + g * blockSize * blockNum * t->unitSize;
            char * sData = (char*)s->data + g * blockSize * blockNum * s->unitSize;
            for (int k = 0; k < mergedNum; k++) {
                XMemCopy2D(tData + k * tStep, tPtich, t->devID,
                    sData + k * sStep, sPitch, s->devID,
                    mSize, n);
            }
        }
    }
    else {
        XMem * mem = s->mem;
        int size = s->unitNum * s->unitSize;

        bool isOnSameDevice = (s->devID < 0 && t->devID < 0) || (s->devID == t->devID);

        void * dataTMP = t->data;

        if (!isOnSameDevice)
            dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(mem->devID, size);

        int blockNumInMerge = s->dimSizeRDI[leadingDimRDI];
        int splitSizeInGrid = gridSize / blockNumInMerge;
        int realBlockSize = blockSize * t->unitSize;

        int * blockIndex = (int*)(mem != NULL ?
            mem->AllocBuf(mem->devID, blockNum * gridNum * sizeof(int)) :
            XMemAlloc(mem->devID, blockNum * gridNum * sizeof(int)));

        MakeMergeBlockIndex(blockIndex, blockNum, blockNumInMerge, splitSizeInGrid, gridSize, gridNum, mem);

        CopyBlocksOnSite(s->data, realBlockSize, blockNum, dataTMP, blockIndex, mem);

        if (mem != NULL)
            mem->ReleaseBuf(mem->devID, blockNum * gridNum * sizeof(int));
        else
            XMemFree(mem->devID, blockIndex);

        /* copy from tmp to target */
        XMemCopy(t->data, t->devID, dataTMP, s->devID, size);

        if (!isOnSameDevice) {
            XMemCopy(t->data, t->devID, dataTMP, s->devID, size);

            if (mem != NULL)
                mem->ReleaseBuf(mem->devID, size);
            else
                XMemFree(mem->devID, dataTMP);
        }
    }
}

/*
merge small tensors into a big tensor
>> smalls - the list of the small tensors
>> big - the merged tensor (for return)
>> whereToMerge - the merging operation is along with which dimension
*/
void Merge(XList * smalls, XTensor * big, int whereToMerge)
{
	CheckNTErrors((smalls != NULL), "Invalid list!");
    CheckNTErrors((smalls->count > 0), "Empty list!");

    bool uniform = true;

    int mergeNum = smalls->count;
    XTensor* smallsItem0 = (XTensor*)(smalls->GetItem(0));
    int itemSize = smallsItem0->unitNum * smallsItem0->unitSize;


    for (int i = 0; i < smalls->count; i++) {
        XTensor* smallsItem = (XTensor*)smalls->GetItem(i);
        CheckNTErrors((big->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!");

        if (i > 0) {
            XTensor * preItem = (XTensor*)smalls->GetItem(i - 1);
            if (smallsItem->unitNum * smallsItem->unitSize != (char*)smallsItem->data - (char*)preItem->data)
                uniform = false;
        }
    }

    int blockSize = 1;
    int blockNum = 1;
    int gridSize = 1;
    int gridNum = 1;
    int mergedNum = smalls->count;

    XTensor * s0 = (XTensor*)smalls->GetItem(0);
	int whereToMergeRDI = s0->order - whereToMerge - 1;
    for (int i = 0; i < s0->order; i++) {
        if (i <= whereToMergeRDI)
            blockSize *= s0->dimSizeRDI[i];
        else
            blockNum *= s0->dimSizeRDI[i];
    }

    CheckNTErrors((s0->unitNum % (blockSize * blockNum) == 0), "Incorrect size!");

    /* a grid has a number of blocks. there might be several grids */
    gridSize = blockNum;
    gridNum = s0->unitNum / (blockSize * blockNum);

    /* merging with fewer data copy operations */
    if (mergedNum * gridNum <= MIN_TENSOR_SPLIT_LIST_NUM) {
        int sPitch = blockSize * s0->unitSize;
        int tPtich = blockSize * mergedNum * big->unitSize;
        int mSize = blockSize * big->unitSize;
        int n = blockNum;
        int sStep = 0;
        int tStep = blockSize * big->unitSize;
        for (int g = 0; g < gridNum; g++) {
            char * tData = (char*)big->data + g * blockSize * blockNum * big->unitSize;
            for (int k = 0; k < mergedNum; k++) {
                XTensor * s = (XTensor*)smalls->GetItem(k);
                char * sData = (char*)s->data + g * blockSize * blockNum * s->unitSize;
                XMemCopy2D(tData + k * tStep, tPtich, big->devID,
                    sData + k * sStep, sPitch, s->devID,
                    mSize, n);
            }
        }
    }
    /* merging with fewer kernel/api calls??? (i'm not sure about it!! may remove this later) */
    else {
        int* dimSizeTMP = new int[MAX_TENSOR_DIM_NUM];
        for (int i = 0; i < MAX_TENSOR_DIM_NUM; i++)
            dimSizeTMP[i] = -smallsItem0->dimSizeRDI[i];
        dimSizeTMP[smallsItem0->order] = -mergeNum;

        XMem * mem = smallsItem0->mem;
        XTensor * tensorTMP = new XTensor(smallsItem0->order + 1, dimSizeTMP, smallsItem0->dataType, smallsItem0->denseRatio, mem);
        int size = mergeNum * itemSize;

        void * dataTMP = NULL;
        if (uniform)
            dataTMP = smallsItem0->data;
        else
            dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(mem->devID, size);

        tensorTMP->data = dataTMP;

        /* copy from source to tmp */
        if (!uniform) {
            for (int i = 0; i < mergeNum; i++) {
                XTensor* smallsItem = (XTensor*)smalls->GetItem(i);
                XMemCopy((char*)(tensorTMP->data) + (itemSize * i), tensorTMP->devID, smallsItem->data, smallsItem->devID, itemSize);
            }
        }

        Merge(tensorTMP, big, whereToMerge);

        delete[] dimSizeTMP;
        tensorTMP->data = NULL;
        dataTMP = NULL;

        delete tensorTMP;

        if ((!uniform) && (mem != NULL))
            mem->ReleaseBuf(mem->devID, size);
        else
            XMemFree(mem->devID, dataTMP);
    }
}
} // namespace nts(NiuTrans.Tensor)
