/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

#include "../../XTensor.h"
#include "../../XUtility.h"
#include "Split.h"
#include "MakeSplitBlockIndex.h"
#include "../movement/CopyBlocksOnSite.h"

namespace nts { // namespace nts(NiuTrans.Tensor)

/*
transform a tensor by splitting it, e.g., (N, M) -> (N/3, M, 3)
>> s - the source tensor
>> t - the target tensor (for return)
>> whereToSplit - which dimension of the tensor is to split
>> splitNum - how many splits
*/
void Split(XTensor * s, XTensor * t, int whereToSplit, int splitNum)
{
    CheckNTErrors((s && t), "Invalid tensors!");
    CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)),
        "the data must be kept on the same device!");

    CheckNTErrors((s->unitNum == t->unitNum && s->unitSize == t->unitSize), "Unmatched tensors!");
    CheckNTErrors((s->order == t->order - 1), "Unmatched tensors!");
    CheckNTErrors((t->dimSizeRDI[t->order - 1] == splitNum), "Incorrect tensor sizes!");

    int whereToSplitRDI = s->order - whereToSplit - 1;
    for (int i = 0; i < s->order; i++) {
        if (i == whereToSplitRDI) {
            CheckNTErrors((s->dimSizeRDI[i] == t->dimSizeRDI[i] * splitNum),
                "Unmatched tensor sizes!");
        }
        else {
            CheckNTErrors((s->dimSizeRDI[i] == t->dimSizeRDI[i]),
                "Unmatched tensor sizes!");
        }
    }

    /* for the case that we split the last dimension. Actually
    (N, M) and (N, M/3, 3) have the same memory layout */
    if (s->order - 1 == whereToSplitRDI) {
        XMemCopy(t->data, t->devID, s->data, s->devID, s->unitNum * s->unitSize);
        return;
    }

    int blockSize = 1;
    int blockNum = 1;
    for (int i = 0; i < s->order; i++) {
        if (i == whereToSplitRDI) {
            blockSize *= s->dimSizeRDI[i] / splitNum;
            blockNum *= splitNum;
        }
        else if (i < whereToSplitRDI)
            blockSize *= s->dimSizeRDI[i];
        else
            blockNum *= s->dimSizeRDI[i];
    }

    CheckNTErrors((blockNum % splitNum == 0), "Incorrect split number!");

    if (splitNum <= MIN_TENSOR_SPLIT_NUM) {
        int sPitch = blockSize * splitNum * s->unitSize;
        int tPitch = blockSize * t->unitSize;
        int mSize = blockSize * t->unitSize;
        int n = blockNum / splitNum;
        int sStep = blockSize * s->unitSize;
        int tStep = n * tPitch;
        for (int k = 0; k < splitNum; k++) {
            XMemCopy2D((char*)t->data + k * tStep, tPitch, t->devID,
                (char*)s->data + k * sStep, sPitch, s->devID,
                mSize, n);
        }
    }
    else {
        XMem * mem = s->mem;
        int size = s->unitNum * s->unitSize;
        bool isOnSameDevice = (s->devID < 0 && t->devID < 0) || (s->devID == t->devID);

        void * dataTMP = t->data;

        if (!isOnSameDevice)
            dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(mem->devID, size);

        int realBlockSize = blockSize * t->unitSize;
        int blockSplitSize = blockNum / splitNum;

        int * blockIndex = (int*)(mem != NULL ?
            mem->AllocBuf(mem->devID, blockNum * sizeof(int)) :
            XMemAlloc(mem->devID, blockNum * sizeof(int)));

        MakeSplitBlockIndex(blockIndex, splitNum, blockSplitSize, blockNum, mem);

        CopyBlocksOnSite(s->data, realBlockSize, blockNum, dataTMP, blockIndex, mem);

        if (mem != NULL)
            mem->ReleaseBuf(mem->devID, blockNum * sizeof(int));
        else
            XMemFree(mem->devID, blockIndex);

        /* copy from tmp to target */
        if (!isOnSameDevice) {
            XMemCopy(t->data, t->devID, dataTMP, s->devID, size);

            if (mem != NULL)
                mem->ReleaseBuf(mem->devID, size);
            else
                XMemFree(mem->devID, dataTMP);
        }
    }
}

