/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2019-02-27
*/

#include "../../XTensor.h"
#include "../../XDevice.h"
#include "../../XName.h"
#include "MulAndShift.h"
#include "MatrixMul.h"
#include "Sum.h"

namespace nts { // namespace nts(NiuTrans.Tensor)

/*
return a dimension if the sum is performed as SumDim (in more details in SumDim.h)
>> a - a tensor
>> b - another tensor for sum
*/
int GetSumIndex(const XTensor &a, const XTensor &b)
{
    if (a.order < b.order)
        return -1;
    if (XTensor::IsSameShaped(&a, &b))
        return -1;

    int hitCount = 0;
    int hitDim = -1;
    for (int i = 0; i < b.order; i++) {
        if (b.dimSize[b.order - 1 - i] == 1)
            continue;
        else if (b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]) {
            hitCount++;
            hitDim = a.order - b.order + i;
        }
    }

    if (hitCount == 1)
        return hitDim;
    else
        return -1;
}

/*
operation c = x * w + b  MulAndShift
>> x - tensor x
>> w - tensor w
>> b - tensor b
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
*/
XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
                    DTYPE alpha, XPRunner * parallelRunner)
{
    CheckNTErrors(x.dataType == w.dataType, "Input tensors should have the same data type!");
    CheckNTErrors(x.order >= 2 && w.order >= 2, "Input tensors must have a order >= 2!");

    int xn = x.dimSizeRDI[1];
    int xm = x.dimSizeRDI[0];
    int wn = w.dimSizeRDI[1];
    int wm = w.dimSizeRDI[0];

    CheckNTErrors(xm == wn, "Unmatched tensors in multiplication!");

    int order = x.order + w.order - 2;
    int sub = 0;
    int * dimSize = new int[order];
    for (int i = 2; i < x.order; i++)
        dimSize[sub++] = x.dimSizeRDI[x.order + 1 - i];
    for (int i = 2; i < w.order; i++)
        dimSize[sub++] = w.dimSizeRDI[w.order + 1 - i];
    dimSize[sub++] = xn;
    dimSize[sub++] = wm;

    float dr = (!x.isSparse || !w.isSparse) ? 1.0F : MAX(x.denseRatio, w.denseRatio);

    XTensor * tmp = NewTensorBuf(order, dimSize, x.dataType, dr, x.devID, x.mem);

    /* call _MatrixMul function */
    _MatrixMul(&x, X_NOTRANS, &w, X_NOTRANS, tmp, alpha, 0, parallelRunner);

    XTensor c(tmp);
    c.SetTMPFlag();

    int n = GetSumIndex(tmp, b);

    if (n == -1) {
        /* call _Sum function */
        _Sum(tmp, &b, &c);

        // TODO!!
        ShowNTErrors("TODO!");

    }
    else if (n >= 0 && n < tmp->order) {
        /* call _SumDim function */
        _SumDim(tmp, &b, &c, n);

    }
    else {
        ShowNTErrors("Something is wrong!");
    }

    /* tensor connections */
    XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
    XLink::AddParamToHeadInt(&c, n);
    XLink::AddParamToHeadTrans(&c, X_NOTRANS);
    XLink::AddParamToHeadTrans(&c, X_NOTRANS);
    //XLink::AddParamToHead(&c, beta);

    /* destroy variables */
    delete[] dimSize;
    DelTensorBuf(tmp);

    return c;
}

}