Main.cpp 2.79 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2017, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 *
 * This is the entrance of the low-level tensor library : NiuTrans.Tensor
 *
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2015-12-14
 *
 */

#include <stdio.h>
#include <math.h>
#include <time.h>
#include "XTensor.h"
#include "XDevice.h"
xiaotong committed
31
#include "./test/Test.h"
32
#include "./core/CHeader.h"
33 34 35 36 37 38 39 40

//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>  
//#include <crtdbg.h> 

using namespace nts;

void SmallTest();
41
void TransposeTest();
42 43 44

int main( int argc, const char ** argv )
{
xiaotong committed
45
    //_CrtSetBreakAlloc(123);
46 47

    /* a tiny test */
48
    SmallTest();
xiaotong committed
49 50

    //_CrtDumpMemoryLeaks();
51
    //return 0;
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

    if(argc > 1 && !strcmp(argv[1], "-test"))
        Test();
    else{
        fprintf(stderr, "Thanks for using NiuTrans.Tensor! This is a library that eases the\n");
        fprintf(stderr, "use of tensors. All you need is to ... \n\n");
        fprintf(stderr, "Run this program with \"-test\" for unit test!\n");
    }

    //_CrtDumpMemoryLeaks();

    return 0;
}

void SmallTest()
{
    XTensor a;
    XTensor b;
70 71
    XTensor c;
    XTensor d;
72 73

    InitTensor2D(&a, 2, 2);
74
    InitTensor2D(&b, 2, 2);
75
    a.SetZeroAll();
76
    b.SetZeroAll();
77 78 79
    a.Set2D(1.0F, 0, 0);
    a.Set2D(2.0F, 1, 1);

80
    b = Sum(a, Multiply(a, a));
xiaotong committed
81

82
    /* this is prohibited !!!!!!!!!!!!! */
83 84 85 86 87
    //XTensor c = a * b + a;
    //XTensor d = a + b + c.Lin(0.5F);
    
    c = a * b + a;
    d = a + b + c.Lin(0.5F);
xiaotong committed
88 89

    XLink::CheckNetwork(&d);
90
    XLink::ShowNetwork(stderr, &d);
91
        
92 93 94 95
    a.Dump(stderr, "a:");
    b.Dump(stderr, "b:");
    c.Dump(stderr, "c:");
    d.Dump(stderr, "d:");
96
}
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128

void TransposeTest()
{
    XTensor a;
    XTensor b;

    int I = 2;
    int J = 3;

    InitTensor4D(&a, 2, 3, 4, 5);

    int * dims = new int[a.order];
    memcpy(dims, a.dimSize, sizeof(int) * a.order);
    dims[I] = a.dimSize[J];
    dims[J] = a.dimSize[I];

    InitTensor(&b, 4, dims);

    a.SetZeroAll();
    b.SetZeroAll();

    float * data = new float[a.unitNum];
    for(int i = 0; i < a.unitNum; i++)
        data[i] = (float)i;

    a.SetData(data, a.unitNum, 0);

    _Transpose(&a, &b, I, J);
    b.Dump(stderr, "b:");

    delete[] data;
}