TCompare.cpp 3.84 KB
Newer Older
linye committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2017, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-12
 */

#include "../XTensor.h"
#include "../core/math/Compare.h"
#include "TCompare.h"

namespace nts { // namespace nts(NiuTrans.Tensor)

/*
case 1: test Equal function.
Comapre whether every entry is equal to the specified value.
*/
bool TestCompare1()
{
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
    /* a tensor of size (3, 2) */
    int aOrder = 2;
    int * aDimSize = new int[aOrder];
    aDimSize[0] = 3;
    aDimSize[1] = 2;

    int aUnitNum = 1;
    for (int i = 0; i < aOrder; i++)
        aUnitNum *= aDimSize[i];

    DTYPE aData[3][2] = { {1.0F, -2.0F},
                          {0.0F, 4.0F},
                          {5.0F, 1.0F} };
    DTYPE answer[3][2] = { {1.0F, 0.0F},
                           {0.0F, 0.0F},
                              {0.0F, 1.0F} };

    /* CPU test */
    bool cpuTest = true;

    /* create tensors */
    XTensor * a = NewTensor(aOrder, aDimSize);
    XTensor * b = NewTensor(aOrder, aDimSize);
    XTensor * aMe = NewTensor(aOrder, aDimSize);
    XTensor bUser;

    /* initialize variables */
    a->SetData(aData, aUnitNum);
    aMe->SetData(aData, aUnitNum);

    /* call Equal function */
    _Equal(a, b, 1.0);
    _EqualMe(aMe, 1.0);
    bUser = Equal(*a, 1.0);

    /* check results */
    cpuTest = b->CheckData(answer, aUnitNum, 1e-4F) && 
linye committed
71 72 73 74
              aMe->CheckData(answer, aUnitNum, 1e-4F) && 
              bUser.CheckData(answer, aUnitNum, 1e-4F);

#ifdef USE_CUDA
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
    /* GPU test */
    bool gpuTest = true;

    /* create tensor */
    XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
    XTensor * bGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
    XTensor * aMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
    XTensor bUserGPU;

    /* Initialize variables */
    aGPU->SetData(aData, aUnitNum);
    aMeGPU->SetData(aData, aUnitNum);

    /* call Equal function */
    _Equal(aGPU, bGPU, 1.0);
    _EqualMe(aMeGPU, 1.0);
    bUserGPU = Equal(*aGPU, 1.0);

    /* check results */
    gpuTest = bGPU->CheckData(answer, aUnitNum, 1e-4F) && 
linye committed
95 96 97
              aMeGPU->CheckData(answer, aUnitNum, 1e-4F) && 
              bUserGPU.CheckData(answer, aUnitNum, 1e-4F);

98 99 100 101 102 103 104 105
    /* destroy variables */
    delete a;
    delete b;
    delete aMe;
    delete aGPU;
    delete bGPU;
    delete aMeGPU;
    delete[] aDimSize;
linye committed
106

107
    return cpuTest && gpuTest;
linye committed
108
#else
109 110 111 112 113
    /* destroy variables */
    delete a;
    delete b;
    delete aMe;
    delete[] aDimSize;
linye committed
114

115
    return cpuTest;
linye committed
116 117 118 119 120 121 122 123 124 125 126
#endif // USE_CUDA
}

/* other cases */
/*
TODO!!
*/

/* test for Compare Function */
bool TestCompare()
{
127 128
    XPRINT(0, stdout, "[TEST Compare] compare every entry with specified value \n");
    bool returnFlag = true, caseFlag = true;
linye committed
129

130 131
    /* case 1 test */
    caseFlag = TestCompare1();
linye committed
132

133 134 135 136 137 138
    if (!caseFlag) {
        returnFlag = false;
        XPRINT(0, stdout, ">> case 1 failed!\n");
    }
    else
        XPRINT(0, stdout, ">> case 1 passed!\n");
linye committed
139

140 141 142 143
    /* other cases test */
    /*
    TODO!!
    */
linye committed
144

145 146 147 148 149
    if (returnFlag) {
        XPRINT(0, stdout, ">> All Passed!\n");
    }
    else
        XPRINT(0, stdout, ">> Failed!\n");
linye committed
150

151
    XPRINT(0, stdout, "\n");
linye committed
152

153
    return returnFlag;
linye committed
154 155 156
}

} // namespace nts(NiuTrans.Tensor)