/*
split a big tensor into small tensors
>> big - the source tensor
>> smalls - the list that keeps the resulting tensors (for return)
NOTE that all the "small" tensors have already been
placed in the list in advance.
>> whereToSplit - which dimension of the tensor is to split
>> splitNum - how many splits
*/
void Split(XTensor * big, XList * smalls, int whereToSplit, int splitNum)
{
    CheckNTErrors((smalls != NULL), "Invalid list!");
    CheckNTErrors((smalls->count == splitNum), "Unmatched tensors!");
    CheckNTErrors((smalls->count > 0), "Wrong input!");

    int whereToSplitRDI = big->order - whereToSplit - 1;
    bool uniform = true;

    for (int i = 0; i < smalls->count; i++) {
        XTensor* smallsItem = (XTensor*)smalls->GetItem(i);
        CheckNTErrors((big->unitNum == smallsItem->unitNum * splitNum), "Unmatched tensors!");
        if (i > 0) {
            XTensor * preItem = (XTensor*)smalls->GetItem(i - 1);
            if (smallsItem->unitNum * smallsItem->unitSize != (char*)smallsItem->data - (char*)preItem->data)
                uniform = false;
        }
    }

    int blockSize = 1;
    int blockNum = 1;
    for (int i = 0; i < big->order; i++) {
        if (i == whereToSplitRDI) {
            blockSize *= big->dimSizeRDI[i] / splitNum;
            blockNum *= splitNum;
        }
        else if (i < whereToSplitRDI)
            blockSize *= big->dimSizeRDI[i];
        else
            blockNum *= big->dimSizeRDI[i];
    }

    CheckNTErrors((blockNum % splitNum == 0), "Incorrect split number!");

    /* splitting with fewer data copy operations */
    if (splitNum <= MIN_TENSOR_SPLIT_LIST_NUM) {
        XTensor * t0 = (XTensor*)smalls->GetItem(0);
        int sPitch = blockSize * splitNum * big->unitSize;
        int tPitch = blockSize * t0->unitSize;
        int mSize = blockSize * t0->unitSize;
        int n = blockNum / splitNum;
        int sStep = blockSize * big->unitSize;
        int tStep = 0;
        for (int k = 0; k < splitNum; k++) {
            XTensor * t = (XTensor*)smalls->GetItem(k);
            XMemCopy2D((char*)t->data + k * tStep, tPitch, t->devID,
                (char*)big->data + k * sStep, sPitch, big->devID,
                mSize, n);
        }
    }
    /* splitting with fewer kernel/api calls??? (i'm not sure about it!! may remove this later) */
    else {
        int* dimSizeTMP = new int[MAX_TENSOR_DIM_NUM];
        for (int i = 0; i < MAX_TENSOR_DIM_NUM; i++)
            dimSizeTMP[i] = -big->dimSize[i];
        dimSizeTMP[whereToSplit] /= splitNum;
        dimSizeTMP[big->order] = -splitNum;

        XMem * mem = big->mem;
        XTensor* tensorTMP = new XTensor(big->order + 1, dimSizeTMP, big->dataType, big->denseRatio, mem);
        int size = big->unitNum * big->unitSize;
        void * dataTMP = NULL;

        if (uniform) {
            XTensor* first = (XTensor*)smalls->GetItem(0);
            dataTMP = first->data;
        }
        else {
            dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(mem->devID, size);
        }

        tensorTMP->data = dataTMP;

        Split(big, tensorTMP, whereToSplit, splitNum);

        /* copy from tmp to target */
        if (!uniform) {
            int splitSize = big->unitNum * big->unitSize / splitNum;
            for (int i = 0; i < splitNum; i++) {
                XTensor* smallsItem = (XTensor*)smalls->GetItem(i);
                XMemCopy(smallsItem->data, smallsItem->devID, (char*)(tensorTMP->data) + (splitSize * i), tensorTMP->devID, splitSize);
            }
        }

        delete[] dimSizeTMP;

        tensorTMP->data = NULL;
        dataTMP = NULL;
        delete tensorTMP;

        if ((!uniform) && (mem != NULL))
            mem->ReleaseBuf(mem->devID, size);
        else
            XMemFree(mem->devID, dataTMP);
    }
}
} // namespace nts(NiuTrans.Tensor)