Sum.cpp 9.56 KB
Newer Older
xiaotong committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

22 23
#include "../../XTensor.h"
#include "../../XName.h"
xiaotong committed
24
#include "../../XUtility.h"
25
#include "../../XBLAS.h"
26
#include "../movement/CopyValues.h"
27
#include "../shape/IsSameShaped.h"
xuchen committed
28
#include "../math/ScaleAndShift.h"
xiaotong committed
29 30
#include "Sum.h"
#include "Sum.cuh"
31
#include "SumDim.h"
xiaotong committed
32 33 34 35 36

namespace nts { // namespace nts(NiuTrans.Tensor)

/*
tensor summation c = a + b * \beta
37

xiaotong committed
38 39 40 41 42
>> a - a tensor
>> b - another tensor
>> c - where we put a+b*\beta. we save it in a if c is NULL
>> beta - the scaling factor
*/
43
void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
xiaotong committed
44
{
45
    CheckNTErrors(a && b && c, "Empty tensor input!");
46 47 48 49
    CheckNTErrors(a->unitNum == b->unitNum && a->unitNum == c->unitNum,
                  "Unmatched tensors in addition!");
    CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
                  "Unmatched tensors in addition!");
xiaotong committed
50

51 52
    CheckDev(a->devID, b->devID);

53 54 55 56
    if(beta == 0){
        _CopyValues(a, c);
        return;
    }
xiaotong committed
57

58
    if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
xiaotong committed
59 60 61 62 63 64 65 66 67 68 69 70 71
#ifdef USE_CUDA
        if (a == c) {
            int P2PAccesible = 0;
#ifdef CUDA_UVA
            cudaDeviceCanAccessPeer(&P2PAccesible, a->devID, b->devID);
#endif
            if ((a->devID < 0 && b->devID >= 0) ||
                (a->devID >= 0 && b->devID < 0) ||
                (a->devID >= 0 && b->devID >= 0 && a->devID != b->devID && !P2PAccesible))
            {
                ShowNTErrors("Cannot run this method on multiple devices simultaneously!");
            }
            else
xiaotong committed
72
                _CudaSum(a, b, c, beta);
xiaotong committed
73 74
        }
        else
xiaotong committed
75
            _CudaSum(a, b, c, beta);
xiaotong committed
76 77 78 79 80

#endif
    }
    else {
        if (!a->isSparse && !b->isSparse) {
81
            CheckNTErrors(!c->isSparse, "Illegal use of sparse tensor in addition!");
82

xiaotong committed
83 84 85 86 87 88 89
            if (a->dataType == DEFAULT_DTYPE &&
                b->dataType == DEFAULT_DTYPE &&
                c->dataType == DEFAULT_DTYPE)
            {
                DTYPE * ap = (DTYPE*)a->data;
                DTYPE * bp = (DTYPE*)b->data;
                DTYPE * cp = (DTYPE*)c->data;
90 91 92
                /* when c != a, OpenBLAS needs to copy a to c first. This operation
                 slow down the speed, so just use OpenBLAS when c == a */
#if defined(USE_BLAS)
93 94 95
                if (c == a) {
                    AXPY(a->unitNum, beta, bp, 1, cp, 1);
                    return;
xiaotong committed
96
                }
xuchen committed
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
#else
                /* unrolling */
                int num = a->unitNum;
                if (num % 4 == 0) {
                    for (int i = 0; i < num; i += 4) {
                        cp[i] = ap[i] + bp[i] * beta;
                        cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
                        cp[i + 2] = ap[i + 2] + bp[i + 2] * beta;
                        cp[i + 3] = ap[i + 3] + bp[i + 3] * beta;
                    }
                }
                else if (num % 2 == 0) {
                    for (int i = 0; i < num; i += 2) {
                        cp[i] = ap[i] + bp[i] * beta;
                        cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
                    }
                }
                else {
                    for (int i = 0; i < num; i++) {
                        cp[i] = ap[i] + bp[i] * beta;
                    }
                }
119
#endif
xuchen committed
120 121 122 123 124 125 126 127 128
            }
            else if (a->dataType == X_INT &&
                     b->dataType == X_INT &&
                     c->dataType == X_INT)
            {
                int * ap = (int*)a->data;
                int * bp = (int*)b->data;
                int * cp = (int*)c->data;

129 130 131 132 133 134 135 136
                /* unrolling */
                int num = a->unitNum;
                if (num % 4 == 0) {
                    for (int i = 0; i < num; i += 4) {
                        cp[i] = ap[i] + bp[i] * beta;
                        cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
                        cp[i + 2] = ap[i + 2] + bp[i + 2] * beta;
                        cp[i + 3] = ap[i + 3] + bp[i + 3] * beta;
xiaotong committed
137
                    }
138 139 140 141 142
                }
                else if (num % 2 == 0) {
                    for (int i = 0; i < num; i += 2) {
                        cp[i] = ap[i] + bp[i] * beta;
                        cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
xiaotong committed
143
                    }
144 145 146 147
                }
                else {
                    for (int i = 0; i < num; i++) {
                        cp[i] = ap[i] + bp[i] * beta;
148
                    }
xiaotong committed
149
                }
150
            }
xiaotong committed
151 152 153 154 155 156 157 158 159 160 161
            else {
                // TODO!!
                ShowNTErrors("TODO!");
            }
        }
        else {
            // TODO!!
            ShowNTErrors("TODO!");
        }
    }
}
xiaotong committed
162 163
    
