Merge.cpp 15.1 KB
Newer Older
xiaotong committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

22 23 24
#include "../../XTensor.h"
#include "../../XUtility.h"
#include "../../XName.h"
liyinqiao committed
25
#include "../shape/IsSameShaped.h"
xiaotong committed
26 27
#include "Merge.h"
#include "MakeMergeBlockIndex.h"
28
#include "../movement/CopyBlocksOnSite.h"
xiaotong committed
29 30 31 32

namespace nts { // namespace nts(NiuTrans.Tensor)

/*
33 34 35 36
transform a tensor by merging it along with a dimension.

e.g., (N/3, M, 3) -> (N, M)

xiaotong committed
37 38 39
>> s - the source tensor
>> t - the target tensor (for return)
>> whereToMerge - the merging operation is along with which dimension
40 41 42
>> leadingDim - the leading dimension of merging, take (N/3, M, 3) -> (N, M) 
   for example, whereToMerge = 0 (i.e., the dimension for "N/3")
   leadingDim = 2 (i.e., the dimension for "3")
xiaotong committed
43
*/
44
void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
xiaotong committed
45
{
46 47 48
    if(leadingDim < 0)
        leadingDim = 0;

49 50
    if (leadingDim >= s->order)
        leadingDim = leadingDim - s->order;
xiaotong committed
51 52 53

    CheckNTErrors((s != NULL && t != NULL), "Invalid tensors!");
    CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)),
54
                  "the data must be kept on the same device!");
xiaotong committed
55 56 57

    CheckNTErrors((s->unitNum == t->unitNum && s->unitSize == t->unitSize), "Unmatched tensors!");
    CheckNTErrors((s->order == t->order + 1), "Unmatched tensors!");
58
    CheckNTErrors((leadingDim < whereToMerge), "Invalid leading dimension!");
xiaotong committed
59 60

    for (int i = 0; i < s->order; i++) {
61 62 63
        if (i == whereToMerge) {
            
            CheckNTErrors((t->dimSize[i - 1] == s->dimSize[i] * s->dimSize[leadingDim]),
64
                          "Unmatched tensor sizes!");
xiaotong committed
65
        }
66 67
        else if (i < leadingDim){
            CheckNTErrors((s->dimSize[i] == t->dimSize[i]),
68 69
                          "Unmatched tensor sizes!");
        }
70 71
        else if (i > leadingDim) {
            CheckNTErrors((s->dimSize[i] == t->dimSize[i - 1]),
72
                          "Unmatched tensor sizes!");
xiaotong committed
73 74 75 76 77 78 79
        }
    }

    int blockSize = 1;
    int blockNum = 1;
    int gridSize = 1;
    int gridNum = 1;
80
    int mergedNum = s->dimSize[leadingDim];
xiaotong committed
81 82

    for (int i = 0; i < s->order; i++) {
83 84 85
        if (i >= leadingDim) {
            if (i >= whereToMerge)
                blockSize *= s->dimSize[i];
xiaotong committed
86
            else
87
                blockNum *= s->dimSize[i];
xiaotong committed
88 89 90 91 92 93 94 95 96
        }
    }

    CheckNTErrors((s->unitNum % (blockSize * blockNum) == 0), "Incorrect size!");

    /* a grid has a number of blocks. there might be several grids */
    gridSize = blockNum;
    gridNum = s->unitNum / (blockSize * blockNum);

