/* NiuTrans.Tensor - an open-source tensor library * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. * All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-6-24 */ #include "TXMem.h" #include "../XGlobal.h" #include "../XUtility.h" #include "../XMem.h" /* the nts (NiuTrans.Tensor) namespace */ namespace nts{ bool TestXMemCase1() { bool ok = true; int caseNum = 1000; int blcokSize = 16; int testNum = caseNum * 10; for(int i = 0, scalar = 1; i < 3; i++){ XMem mem; mem.Initialize(-1, FREE_ON_THE_FLY, blcokSize * sizeof(int) * scalar * scalar, 1000, 0); mem.SetIndex(10000, blcokSize * sizeof(int) / 2); srand(907); int ** p = new int*[caseNum]; int * size = new int[caseNum]; for(int i = 0; i < caseNum; i++){ p[i] = NULL; size[i] = rand() % (2*blcokSize); } for(int i = 0; i < testNum * scalar; i++){ int j = rand() % caseNum; if(p[j] == NULL){ p[j] = (int*)mem.AllocStandard(mem.devID, size[j] * sizeof(int)); for(int k = 0; k < size[j]; k++) p[j][k] = j; } else{ mem.ReleaseStandard(mem.devID, p[j]); for(int k = 0; k < size[j]; k++) p[j][k] = -1; p[j] = NULL; } for(int k = 0; k < caseNum; k++){ if(p[k] != NULL){ for(int o = 0; o < size[k]; o++){ if(p[k][o] != k){ ok = false; } } } } } delete[] p; delete[] size; scalar *= 2; } return ok; } bool TestXMem() { XPRINT(0, stdout, "[Test] Memory pool ... Began\n"); bool returnFlag = true; bool caseFlag = true; double startT = GetClock(); /* case 1 test */ caseFlag = TestXMemCase1(); if (!caseFlag) { returnFlag = false; XPRINT(0, stdout, ">> case 1 failed!\n"); } else {XPRINT(0, stdout, ">> case 1 passed!\n");} if (returnFlag) { XPRINT(0, stdout, ">> All Passed!\n"); } else { XPRINT(0, stdout, ">> Failed!\n"); } double endT = GetClock(); XPRINT1(0, stdout, "[Test] Finished (took %.3lfms)\n\n", endT - startT); return returnFlag; } } /* end of the nts (NiuTrans.Tensor) namespace */