/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-06-14
*/

#include "TRectify.h"

namespace nts { // namespace nts(NiuTrans.Tensor)

/* 
case 1: test rectify function
In this case, y = max(0, x) 
*/
bool TestRectify1()
{
    /* a tensor of size (2, 3) */
    int order = 2;
    int * dimSize = new int[order];
    dimSize[0] = 2;
    dimSize[1] = 3;

    int unitNum = 1;
    for (int i = 0; i < order; i++)
        unitNum *= dimSize[i];

    DTYPE xData[2][3] = { {0.0F, -1.0F, 2.0F},
                          {3.0F, -4.0F, -5.0F} };
    DTYPE answer[2][3] = { {0.0F, 0.0F, 2.0F},
                           {3.0F, 0.0F, 0.0F} };

    /* CPU test */
    bool cpuTest = true;

    /* create tensors */
    XTensor * x = NewTensor(order, dimSize);
    XTensor * y = NewTensor(order, dimSize);

    /* initialize variables */
    x->SetData(xData, unitNum);
    y->SetZeroAll();

    /* call Rectify function */
    Rectify(x, y);

    /* check results */
    cpuTest = y->CheckData(answer, unitNum);

#ifdef USE_CUDA
	/* GPU test */
	bool gpuTest = true;

	/* create tensor */
	XTensor * xGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
	XTensor * yGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);

	/* Initialize variables */
	xGPU->SetData(xData, unitNum);
	yGPU->SetZeroAll();

	/* call Rectify function */
	Rectify(xGPU, yGPU);

	/* check results */
	gpuTest = yGPU->CheckData(answer, unitNum);

	/* destroy variables */
	delete x;
    delete y;
    delete xGPU;
    delete yGPU;
	delete[] dimSize;

	return cpuTest && gpuTest;
#else
	/* destroy variables */
	delete x;
    delete y;
	delete[] dimSize;

	return cpuTest;
#endif // USE_CUDA
}

/* 
case 2: backward computation 
dE/dx = dE/dy * dy/dx 
rectified: y = max(0, x) 
In this case, lossName=CROSSENTROPY.
*/
bool TestRectify2()
{
	/* a tensor of size (2, 3) */
	int order = 2;
	int * dimSize = new int[order];
	dimSize[0] = 2;
	dimSize[1] = 3;

	int unitNum = 1;
	for (int i = 0; i < order; i++)
		unitNum *= dimSize[i];

	DTYPE xData[2][3] = { {1.0F, 1.0F, 2.0F},
	                      {2.0F, 4.0F, 5.0F} };
	DTYPE goldData[2][3] = { {1.0F, 1.0F, 1.0F},
	                         {1.0F, 1.0F, 1.0F} };
    DTYPE yAnswer[2][3] = { {1.0F, 1.0F, 2.0F},
	                        {2.0F, 4.0F, 5.0F} };
	DTYPE dedyAnswer[2][3] = { {-1.0F, -1.0F, -0.5F},
	                           {-0.5F, -0.25F, -0.2F} };
	DTYPE dedxAnswer[2][3] = { {-1.0F, -1.0F, -0.5F},
	                           {-0.5F, -0.25F, -0.2F} };

	/* CPU test */
	bool cpuTest = true;

	/* create tensors */
	XTensor * x = NewTensor(order, dimSize);
	XTensor * y = NewTensor(order, dimSize);
	XTensor * gold = NewTensor(order, dimSize);
	XTensor * dedy = NewTensor(order, dimSize);
	XTensor * dedx = NewTensor(order, dimSize);

	/* initialize variables */
	x->SetData(xData, unitNum);
	gold->SetData(goldData, unitNum);
	y->SetZeroAll();
	dedy->SetZeroAll();
	dedx->SetZeroAll();

    /* call Rectify function */
    Rectify(x, y);

	/* call RectifyBackward function */
	RectifyBackward(gold, y, x, dedy, dedx, CROSSENTROPY);

	/* check results */
    cpuTest = y->CheckData(yAnswer, unitNum, 1e-4F)
              && dedx->CheckData(dedxAnswer, unitNum, 1e-4F)
              && dedy->CheckData(dedyAnswer, unitNum, 1e-4F);

#ifdef USE_CUDA
	/* GPU test */
	bool gpuTest = true;

	/* create tensors */
	XTensor * xGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
	XTensor * yGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
	XTensor * goldGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
	XTensor * dedyGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
	XTensor * dedxGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);

	/* initialize variables */
	xGPU->SetData(xData, unitNum);
	goldGPU->SetData(goldData, unitNum);
	yGPU->SetZeroAll();
	dedyGPU->SetZeroAll();
	dedxGPU->SetZeroAll();
    
    /* call Rectify function */
    Rectify(xGPU, yGPU);

	/* call rectifybackward function */
	RectifyBackward(goldGPU, yGPU, xGPU, dedyGPU, dedxGPU, CROSSENTROPY);
    
	/* check results */
    gpuTest = yGPU->CheckData(yAnswer, unitNum, 1e-4F)
              && dedxGPU->CheckData(dedxAnswer, unitNum, 1e-4F)
              && dedyGPU->CheckData(dedyAnswer, unitNum, 1e-4F);

	/* destroy variables */
    delete x;
    delete y;
    delete dedy;
    delete dedx;
    delete gold;
    delete xGPU;
    delete yGPU;
    delete dedyGPU;
    delete dedxGPU;
    delete goldGPU;
	delete[] dimSize;

	return cpuTest && gpuTest;
#else
	/* destroy variables */
    delete x;
    delete y;
    delete dedy;
    delete dedx;
    delete gold;
	delete[] dimSize;

	return cpuTest;
#endif // USE_CUDA
}

/* other cases */
/*
TODO!!
*/

/* test for Rectify Function */
bool TestRectify()
{
    XPRINT(0, stdout, "[TEST RECTIFY] rectify function and its backward computation \n");
    bool returnFlag = true, caseFlag = true;

    /* case 1 test */
    caseFlag = TestRectify1();

    if (!caseFlag) {
        returnFlag = false;
        XPRINT(0, stdout, ">> case 1 failed!\n");
    }
    else
        XPRINT(0, stdout, ">> case 1 passed!\n");

    /* case 2 test */
    caseFlag = TestRectify2();

	if (!caseFlag) {
		returnFlag = false;
		XPRINT(0, stdout, ">> case 2 failed!\n");
	}
	else
		XPRINT(0, stdout, ">> case 2 passed!\n");

    /* other cases test */
    /*
    TODO!!
    */

    if (returnFlag) {
        XPRINT(0, stdout, ">> All Passed!\n");
    }
    else
        XPRINT(0, stdout, ">> Failed!\n");

    XPRINT(0, stdout, "\n");

    return returnFlag;
}

} // namespace nts(NiuTrans.Tensor)