/*
164 165 166
tensor summation a = a + b * \beta (do it on site)
keep the result in the tensor a and return nothing

xiaotong committed
167 168 169 170
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
*/
171
void _SumMe(XTensor * a, const XTensor * b, DTYPE beta)
xiaotong committed
172 173 174
{
    _Sum(a, b, a, beta);
}
175

176 177 178 179 180 181 182 183
/*
tensor summation a = a + b * \beta (do it on site)
keep the result in the tensor a and return nothing

>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
*/
xuchen committed
184
void SumMe(XTensor & a, const XTensor & b, DTYPE beta)
185
{
xuchen committed
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
    if (b.order == 0){
        DTYPE shift = b.Get0D() * beta;
        _ScaleAndShift(&a, &a, 1.0F, shift);
    }
    else {
        int n = GetBroadcastDimIndex(a, b);

        if (n == -1)
            /* call _Sum function */
            _Sum(&a, &b, &a, beta);
        else if (n >= 0 && n < a.order)
            /* call _SumDim function */
            _SumDim(&a, &b, &a, n, beta);
        else
            ShowNTErrors("Something is wrong!");
    }
202 203
}

204
/* 
xuchen committed
205
return a dimension if the operation is performed as broadcast(e.g. SumDim function)
206
>> a - a tensor
xuchen committed
207
>> b - another tensor for operation
208
*/
xuchen committed
209
int GetBroadcastDimIndex(const XTensor & a, const XTensor & b)
210 211 212
{
    if(a.order < b.order)
        return -1;
213
    if(IsSameShaped(a, b))
xiaotong committed
214
        return -1;
215 216

    int hitDim = -1;
xuchen committed
217
    bool isHit = false;
xuchen committed
218 219 220
    for(int i = 0; i < b.order; i++){
        if(b.dimSize[b.order - 1 - i] == 1)
            continue;
xuchen committed
221 222 223 224 225 226 227 228 229 230 231
        else {
            if (isHit == true)
                return -1;
            else
                isHit = true;
            for (int j = 0; j < a.order; j++){
                if (b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - j]){
                    hitDim = a.order - 1 - j;
                    break;
                }
            }
232 233 234
        }
    }

xuchen committed
235
    return hitDim;
236
}
xiaotong committed
237 238
    
/*
xiaotong committed
239
tensor summation c = a + b * \beta (return an XTensor structure)
240 241
make a new tensor c to keep the result and return it

xiaotong committed
242 243 244
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
245
<< return - the result of tensor summation
xiaotong committed
246
*/
xuchen committed
247
XTensor Sum(const XTensor & a, const XTensor & b, DTYPE beta)
xiaotong committed
248 249
{
    XTensor c(&a);
xiaotong committed
250
    c.SetTMPFlag();
251

xuchen committed
252 253 254 255 256 257
    if (b.order == 0){
        DTYPE shift = b.Get0D() * beta;
        ScaleAndShift(a, c, 1.0F, shift);
    }
    else {
        int n = GetBroadcastDimIndex(a, b);
258

xuchen committed
259 260 261 262 263 264 265 266 267
        if(n == -1){
            /* call _Sum function */
            _Sum(&a, &b, &c, beta);

            /* tensor connections */
            if (a.enableGrad && b.enableGrad) {
                XLink::MakeLink(&a, &b, &c, MATH_SUM);
                XLink::AddParamToHead(&c, beta);
            }
268
        }
xuchen committed
269 270 271 272 273 274 275 276 277 278 279 280 281
        else if(n >= 0 && n < a.order){
            /* call _SumDim function */
            _SumDim(&a, &b, &c, n, beta);

            /* tensor connections */
            if (a.enableGrad && b.enableGrad) {
                XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
                XLink::AddParamToHeadInt(&c, n);
                XLink::AddParamToHead(&c, beta);
            }
        }
        else{
            ShowNTErrors("Something is wrong!");
282
        }
283
    }
xiaotong committed
284 285
    return c;
}
xiaotong committed
286

287 288 289 290 291 292 293
/*
tensor summation c = a + b * \beta

>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
*/
xuchen committed
294
void Sum(const XTensor & a, const XTensor & b, XTensor & c, DTYPE beta)
295
{
296 297
    if (!c.isInit || !IsSameShaped(a, c)) {
        InitTensorV2(&c, &a);
298 299
    }

xuchen committed
300 301 302 303 304 305
    if (b.order == 0){
        DTYPE shift = b.Get0D() * beta;
        ScaleAndShift(a, c, 1.0F, shift);
    }
    else {
        int n = GetBroadcastDimIndex(a, b);
306

xuchen committed
307 308 309 310 311 312 313 314 315
        if (n == -1) {
            /* call _Sum function */
            _Sum(&a, &b, &c, beta);

            /* tensor connections */
            if (a.enableGrad && b.enableGrad) {
                XLink::MakeLink(&a, &b, &c, MATH_SUM);
                XLink::AddParamToHead(&c, beta);
            }
316
        }
xuchen committed
317 318 319 320 321 322 323 324 325 326 327 328 329
        else if (n >= 0 && n < a.order) {
            /* call _SumDim function */
            _SumDim(&a, &b, &c, n, beta);

            /* tensor connections */
            if (a.enableGrad && b.enableGrad) {
                XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
                XLink::AddParamToHeadInt(&c, n);
                XLink::AddParamToHead(&c, beta);
            }
        }
        else {
            ShowNTErrors("Something is wrong!");
330 331 332 333
        }
    }
}

xiaotong committed
334
} // namespace nts(NiuTrans.Tensor)