/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

#include "../XTensor.h"
#include "../XDevice.h"
#include "MatrixMulBatched.h"
#include "MatrixMULBatchedCPU.h"
#include "XTensorBLAS.h"

namespace nts { // namespace nts(NiuTrans.Tensor)

/*
matrix multiplication of the two tensors
for each 2-dimensional data array in a (denoted as ai) and
each 2-dimensional data array in b (denoted as bi), we have
ci = trans(ai) * trans(bi) * alpha + cm * beta
where trans() returns the transposed matrix if the flag is fired
>> a - tensor a
>> transposedA - indicates whether the matrices in a are transposed
>> b - tensor b
>> transposedB - indicates whether teh matrices in b are transposed
>> c - where we keep a*b
>> alpha - a coefficient
>> beta - another coefficient
*/
void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA,
    XTensor * b, MATRIX_TRANS_TYPE transposedB,
    XTensor * c, DTYPE alpha, DTYPE beta,
    XPRunner * parallelRunner)
{
    CheckNTErrors((a && b && c), "Empty input tensors!");
    CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
        "Input tensors should have the same data type!");
    CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2),
        "Input tensors must have a order > 2!");

    int an = transposedA == X_TRANS ? a->dimSize[1] : a->dimSize[0];
    int am = transposedA == X_TRANS ? a->dimSize[0] : a->dimSize[1];
    int bn = transposedB == X_TRANS ? b->dimSize[1] : b->dimSize[0];
    int bm = transposedB == X_TRANS ? b->dimSize[0] : b->dimSize[1];
    int cn = c->dimSize[0];
    int cm = c->dimSize[1];

    CheckNTErrors((am == bn && an == cn && bm == cm),
        "Unmatched tensors in multiplication!");

    int aBlockSize = a->dimSizeRDI[0] * a->dimSizeRDI[1];
    int bBlockSize = b->dimSizeRDI[0] * b->dimSizeRDI[1];
    int cBlockSize = c->dimSizeRDI[0] * c->dimSizeRDI[1];
    int aRealBlockSize = aBlockSize * a->unitSize;
    int bRealBlockSize = bBlockSize * b->unitSize;
    int cRealBlockSize = cBlockSize * c->unitSize;
    int blockNum = 1;

    for (int i = 2; i < a->order; i++) {
        CheckNTErrors((a->dimSizeRDI[i] == c->dimSizeRDI[i]), "Incorrect tensor sizes!");
        CheckNTErrors((b->dimSizeRDI[i] == c->dimSizeRDI[i]), "Incorrect tensor sizes!");
        blockNum *= a->dimSizeRDI[i];
    }

    XList * aList = new XList(10);
    XList * bList = new XList(10);
    XList * cList = new XList(10);
    int aDimSize[2] = { -a->dimSizeRDI[0], a->dimSizeRDI[1] };
    int bDimSize[2] = { -b->dimSizeRDI[0], b->dimSizeRDI[1] };
    int cDimSize[2] = { -c->dimSizeRDI[0], c->dimSizeRDI[1] };

    for (int p = 0; p < blockNum; p++) {
        void * ap = (char*)a->data + aRealBlockSize * p;
        void * bp = (char*)b->data + bRealBlockSize * p;
        void * cp = (char*)c->data + cRealBlockSize * p;
        XTensor * ai = new XTensor(2, aDimSize, a->dataType, a->denseRatio, a->mem);
        XTensor * bi = new XTensor(2, bDimSize, b->dataType, b->denseRatio, b->mem);
        XTensor * ci = new XTensor(2, cDimSize, c->dataType, c->denseRatio, c->mem);
        ai->data = ap;
        bi->data = bp;
        ci->data = cp;
        aList->Add(ai);
        bList->Add(bi);
        cList->Add(ci);
    }

    if (a->devID >= 0 && b->devID >= 0 && c->devID >= 0) {
#ifdef USE_CUDA
        CheckNTErrors((a->devID == b->devID && a->devID == c->devID),
                      "The code must be run on the same GPU!");
        
        int devIDBackup;
        ProtectCudaDev(a->devID, devIDBackup);

        CudaBLASMatrixMULList(a->mem != NULL ? a->mem->GetCublasHandle() : GDevs.GetCudaHandle(a->devID),
                              aList, transposedA,
                              bList, transposedB,
                              cList, aList->count,
                              alpha, beta);

        BacktoCudaDev(a->devID, devIDBackup);
#else
        ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
    }
    else {
        CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
        MatrixMULBatchedCPU(aList, transposedA,
            bList, transposedB,
            cList, alpha, beta);
    }

    for (int i = 0; i < aList->count; i++) {
        XTensor * ai = (XTensor*)aList->GetItem(i);
        ai->data = NULL;
        delete ai;
    }

    for (int i = 0; i < bList->count; i++) {
        XTensor * bi = (XTensor*)bList->GetItem(i);
        bi->data = NULL;
        delete bi;
    }

    for (int i = 0; i < cList->count; i++) {
        XTensor * ci = (XTensor*)cList->GetItem(i);
        ci->data = NULL;
        delete ci;
    }

    delete aList;
    delete bList;
    delete cList;
}

} // namespace nts(NiuTrans.Tensor)