97
    if (mergedNum * gridNum <= MIN_TENSOR_MERGE_NUM) {
xiaotong committed
98 99 100 101 102 103 104 105 106 107 108
        int sPitch = blockSize * s->unitSize;
        int tPtich = blockSize * mergedNum * t->unitSize;
        int mSize = blockSize * t->unitSize;
        int n = blockNum / mergedNum;
        int sStep = n * sPitch;
        int tStep = blockSize * t->unitSize;
        for (int g = 0; g < gridNum; g++) {
            char * tData = (char*)t->data + g * blockSize * blockNum * t->unitSize;
            char * sData = (char*)s->data + g * blockSize * blockNum * s->unitSize;
            for (int k = 0; k < mergedNum; k++) {
                XMemCopy2D(tData + k * tStep, tPtich, t->devID,
109
                           sData + k * sStep, sPitch, s->devID, mSize, n);
xiaotong committed
110 111 112 113 114 115 116 117 118 119 120 121 122 123
            }
        }
    }
    else {
        XMem * mem = s->mem;
        int size = s->unitNum * s->unitSize;

        bool isOnSameDevice = (s->devID < 0 && t->devID < 0) || (s->devID == t->devID);

        void * dataTMP = t->data;

        if (!isOnSameDevice)
            dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(mem->devID, size);

124
        int blockNumInMerge = s->dimSize[leadingDim];
xiaotong committed
125 126 127 128
        int splitSizeInGrid = gridSize / blockNumInMerge;
        int realBlockSize = blockSize * t->unitSize;

        int * blockIndex = (int*)(mem != NULL ?
129 130
                                  mem->AllocBuf(mem->devID, blockNum * gridNum * sizeof(int)) :
                                  XMemAlloc(s->devID, blockNum * gridNum * sizeof(int)));
xiaotong committed
131

132
        _MakeMergeBlockIndex(blockIndex, blockNum, blockNumInMerge, splitSizeInGrid, gridSize, gridNum, s->devID);
xiaotong committed
133

134
        _CopyBlocksOnSite(s->data, realBlockSize, blockNum * gridNum, dataTMP, blockIndex, s->devID);
xiaotong committed
135 136 137 138

        if (mem != NULL)
            mem->ReleaseBuf(mem->devID, blockNum * gridNum * sizeof(int));
        else
139
            XMemFree(s->devID, blockIndex);
xiaotong committed
140 141 142 143 144 145

        if (!isOnSameDevice) {
            XMemCopy(t->data, t->devID, dataTMP, s->devID, size);
            if (mem != NULL)
                mem->ReleaseBuf(mem->devID, size);
            else
146
                XMemFree(s->devID, dataTMP);
xiaotong committed
147 148 149 150
        }
    }
}

151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
bool CheckMergeSize(const XTensor * s, const XTensor * t, int whereToMerge, int leadingDim)
{
    if (!(s && t))
        return false;

    if (!(s->dataType == t->dataType))
        return false;

    if (leadingDim < 0)
        leadingDim = 0;
    int order = s->order - 1;
    int * dimSize = new int[order];

    for (int i = 0; i < s->order; i++) {
        if (i < leadingDim)
            dimSize[i] = s->dimSize[i];
        else if (i > leadingDim) {
            if (i != whereToMerge)
                dimSize[i - 1] = s->dimSize[i];
            else
                dimSize[i - 1] = s->dimSize[i] * s->dimSize[leadingDim];
        }
    }

    for (int i = 0; i < order; i++) {
        if (dimSize[i] != t->dimSize[i])
            return false;
    }

    return true;
}


xiaotong committed
184
/*
xiaotong committed
185
transform a tensor by merging it along with a dimension (return an XTensor structure)
186 187 188 189 190 191 192 193 194 195 196 197 198
make a new tensor to keep the result and  return it

e.g., (N/3, M, 3) -> (N, M)

>> s - the source tensor
>> whereToMerge - the merging operation is along with which dimension
>> leadingDim - the leading dimension of merging, take (N/3, M, 3) -> (N, M) 
   for example, whereToMerge = 0 (i.e., the dimension for "N/3")
   leadingDim = 2 (i.e., the dimension for "3")
<< return - the transformed tensor by merging along with a dimension
*/
XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
{
199
    CheckNTErrors(leadingDim < whereToMerge, "Invalid leading dimension!");
200 201
    
    if (leadingDim < 0)
202
        leadingDim = 0;
203 204 205 206 207 208 209 210 211 212 213 214 215 216
    int order = s.order - 1;
    int * dimSize = new int[order];

    for (int i = 0; i < s.order; i++) {
        if (i < leadingDim) 
            dimSize[i] = s.dimSize[i];
        else if (i > leadingDim) {
            if (i != whereToMerge)
                dimSize[i - 1] = s.dimSize[i];
            else
                dimSize[i - 1] = s.dimSize[i] * s.dimSize[leadingDim];
        }
    }

217 218
    float dr = (!s.isSparse) ? 1.0F : s.denseRatio;
    XTensor t(order, dimSize, s.dataType, dr, s.devID, s.mem);
219
    t.SetTMPFlag();
220 221 222 223

    /* call _Merge function */
    _Merge(&s, &t, whereToMerge, leadingDim);

224
    /* tensor connections */
xuchen committed
225 226 227 228 229
    if (s.enableGrad) {
        XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
        XLink::AddParamToHeadInt(&t, whereToMerge);
        XLink::AddParamToHeadInt(&t, leadingDim);
    }
230

231
    /* destroy variables */
xiaotong committed
232
    delete[] dimSize;
233 234 235 236

    return t;
}

