/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

#include "CopyIndexed.h"
#include "CopyBlocks.h"
#include "../../XName.h"

namespace nts { // namespace nts(NiuTrans.Tensor)

/*
copy indexed sub-tensors

>> s - the source tensor
>> t - the target tensor
>> dim - the leading dimension to define "sub-tensors"
         e.g., for a tensor of size (3, 2, 4) and dim = 2, 
         we have 4 sub-tensors of size (3,2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
>> tgtIndex - index of the target sub-tensors
>> copyNum - number of the sub-tensors we copy for each source index, 
             e.g., for srcIndex = [1,4] and copyNum = 2,
             we actually copy the source sub-tensors 1, 2, 4, 5
*/
void _CopyIndexed(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize, int * tgtIndex, int copyNum)
{
    CheckNTErrors((s && t), "Invalid tensors!");
    CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)),
                  "the data must be kept on the same device!");
    CheckNTErrors((dim < s->order && dim < t->order), "A too larget dimension specified!");
    CheckNTErrors((s->unitSize == t->unitSize), "Unmatched tensors!");

    int dimRDI = s->order - dim - 1;
    int blockSizeSrc = 1;
    int blockSizeTgt = 1;
    int blockNumSrc = 1;
    int blockNumTgt = 1;
    int leadDimSizeSrc = s->dimSizeRDI[dimRDI];
    int leadDimSizeTgt = t->dimSizeRDI[dimRDI];
    int indexOffsetNum = 1;

    for (int i = 0; i < dimRDI; i++) {
        blockSizeSrc *= s->dimSizeRDI[i];
        blockSizeTgt *= t->dimSizeRDI[i];
    }
    for (int i = dimRDI; i < s->order; i++)
        blockNumSrc *= s->dimSizeRDI[i];
    for (int i = dimRDI; i < t->order; i++)
        blockNumTgt *= t->dimSizeRDI[i];

    CheckNTErrors((blockSizeSrc == blockSizeTgt), "Unmatched tensors!");
    indexOffsetNum = blockNumSrc / s->dimSizeRDI[dimRDI];

    int realIndexSize = indexOffsetNum * indexSize * copyNum;
    int * realSrcIndex = new int[realIndexSize];
    int * realTgtIndex = new int[realIndexSize];
    for (int i = 0; i < indexOffsetNum; i++) {
        int base = i * indexSize * copyNum;
        int baseSrc = i * leadDimSizeSrc;
        int baseTgt = i * leadDimSizeTgt;
        for (int j = 0; j < indexSize; j++) {
            int offset = base + j * copyNum;
            int * rsi = realSrcIndex + offset;
            int * rti = realTgtIndex + offset;
            for (int k = 0; k < copyNum; k++) {
                rsi[k] = baseSrc + srcIndex[j] + k;
                rti[k] = baseTgt + tgtIndex[j] + k;
            }
        }
    }

    for (int i = 0; i < indexSize; i++) {
        CheckNTErrors((srcIndex[i] < blockNumSrc), "Index is out of scope!");
        CheckNTErrors((tgtIndex[i] < blockNumTgt), "Index is out of scope!");
    }

    _CopyBlocks(s->data, blockSizeSrc * s->unitSize, realSrcIndex, realIndexSize, t->data, realTgtIndex, s->mem, s->devID);

    delete[] realSrcIndex;
    delete[] realTgtIndex;
}

/*
copy indexed sub-tensors (return a XTensor structure)
make a new tensor to keep the result and return it

>> s - the source tensor
>> dim - the leading dimension to define "sub-tensors"
         e.g., for a tensor of size (3, 2, 4) and dim = 2, 
         we have 4 sub-tensors of size (3,2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
>> tgtIndex - index of the target sub-tensors
>> copyNum - number of the sub-tensors we copy for each source index, 
   e.g., for srcIndex = [1,4] and copyNum = 2,
   we actually copy the source sub-tensors 1, 2, 4, 5
<< return - the result of copying indexed sub-tensors
*/
XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, int * tgtIndex, int copyNum)
{
    CheckNTErrors(dim >= 0 && dim < s.order, "A too larget dimension specified!");

    int order = s.order;
    int * dimSize = new int[order];

    for (int i = 0; i < s.order; i++) {
        if (i == dim)
            dimSize[i] = indexSize * copyNum;
        else
            dimSize[i] = s.dimSize[i];
    }
    
    float dr = (!s.isSparse) ? 1.0F : s.denseRatio;
    XTensor t(order, dimSize, s.dataType, dr, s.devID, s.mem);
    t.SetTMPFlag();

    /* call _CopyIndexed function */
    _CopyIndexed(&s, &t, dim, srcIndex, indexSize, tgtIndex, copyNum);

    /* tensor connection */
    XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYINDEXED);
    XLink::AddParamToHeadInt(&t, dim);
    XLink::AddParamToHeadPointer(&t, srcIndex);
    XLink::AddParamToHeadInt(&t, indexSize);
    XLink::AddParamToHeadPointer(&t, tgtIndex);
    XLink::AddParamToHeadInt(&t, copyNum);
    
    /* destroy variables */
    delete[] dimSize;
    
    return t;
}

} // namespace nts(NiuTrans.Tensor)
