Unsqueeze.cpp 5.91 KB
Newer Older
xiaotong committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

22 23
#include "../../XTensor.h"
#include "../../XName.h"
xiaotong committed
24 25 26 27 28 29 30
#include "Unsqueeze.h"
#include "MergeBlockLists.h"
#include "Unsqueeze.cuh"

namespace nts { // namespace nts(NiuTrans.Tensor)

/*
31 32 33
insert a dimension by copying the blocks for x times
(where x is the size of the inerted dimension)

xiaotong committed
34 35 36 37 38
>> a - input tensor
>> b - output tensor
>> dim - where to insert the dimension
>> dSize - size of the newly-inserted dimension
*/
39
void _Unsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
xiaotong committed
40 41 42 43 44 45
{
    CheckNTErrors((a && b), "Empty input tensors!");
    CheckNTErrors((a->order == b->order - 1), "Unmatched tensors!");
    CheckNTErrors((a->unitSize == b->unitSize), "Unmatched tensors!");

    for (int i = 0; i < b->order; i++) {
46 47
        if (i < dim) {
            CheckNTErrors((a->dimSize[i] == b->dimSize[i]), "Unmatched tensors!");
xiaotong committed
48
        }
49 50
        else if (i > dim) {
            CheckNTErrors((a->dimSize[i - 1] == b->dimSize[i]), "Unmatched tensors!");
xiaotong committed
51 52
        }
        else {
53
            CheckNTErrors((dSize == b->dimSize[i]), "Unmatched tensors!");
xiaotong committed
54 55 56 57 58 59 60 61
        }
    }

    int blockSize = 1;
    int realBlockSize = 1;

    int blockNumA = 1;
    int blockNumB = 1;
62 63
    for (int i = dim; i < a->order; i++)
        blockSize *= a->dimSize[i];
xiaotong committed
64 65 66 67 68 69 70 71 72 73

    realBlockSize = blockSize * a->unitSize;

    blockNumA = a->unitNum / blockSize;
    blockNumB = b->unitNum / blockSize;

    CheckNTErrors((blockNumA * dSize == blockNumB), "Unmatched tensors!");

    if (a->devID >= 0 || b->devID >= 0) {
#ifdef USE_CUDA
74
        _CudaUnsqueeze(a, b, dim, dSize);
xiaotong committed
75 76 77 78 79
#else
        ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
    }
    else {
80
        StrList * sourceArrays = new StrList(blockNumB);
xiaotong committed
81 82 83 84 85 86 87 88 89 90
        int * blockSizes = new int[blockNumB];

        for (int i = 0; i < blockNumA; i++) {
            char * ap = (char*)a->data + i * realBlockSize;
            for (int j = 0; j < dSize; j++) {
                sourceArrays->Add(ap);
                blockSizes[i * dSize + j] = realBlockSize;
            }
        }

91
        _MergeBlockLists(sourceArrays, blockSizes, 1, b->data, b->mem);
xiaotong committed
92 93 94 95 96 97

        delete sourceArrays;
        delete[] blockSizes;
    }
}

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
bool CheckUnsqueezeSize(const XTensor * a, const XTensor * b, int dim, int dSize)
{
    if (!(a && b))
        return false;

    if (!(a->dataType == b->dataType))
        return false;

    int order = a->order + 1;
    int * dimSize = new int[order];

    for (int i = 0; i < order; i++) {
        if (i < dim)
            dimSize[i] = a->dimSize[i];
        else if (i == dim)
            dimSize[i] = dSize;
        else
            dimSize[i] = a->dimSize[i - 1];
    }

    for (int i = 0; i < order; i++) {
        if (dimSize[i] != b->dimSize[i])
            return false;
    }

    return true;
}

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
/*
insert a dimension by copying the blocks for x times
(where x is the size of the inerted dimension) (returna a XTensor structure)
make a new tensor to keep the result and return it

>> a - input tensor
>> dim - where to insert the dimension
>> dSize - size of the newly-inserted dimension
<< return - a tensor by inserting a dimension by copying the blocks for x times
*/
XTensor Unsqueeze(const XTensor &a, int dim, int dSize)
{
    int order = a.order + 1;
    int * dimSize = new int[order];

    for (int i = 0; i < order; i++) {
        if (i < dim)
            dimSize[i] = a.dimSize[i];
        else if (i == dim)
            dimSize[i] = dSize;
        else
            dimSize[i] = a.dimSize[i - 1];
    }

150 151
    float dr = (!a.isSparse) ? 1.0F : a.denseRatio;
    XTensor b(order, dimSize, a.dataType, dr, a.devID, a.mem);
xiaotong committed
152
    b.SetTMPFlag();
153 154 155 156

    /* call _Unsqueeze function */
    _Unsqueeze(&a, &b, dim, dSize);

157
    /* tensor connections */
158 159 160 161 162
    if (a.enableGrad) {
        XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE);
        XLink::AddParamToHeadInt(&b, dim);
        XLink::AddParamToHeadInt(&b, dSize);
    }
163

164
    /* destroy variables */
xiaotong committed
165
    delete[] dimSize;
166 167 168 169

    return b;
}

xuchen committed
170 171 172 173 174 175 176 177 178 179
/*
insert a dimension by copying the blocks for x times
(where x is the size of the inerted dimension) (returna a XTensor structure)
make a new tensor to keep the result and return it

>> a - the input tensor
>> b - the output tensor
>> dim - where to insert the dimension
>> dSize - size of the newly-inserted dimension
*/
180
void Unsqueeze(const XTensor &a, XTensor &b, int dim, int dSize)
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
{
    if (!b.isInit || !CheckUnsqueezeSize(&a, &b, dim, dSize)) {
        int order = a.order + 1;
        int * dimSize = new int[order];

        for (int i = 0; i < order; i++) {
            if (i < dim)
                dimSize[i] = a.dimSize[i];
            else if (i == dim)
                dimSize[i] = dSize;
            else
                dimSize[i] = a.dimSize[i - 1];
        }

        float dr = (!a.isSparse) ? 1.0F : a.denseRatio;
196
        InitTensorV2(&b, order, dimSize, a.dataType, dr, a.devID, a.mem);
197 198 199 200 201 202 203 204

        /* destroy variables */
        delete[] dimSize;
    }

    /* call _Unsqueeze function */
    _Unsqueeze(&a, &b, dim, dSize);

205
    if (a.enableGrad) {
206 207 208 209 210 211 212
        /* tensor connections */
        XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE);
        XLink::AddParamToHeadInt(&b, dim);
        XLink::AddParamToHeadInt(&b, dSize);
    }
}

xiaotong committed
213
} // namespace nts(NiuTrans.Tensor)