/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

#include "../../XTensor.h"
#include "../../XUtility.h"
#include "../../XName.h"
#include "ConcatenateSolely.h"
#include "MergeBlockLists.h"

namespace nts { // namespace nts(NiuTrans.Tensor)

/*
concatenate a list of tensors along a given dimension

>> smalls - a list of tensors for concatenation
>> big - the resulting tensor
>> dim - which dimension we perform the concatenation
*/
void _ConcatenateSolely(const TensorList * smalls, XTensor * big, int dim)
{
    CheckNTErrors(big->order > dim && dim >= 0, "Illegal dimension to concatenate!");

    int catDimSize = 0;

    for (int i = 0; i < smalls->count; i++) {
        XTensor * tensor = (XTensor*)smalls->GetItem(i);
        CheckNTErrors((big->order == tensor->order), "Unmatched tensor orders!");
        for (int j = 0; j < big->order; j++) {
            if (j != dim) {
                CheckNTErrors((big->dimSize[j] == tensor->dimSize[j]), "Unmatched tensor sizes!");
            }
            else {
                catDimSize += tensor->dimSize[j];
            }
        }
    }

    CheckNTErrors((catDimSize == big->dimSize[dim]), "Unmatched tensor sizes!");

    int stride = 1;
    int blockNum = 1;
    for (int i = 0; i < dim; i++)
        blockNum *= big->dimSize[i];

    for (int i = dim + 1; i < big->order; i++)
        stride *= big->dimSize[i];

    int offset = 0;

    /* 
    two strategies are used - we can either resort to memcpy2d for the case of
    concatenation of a few items, or use MergeBlockLists to merge a large number
    of data blocks 
    */
    if (smalls->count <= MIN_TENSOR_CAT_NUM) {
        for (int i = 0; i < smalls->count; i++) {
            XTensor * tensor = (XTensor*)smalls->GetItem(i);
            int sPitch = stride * tensor->dimSize[dim] * tensor->unitSize;
            int tPitch = stride * big->dimSize[dim] * big->unitSize;
            int mSize = sPitch;
            int n = blockNum;
            XMemCopy2D((char*)big->data + offset, tPitch, big->devID,
                (char*)tensor->data, sPitch, tensor->devID,
                mSize, n);
            offset += sPitch;
        }
    }
    else {
        StrList* sourceArrays = new StrList(smalls->count);
        int * blockSizes = new int[smalls->count];
        for (int i = 0; i < smalls->count; i++) {
            XTensor * tensor = (XTensor*)smalls->GetItem(i);
            blockSizes[i] = stride * tensor->dimSize[dim] * tensor->unitSize;
            sourceArrays->Add((char*)tensor->data);
        }

        _MergeBlockLists(sourceArrays, blockSizes, blockNum, big->data, big->mem);

        delete[] blockSizes;
        delete sourceArrays;
    }
}
} // namespace nts(NiuTrans.Tensor)