/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-6-24
 */

#include "TXMem.h"
#include "../XGlobal.h"
#include "../XUtility.h"
#include "../XMem.h"

/* the nts (NiuTrans.Tensor) namespace */
namespace nts{

bool TestXMemCase1()
{
    bool ok = true;
    int caseNum = 1000;
    int blcokSize = 16;
    int testNum = caseNum * 10;

    for(int i = 0, scalar = 1; i < 3; i++){
        XMem mem;
        mem.Initialize(-1, FREE_ON_THE_FLY, blcokSize * sizeof(int) * scalar * scalar, 1000, 0);
        mem.SetIndex(10000, blcokSize * sizeof(int) / 2);

        srand(907);

        int ** p = new int*[caseNum];
        int * size = new int[caseNum];

        for(int i = 0; i < caseNum; i++){
            p[i] = NULL;
            size[i] = rand() % (2*blcokSize);
        }

        for(int i = 0; i < testNum * scalar; i++){
            int j = rand() % caseNum;

            if(p[j] == NULL){
                p[j] = (int*)mem.AllocStandard(mem.devID, size[j] * sizeof(int));
                for(int k = 0; k < size[j]; k++)
                    p[j][k] = j;
            }
            else{
                mem.ReleaseStandard(mem.devID, p[j]);
                for(int k = 0; k < size[j]; k++)
                    p[j][k] = -1;
                p[j] = NULL;
            }

            for(int k = 0; k < caseNum; k++){
                if(p[k] != NULL){
                    for(int o = 0; o < size[k]; o++){
                        if(p[k][o] != k){
                            ok = false;
                        }
                    }
                }
            }
        }

        delete[] p;
        delete[] size;
        scalar *= 2;
    }

    return ok;
}

bool TestXMem()
{
    XPRINT(0, stdout, "[Test] Memory pool ... Began\n");
    bool returnFlag = true;
    bool caseFlag = true;

    double startT = GetClock();

    /* case 1 test */
    caseFlag = TestXMemCase1();
    if (!caseFlag) { returnFlag = false; XPRINT(0, stdout, ">> case 1 failed!\n"); }
    else {XPRINT(0, stdout, ">> case 1 passed!\n");}

    if (returnFlag) { XPRINT(0, stdout, ">> All Passed!\n"); }
    else { XPRINT(0, stdout, ">> Failed!\n"); }

    double endT = GetClock();

    XPRINT1(0, stdout, "[Test] Finished (took %.3lfms)\n\n", endT - startT);

    return returnFlag;
}

} /* end of the nts (NiuTrans.Tensor) namespace */