liyinqiao committed
237
void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim)
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
{
    if (!t.isInit || !CheckMergeSize(&s, &t, whereToMerge, leadingDim)) {
        if (leadingDim < 0)
            leadingDim = 0;
        int order = s.order - 1;
        int * dimSize = new int[order];

        for (int i = 0; i < s.order; i++) {
            if (i < leadingDim)
                dimSize[i] = s.dimSize[i];
            else if (i > leadingDim) {
                if (i != whereToMerge)
                    dimSize[i - 1] = s.dimSize[i];
                else
                    dimSize[i - 1] = s.dimSize[i] * s.dimSize[leadingDim];
            }
        }

        float dr = (!s.isSparse) ? 1.0F : s.denseRatio;
257
        InitTensorV2(&t, order, dimSize, s.dataType, dr, s.devID, s.mem);
258 259 260 261 262 263 264 265

        /* destroy variables */
        delete[] dimSize;
    }

    /* call _Merge function */
    _Merge(&s, &t, whereToMerge, leadingDim);

xuchen committed
266
    if (s.enableGrad) {
267 268 269 270 271 272 273
        /* tensor connections */
        XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
        XLink::AddParamToHeadInt(&t, whereToMerge);
        XLink::AddParamToHeadInt(&t, leadingDim);
    }
}

