/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-18
 */

#include "Gather.h"
#include "CopyIndexed.h"
#include "../../XUtility.h"
#include "../shape/Reshape.h"

namespace nts{ // namespace nts(NiuTrans.Tensor)

/*
gather indexed sub-tensors

>> s - the source tensor
>> t - the target tensor
>> dim - the leading dimension to define "sub-tensors"
         e.g., for a tensor of size (3, 2, 4) and dim = 2, 
         we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
*/
void _Gather(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize)
{
    int * tgtIndex = new int[indexSize];
    for(int i = 0; i < indexSize; i++)
        tgtIndex[i] = i;

    _CopyIndexed(s, t, dim, srcIndex, indexSize, tgtIndex, 1);

    delete[] tgtIndex;
}

/*
gather indexed sub-tensors (return a XTensor structure)
make a new tensor to keep the result and return it

>> s - the source tensor
>> dim - the leading dimension to define "sub-tensors"
         e.g., for a tensor of size (3, 2, 4) and dim = 2, 
         we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
<< return - the result of copying indexed sub-tensors

Notice: the index must be on the CPU!!!
*/
XTensor Gather(const XTensor &s, int dim, int * srcIndex, int indexSize)
{
    int * tgtIndex = new int[indexSize];
    for(int i = 0; i < indexSize; i++)
        tgtIndex[i] = i;
	
    /* call CopyIndexed function */
    XTensor result;
    result = CopyIndexed(s, dim, srcIndex, indexSize, tgtIndex, 1);

    delete[] tgtIndex;

    return result;
}

/*
gather indexed sub-tensors (return a XTensor structure)
make a new tensor to keep the result and return it

>> s - the source tensor(2D)
>> index - the index tensor
<< return - the result of copying indexed sub-tensors
*/
XTensor Gather(const XTensor &s, const XTensor &index)
{
    int indexSize = index.unitNum;
    CheckNTErrors(s.order == 2, "The order of the input tensor must be 2!");
 
    int * srcIndex = new int[index.unitNum];

    if(index.dataType == X_INT) {
        XMemCopy(srcIndex, -1, index.data, index.devID, indexSize * index.unitSize);
    }
    else if(index.dataType == X_FLOAT || index.dataType == X_DOUBLE) {
        DTYPE * tmp = new DTYPE[indexSize];
        XMemCopy(tmp, -1, index.data, index.devID, indexSize * index.unitSize);
        for(int i = 0; i < indexSize; i++)
            srcIndex[i] = (int)tmp[i];
        delete[] tmp;
    }
    else{
        ShowNTErrors("Unsupported data type!");
    }

    XTensor tensor;
    tensor = Gather(s, 0, srcIndex, indexSize);
    delete[] srcIndex;

    if(index.order > 1) {
        int * dims = new int[index.order + 1];
        memcpy(dims, index.dimSize, index.order * sizeof(int));
        dims[index.order] = tensor.GetDim(-1);

        XTensor t;
        t = Reshape(tensor, index.order + 1, dims);
        delete[] dims;

        return t;
    }
    else {
        return tensor;
    }   
}

} // namespace nts(NiuTrans.Tensor)