/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-06-15
*/

#include "../XTensor.h"
#include "TMatrixMULBatchedCPU.h"

namespace nts { // namespace nts(NiuTrans.Tensor)
/* case 1: matrix multiplication in batch mode (CPU code). 
* In this case, aList=2*(2, 3), bList=2*(2, 3) -> c=2*(2, 2), 
  transposedA=X_NOTRANS, transposedB=X_NOTRANS.
*/
bool TestMatrixMulBatchedCPU1()
{
    /* create list */
    XList * aList = new XList();
    XList * bList = new XList();
    XList * cList = new XList();

    /* a source tensor of size (2, 3) */
    int aOrder = 2;
    int * aDimSize = new int[aOrder];
    aDimSize[0] = 2;
    aDimSize[1] = 3;

    int aUnitNum = 1;
    for (int i = 0; i < aOrder; i++)
        aUnitNum *= aDimSize[i];

    /* a source tensor of size (3, 2) */
    int bOrder = 2;
    int * bDimSize = new int[bOrder];
    bDimSize[0] = 3;
    bDimSize[1] = 2;

    int bUnitNum = 1;
    for (int i = 0; i < bOrder; i++)
        bUnitNum *= bDimSize[i];

    /* a target tensor of size (2, 2) */
    int cOrder = 2;
    int * cDimSize = new int[cOrder];
    cDimSize[0] = 2;
    cDimSize[1] = 2;

    int cUnitNum = 1;
    for (int i = 0; i < cOrder; i++)
        cUnitNum *= cDimSize[i];

    DTYPE aData1[2][3] = { {1.0, 2.0, 3.0},
                           {-4.0, 5.0, 6.0} };
    DTYPE aData2[2][3] = { {1.0, -2.0, -3.0},
                           {-4.0, 3.0, 2.0} };
    DTYPE bData1[3][2] = { {0.0, -1.0},
                           {1.0, 2.0}, 
                           {2.0, 1.0} };
    DTYPE bData2[3][2] = { {0.0, 1.0},
                           {3.0, 2.0}, 
                           {2.0, 1.0} };
    DTYPE answer1[2][2] = { {8.0, 6.0}, 
                            {17.0, 20.0} };
    DTYPE answer2[2][2] = { {-12.0, -6.0}, 
                            {13.0, 4.0} };

    /* CPU test */
    bool cpuTest = true;

    /* create tensors */
    XTensor * a1 = NewTensor(aOrder, aDimSize);
    XTensor * a2 = NewTensor(aOrder, aDimSize);
    XTensor * b1 = NewTensor(bOrder, bDimSize);
    XTensor * b2 = NewTensor(bOrder, bDimSize);
    XTensor * c1 = NewTensor(cOrder, cDimSize);
    XTensor * c2 = NewTensor(cOrder, cDimSize);

    /* initialize variables */
    a1->SetData(aData1, aUnitNum);
    a2->SetData(aData2, aUnitNum);
    b1->SetData(bData1, aUnitNum);
    b2->SetData(bData2, aUnitNum);
    c1->SetZeroAll();
    c2->SetZeroAll();

    /* add tensors to list */
    aList->Add(a1);
    aList->Add(a2);
    bList->Add(b1);
    bList->Add(b2);
    cList->Add(c1);
    cList->Add(c2);

    /* call MatrixMULBatchedCPU function */
    MatrixMULBatchedCPU(aList, X_NOTRANS, bList, X_NOTRANS, cList);

    /* check results */
    cpuTest = c1->CheckData(answer1, cUnitNum) && cpuTest;
    cpuTest = c2->CheckData(answer2, cUnitNum) && cpuTest;
    
#ifdef USE_CUDA
    /* GPU test */
    bool gpuTest = true;

    /* clear list */
    aList->Clear();
    bList->Clear();
    cList->Clear();

    /* create tensors */
    XTensor * aGPU1 = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
    XTensor * aGPU2 = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
    XTensor * bGPU1 = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
    XTensor * bGPU2 = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
    XTensor * cGPU1 = NewTensor(cOrder, cDimSize, X_FLOAT, 1.0F, 0);
    XTensor * cGPU2 = NewTensor(cOrder, cDimSize, X_FLOAT, 1.0F, 0);

    /* initialize variables */
    aGPU1->SetData(aData1, aUnitNum);
    aGPU2->SetData(aData2, aUnitNum);
    bGPU1->SetData(bData1, aUnitNum);
    bGPU2->SetData(bData2, aUnitNum);
    cGPU1->SetZeroAll();
    cGPU2->SetZeroAll();

    /* add tensors to list */
    aList->Add(a1);
    aList->Add(a2);
    bList->Add(b1);
    bList->Add(b2);
    cList->Add(c1);
    cList->Add(c2);

    /* call MatrixMULBatchedCPU function */
    MatrixMULBatchedCPU(aList, X_NOTRANS, bList, X_NOTRANS, cList);

    /* check results */
    gpuTest = c1->CheckData(answer1, cUnitNum) && gpuTest;
    gpuTest = c2->CheckData(answer2, cUnitNum) && gpuTest;

    /* destroy variables */
    delete a1, a2, b1, b2, c1, c2;
    delete aGPU1, aGPU2, bGPU1, bGPU2, cGPU1, cGPU2;
    delete[] aDimSize, bDimSize, cDimSize;

    return cpuTest && gpuTest;
#else
    /* destroy variables */
    delete a1, a2, b1, b2, c1, c2;
    delete[] aDimSize, bDimSize, cDimSize;

    return cpuTest;
#endif // USE_CUDA
}

/* other cases */
/*
    TODO!!
*/

/* test for MatrixMulBatchedCPU Function */
extern "C"
bool TestMatrixMulBatchedCPU()
{
    XPRINT(0, stdout, "[TEST MATRIXMULBATCHEDCPU] -------------\n");
    bool returnFlag = true, caseFlag = true;

    /* case 1 test */
    caseFlag = TestMatrixMulBatchedCPU1();

    if (!caseFlag) {
        returnFlag = false;
        XPRINT(0, stdout, ">> case 1 failed!\n");
    }
    else
        XPRINT(0, stdout, ">> case 1 passed!\n");

    ///* case 2 test */
    //caseFlag = TestMatrixMulBatchedCPU2();
    //if (!caseFlag) {
    //    returnFlag = false;
    //    XPRINT(0, stdout, ">> case 2 failed!\n");
    //}
    //else
    //    XPRINT(0, stdout, ">> case 2 passed!\n");

    /* other cases test */
    /*
    TODO!!
    */

    if (returnFlag) {
        XPRINT(0, stdout, ">> All Passed!\n");
    }
    else
        XPRINT(0, stdout, ">> Failed!\n");

    XPRINT(0, stdout, "\n");

    return returnFlag;
}

} // namespace nts(NiuTrans.Tensor)