274
/*
xiaotong committed
275
merge small tensors into a big tensor
276

xiaotong committed
277
>> smalls - the list of the small tensors
liyinqiao committed
278
>> t - the merged tensor (for return)
xiaotong committed
279 280
>> whereToMerge - the merging operation is along with which dimension
*/
liyinqiao committed
281
void _Merge(const TensorList * smalls, XTensor * t, int whereToMerge)
xiaotong committed
282
{
liyinqiao committed
283
    whereToMerge = (whereToMerge < 0 ? t->order - 1 : whereToMerge);
284

285
    CheckNTErrors((smalls != NULL), "Invalid list!");
xiaotong committed
286
    CheckNTErrors((smalls->count > 0), "Empty list!");
liyinqiao committed
287
    CheckNTErrors((whereToMerge >= 0 && whereToMerge < t->order), "Wrong range of  whereToMerge");
xiaotong committed
288 289 290 291

    bool uniform = true;

    int mergeNum = smalls->count;
292
    XTensor* smallsItem0 = smalls->GetItem(0);
xiaotong committed
293 294 295
    int itemSize = smallsItem0->unitNum * smallsItem0->unitSize;

    for (int i = 0; i < smalls->count; i++) {
296
        XTensor* smallsItem = smalls->GetItem(i);
liyinqiao committed
297
        CheckNTErrors((t->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!");
xiaotong committed
298 299

        if (i > 0) {
300
            XTensor * preItem = smalls->GetItem(i - 1);
xiaotong committed
301 302 303 304 305 306 307 308 309 310 311
            if (smallsItem->unitNum * smallsItem->unitSize != (char*)smallsItem->data - (char*)preItem->data)
                uniform = false;
        }
    }

    int blockSize = 1;
    int blockNum = 1;
    int gridSize = 1;
    int gridNum = 1;
    int mergedNum = smalls->count;

312
    XTensor * s0 = smalls->GetItem(0);
xiaotong committed
313
    for (int i = 0; i < s0->order; i++) {
314 315
        if (i >= whereToMerge)
            blockSize *= s0->dimSize[i];
xiaotong committed
316
        else
317
            blockNum *= s0->dimSize[i];
xiaotong committed
318 319 320 321 322 323 324 325 326
    }

    CheckNTErrors((s0->unitNum % (blockSize * blockNum) == 0), "Incorrect size!");

    /* a grid has a number of blocks. there might be several grids */
    gridSize = blockNum;
    gridNum = s0->unitNum / (blockSize * blockNum);

    /* merging with fewer data copy operations */
327
    if (mergedNum * gridNum <= MIN_TENSOR_MERGE_LIST_NUM) {
xiaotong committed
328
        int sPitch = blockSize * s0->unitSize;
liyinqiao committed
329 330
        int tPtich = blockSize * mergedNum * t->unitSize;
        int mSize = blockSize * t->unitSize;
xiaotong committed
331 332
        int n = blockNum;
        int sStep = 0;
liyinqiao committed
333
        int tStep = blockSize * t->unitSize;
xiaotong committed
334
        for (int g = 0; g < gridNum; g++) {
liyinqiao committed
335
            char * tData = (char*)t->data + g * blockSize * blockNum * t->unitSize;
xiaotong committed
336
            for (int k = 0; k < mergedNum; k++) {
337
                XTensor * s = smalls->GetItem(k);
xiaotong committed
338
                char * sData = (char*)s->data + g * blockSize * blockNum * s->unitSize;
liyinqiao committed
339
                XMemCopy2D(tData + k * tStep, tPtich, t->devID,
xiaotong committed
340 341 342 343 344 345 346
                    sData + k * sStep, sPitch, s->devID,
                    mSize, n);
            }
        }
    }
    /* merging with fewer kernel/api calls??? (i'm not sure about it!! may remove this later) */
    else {
347 348 349 350
        int* dimSizeTMP = new int[smallsItem0->order + 1];
        for (int i = 0; i < smallsItem0->order; i++)
            dimSizeTMP[i + 1] = -smallsItem0->dimSize[i];
        dimSizeTMP[0] = -mergeNum;
xiaotong committed
351 352

        XMem * mem = smallsItem0->mem;
353 354 355
        XTensor * tensorTMP = new XTensor(smallsItem0->order + 1, dimSizeTMP,
                                          smallsItem0->dataType, smallsItem0->denseRatio,
                                          smallsItem0->devID, mem);
xiaotong committed
356 357 358 359 360 361
        int size = mergeNum * itemSize;

        void * dataTMP = NULL;
        if (uniform)
            dataTMP = smallsItem0->data;
        else
liyinqiao committed
362
            dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(t->devID, size);
xiaotong committed
363 364 365 366 367 368

        tensorTMP->data = dataTMP;

        /* copy from source to tmp */
        if (!uniform) {
            for (int i = 0; i < mergeNum; i++) {
369
                XTensor* smallsItem = smalls->GetItem(i);
xiaotong committed
370 371 372 373
                XMemCopy((char*)(tensorTMP->data) + (itemSize * i), tensorTMP->devID, smallsItem->data, smallsItem->devID, itemSize);
            }
        }

liyinqiao committed
374
        _Merge(tensorTMP, t, whereToMerge + 1);
xiaotong committed
375 376 377

        delete[] dimSizeTMP;

378
        tensorTMP->data = NULL;
xiaotong committed
379 380 381 382 383
        delete tensorTMP;

        if ((!uniform) && (mem != NULL))
            mem->ReleaseBuf(mem->devID, size);
        else
liyinqiao committed
384
            XMemFree(t->devID, dataTMP);
xiaotong committed
385 386
    }
}
387 388

/*
xiaotong committed
389
merge small tensors into a big tensor (return an XTensor structure)
390 391 392 393 394 395
make a new tensor to keep the result and return it

>> smalls - the list of the small tensors
>> whereToMerge - the merging operation is along with which dimension
<< return - the big tensor merged by small tensors
*/
396
XTensor Merge(const TensorList &smalls, int whereToMerge)
397
{
398
    XTensor * tensor = smalls.GetItem(0);
399 400 401 402 403 404 405 406 407
    int order = tensor->order;
    int * dimSize = new int[order];
    for (int i = 0; i < tensor->order; i++) {
        if (i != whereToMerge)
            dimSize[i] = tensor->dimSize[i];
        else
            dimSize[i] = tensor->dimSize[whereToMerge] * smalls.count;
    }

408 409
    float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
    XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
410
    big.SetTMPFlag();
411 412 413 414 415

    /* call _Merge function */
    _Merge(&smalls, &big, whereToMerge);
    
    /* tensor connections */
xuchen committed
416 417 418 419
    if (tensor->enableGrad) {
        XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
        XLink::AddParamToHeadInt(&big, whereToMerge);
    }
420 421 422 423 424 425 426 427

    /* destroy variables */
    delete[] dimSize;

    return big;
}

/* 
xiaotong committed
428
merge two tensors into a big tensor (return an XTensor structure) 
429 430 431 432 433 434
>> smalls - the list of the small tensors
>> whereToMerge - the merging operation is along with which dimension
<< return - the big tensor merged by small tensors
*/
XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge)
{
435
    CheckNTErrors(IsSameShaped(smallA, smallB), 
436 437 438 439 440 441 442 443 444 445 446 447 448
                 "The two tensors must be of the same size!");

    int order = smallA.order;
    int * dimSize = new int[order];
    for (int i = 0; i < smallA.order; i++) {
        if (i != whereToMerge)
            dimSize[i] = smallA.dimSize[i];
        else
            dimSize[i] = smallA.dimSize[whereToMerge] * 2;
    }

    float dr = (!smallA.isSparse) ? 1.0F : smallA.denseRatio;
    XTensor big(order, dimSize, smallA.dataType, dr, smallA.devID, smallA.mem);
449
    big.SetTMPFlag();
450

451 452 453
    TensorList smalls(2);
    smalls.Add((XTensor*)&smallA);
    smalls.Add((XTensor*)&smallB);
454

455 456 457
    /* call _Merge function */
    _Merge(&smalls, &big, whereToMerge);

458
    /* tensor connections */
xuchen committed
459 460 461 462
    if (smallA.enableGrad) {
        XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
        XLink::AddParamToHeadInt(&big, whereToMerge);
    }
463

464
    /* destroy variables */
xiaotong committed
465
    delete[] dimSize;
466 467 468 469

    return big;
}

xiaotong committed
470
} // namespace nts(NiuTrans.Tensor)