Merge.cpp 15.2 KB
Newer Older
xiaotong committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

22 23 24
#include "../../XTensor.h"
#include "../../XUtility.h"
#include "../../XName.h"
xiaotong committed
25 26
#include "Merge.h"
#include "MakeMergeBlockIndex.h"
27
#include "../movement/CopyBlocksOnSite.h"
xiaotong committed
28 29 30 31

namespace nts { // namespace nts(NiuTrans.Tensor)

/*
32 33 34 35
transform a tensor by merging it along with a dimension.

e.g., (N/3, M, 3) -> (N, M)

xiaotong committed
36 37 38
>> s - the source tensor
>> t - the target tensor (for return)
>> whereToMerge - the merging operation is along with which dimension
39 40 41
>> leadingDim - the leading dimension of merging, take (N/3, M, 3) -> (N, M) 
   for example, whereToMerge = 0 (i.e., the dimension for "N/3")
   leadingDim = 2 (i.e., the dimension for "3")
xiaotong committed
42
*/
43
void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
xiaotong committed
44
{
45 46 47 48 49
    if(leadingDim < 0)
        leadingDim = 0;

    int whereToMergeRDI = s->order - whereToMerge - 1;
    int leadingDimRDI = s->order - leadingDim - 1;
xiaotong committed
50
    if (leadingDimRDI < 0)
51
        leadingDimRDI = s->order - 1;
xiaotong committed
52 53 54

    CheckNTErrors((s != NULL && t != NULL), "Invalid tensors!");
    CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)),
55
                  "the data must be kept on the same device!");
xiaotong committed
56 57 58 59 60 61 62 63

    CheckNTErrors((s->unitNum == t->unitNum && s->unitSize == t->unitSize), "Unmatched tensors!");
    CheckNTErrors((s->order == t->order + 1), "Unmatched tensors!");
    CheckNTErrors((leadingDimRDI > whereToMergeRDI), "Invalid leading dimension!");

    for (int i = 0; i < s->order; i++) {
        if (i == whereToMergeRDI) {
            CheckNTErrors((t->dimSizeRDI[i] == s->dimSizeRDI[i] * s->dimSizeRDI[leadingDimRDI]),
64
                          "Unmatched tensor sizes!");
xiaotong committed
65
        }
66 67 68 69
        else if (i < leadingDimRDI){
            CheckNTErrors((s->dimSizeRDI[i] == t->dimSizeRDI[i]),
                          "Unmatched tensor sizes!");
        }
xiaotong committed
70
        else if (i > leadingDimRDI) {
71
            CheckNTErrors((s->dimSizeRDI[i] == t->dimSizeRDI[i - 1]),
72
                          "Unmatched tensor sizes!");
xiaotong committed
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        }
    }

    int blockSize = 1;
    int blockNum = 1;
    int gridSize = 1;
    int gridNum = 1;
    int mergedNum = s->dimSizeRDI[leadingDimRDI];

    for (int i = 0; i < s->order; i++) {
        if (i <= leadingDimRDI) {
            if (i <= whereToMergeRDI)
                blockSize *= s->dimSizeRDI[i];
            else
                blockNum *= s->dimSizeRDI[i];
        }
    }

    CheckNTErrors((s->unitNum % (blockSize * blockNum) == 0), "Incorrect size!");

    /* a grid has a number of blocks. there might be several grids */
    gridSize = blockNum;
    gridNum = s->unitNum / (blockSize * blockNum);

xuchen committed
97
    if (mergedNum * gridNum <= MIN_TENSOR_MERGE_NUM) {
xiaotong committed
98 99 100 101 102 103 104 105 106 107 108
        int sPitch = blockSize * s->unitSize;
        int tPtich = blockSize * mergedNum * t->unitSize;
        int mSize = blockSize * t->unitSize;
        int n = blockNum / mergedNum;
        int sStep = n * sPitch;
        int tStep = blockSize * t->unitSize;
        for (int g = 0; g < gridNum; g++) {
            char * tData = (char*)t->data + g * blockSize * blockNum * t->unitSize;
            char * sData = (char*)s->data + g * blockSize * blockNum * s->unitSize;
            for (int k = 0; k < mergedNum; k++) {
                XMemCopy2D(tData + k * tStep, tPtich, t->devID,
109
                           sData + k * sStep, sPitch, s->devID, mSize, n);
xiaotong committed
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
            }
        }
    }
    else {
        XMem * mem = s->mem;
        int size = s->unitNum * s->unitSize;

        bool isOnSameDevice = (s->devID < 0 && t->devID < 0) || (s->devID == t->devID);

        void * dataTMP = t->data;

        if (!isOnSameDevice)
            dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(mem->devID, size);

        int blockNumInMerge = s->dimSizeRDI[leadingDimRDI];
        int splitSizeInGrid = gridSize / blockNumInMerge;
        int realBlockSize = blockSize * t->unitSize;

        int * blockIndex = (int*)(mem != NULL ?
129 130
                                  mem->AllocBuf(mem->devID, blockNum * gridNum * sizeof(int)) :
                                  XMemAlloc(s->devID, blockNum * gridNum * sizeof(int)));
xiaotong committed
131

132
        _MakeMergeBlockIndex(blockIndex, blockNum, blockNumInMerge, splitSizeInGrid, gridSize, gridNum, s->devID);
xiaotong committed
133

134
        _CopyBlocksOnSite(s->data, realBlockSize, blockNum * gridNum, dataTMP, blockIndex, s->devID);
xiaotong committed
135 136 137 138

        if (mem != NULL)
            mem->ReleaseBuf(mem->devID, blockNum * gridNum * sizeof(int));
        else
139
            XMemFree(s->devID, blockIndex);
xiaotong committed
140 141 142 143 144 145

        if (!isOnSameDevice) {
            XMemCopy(t->data, t->devID, dataTMP, s->devID, size);
            if (mem != NULL)
                mem->ReleaseBuf(mem->devID, size);
            else
146
                XMemFree(s->devID, dataTMP);
xiaotong committed
147 148 149 150
        }
    }
}

151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
bool CheckMergeSize(const XTensor * s, const XTensor * t, int whereToMerge, int leadingDim)
{
    if (!(s && t))
        return false;

    if (!(s->dataType == t->dataType))
        return false;

    if (leadingDim < 0)
        leadingDim = 0;
    int order = s->order - 1;
    int * dimSize = new int[order];

    for (int i = 0; i < s->order; i++) {
        if (i < leadingDim)
            dimSize[i] = s->dimSize[i];
        else if (i > leadingDim) {
            if (i != whereToMerge)
                dimSize[i - 1] = s->dimSize[i];
            else
                dimSize[i - 1] = s->dimSize[i] * s->dimSize[leadingDim];
        }
    }

    for (int i = 0; i < order; i++) {
        if (dimSize[i] != t->dimSize[i])
            return false;
    }

    return true;
}


xiaotong committed
184
/*
xiaotong committed
185
transform a tensor by merging it along with a dimension (return an XTensor structure)
186 187 188 189 190 191 192 193 194 195 196 197 198
make a new tensor to keep the result and  return it

e.g., (N/3, M, 3) -> (N, M)

>> s - the source tensor
>> whereToMerge - the merging operation is along with which dimension
>> leadingDim - the leading dimension of merging, take (N/3, M, 3) -> (N, M) 
   for example, whereToMerge = 0 (i.e., the dimension for "N/3")
   leadingDim = 2 (i.e., the dimension for "3")
<< return - the transformed tensor by merging along with a dimension
*/
XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
{
199
    CheckNTErrors(leadingDim < whereToMerge, "Invalid leading dimension!");
200 201
    
    if (leadingDim < 0)
202
        leadingDim = 0;
203 204 205 206 207 208 209 210 211 212 213 214 215 216
    int order = s.order - 1;
    int * dimSize = new int[order];

    for (int i = 0; i < s.order; i++) {
        if (i < leadingDim) 
            dimSize[i] = s.dimSize[i];
        else if (i > leadingDim) {
            if (i != whereToMerge)
                dimSize[i - 1] = s.dimSize[i];
            else
                dimSize[i - 1] = s.dimSize[i] * s.dimSize[leadingDim];
        }
    }

217 218
    float dr = (!s.isSparse) ? 1.0F : s.denseRatio;
    XTensor t(order, dimSize, s.dataType, dr, s.devID, s.mem);
219
    t.SetTMPFlag();
220 221 222 223

    /* call _Merge function */
    _Merge(&s, &t, whereToMerge, leadingDim);

224 225 226 227 228
    /* tensor connections */
    XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
    XLink::AddParamToHeadInt(&t, whereToMerge);
    XLink::AddParamToHeadInt(&t, leadingDim);

229
    /* destroy variables */
xiaotong committed
230
    delete[] dimSize;
231 232 233 234

    return t;
}

235
void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim)
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
{
    if (!t.isInit || !CheckMergeSize(&s, &t, whereToMerge, leadingDim)) {
        if (leadingDim < 0)
            leadingDim = 0;
        int order = s.order - 1;
        int * dimSize = new int[order];

        for (int i = 0; i < s.order; i++) {
            if (i < leadingDim)
                dimSize[i] = s.dimSize[i];
            else if (i > leadingDim) {
                if (i != whereToMerge)
                    dimSize[i - 1] = s.dimSize[i];
                else
                    dimSize[i - 1] = s.dimSize[i] * s.dimSize[leadingDim];
            }
        }

        float dr = (!s.isSparse) ? 1.0F : s.denseRatio;
        InitTensor(&t, order, dimSize, s.dataType, dr, s.devID, s.mem);

        /* destroy variables */
        delete[] dimSize;
    }

    /* call _Merge function */
    _Merge(&s, &t, whereToMerge, leadingDim);

264
    if (t.enableGrad) {
265 266 267 268 269 270 271
        /* tensor connections */
        XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
        XLink::AddParamToHeadInt(&t, whereToMerge);
        XLink::AddParamToHeadInt(&t, leadingDim);
    }
}

272
/*
xiaotong committed
273
merge small tensors into a big tensor
274

xiaotong committed
275
>> smalls - the list of the small tensors
276
>> t - the merged tensor (for return)
xiaotong committed
277 278
>> whereToMerge - the merging operation is along with which dimension
*/
279
void _Merge(const TensorList * smalls, XTensor * t, int whereToMerge)
xiaotong committed
280
{
281
    whereToMerge = (whereToMerge < 0 ? t->order - 1 : whereToMerge);
282

283
    CheckNTErrors((smalls != NULL), "Invalid list!");
xiaotong committed
284
    CheckNTErrors((smalls->count > 0), "Empty list!");
285
    CheckNTErrors((whereToMerge >= 0 && whereToMerge < t->order), "Wrong range of  whereToMerge");
xiaotong committed
286 287 288 289

    bool uniform = true;

    int mergeNum = smalls->count;
290
    XTensor* smallsItem0 = smalls->GetItem(0);
xiaotong committed
291 292 293
    int itemSize = smallsItem0->unitNum * smallsItem0->unitSize;

    for (int i = 0; i < smalls->count; i++) {
294
        XTensor* smallsItem = smalls->GetItem(i);
295
        CheckNTErrors((t->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!");
xiaotong committed
296 297

        if (i > 0) {
298
            XTensor * preItem = smalls->GetItem(i - 1);
xiaotong committed
299 300 301 302 303 304 305 306 307 308 309
            if (smallsItem->unitNum * smallsItem->unitSize != (char*)smallsItem->data - (char*)preItem->data)
                uniform = false;
        }
    }

    int blockSize = 1;
    int blockNum = 1;
    int gridSize = 1;
    int gridNum = 1;
    int mergedNum = smalls->count;

310
    XTensor * s0 = smalls->GetItem(0);
311
    int whereToMergeRDI = s0->order - whereToMerge - 1;
xiaotong committed
312 313 314 315 316 317 318 319 320 321 322 323 324 325
    for (int i = 0; i < s0->order; i++) {
        if (i <= whereToMergeRDI)
            blockSize *= s0->dimSizeRDI[i];
        else
            blockNum *= s0->dimSizeRDI[i];
    }

    CheckNTErrors((s0->unitNum % (blockSize * blockNum) == 0), "Incorrect size!");

    /* a grid has a number of blocks. there might be several grids */
    gridSize = blockNum;
    gridNum = s0->unitNum / (blockSize * blockNum);

    /* merging with fewer data copy operations */
xuchen committed
326
    if (mergedNum * gridNum <= MIN_TENSOR_MERGE_LIST_NUM) {
xiaotong committed
327
        int sPitch = blockSize * s0->unitSize;
328 329
        int tPtich = blockSize * mergedNum * t->unitSize;
        int mSize = blockSize * t->unitSize;
xiaotong committed
330 331
        int n = blockNum;
        int sStep = 0;
332
        int tStep = blockSize * t->unitSize;
xiaotong committed
333
        for (int g = 0; g < gridNum; g++) {
334
            char * tData = (char*)t->data + g * blockSize * blockNum * t->unitSize;
xiaotong committed
335
            for (int k = 0; k < mergedNum; k++) {
336
                XTensor * s = smalls->GetItem(k);
xiaotong committed
337
                char * sData = (char*)s->data + g * blockSize * blockNum * s->unitSize;
338
                XMemCopy2D(tData + k * tStep, tPtich, t->devID,
xiaotong committed
339 340 341 342 343 344 345
                    sData + k * sStep, sPitch, s->devID,
                    mSize, n);
            }
        }
    }
    /* merging with fewer kernel/api calls??? (i'm not sure about it!! may remove this later) */
    else {
346 347 348 349
        int* dimSizeTMP = new int[smallsItem0->order + 1];
        for (int i = 0; i < smallsItem0->order; i++)
            dimSizeTMP[i + 1] = -smallsItem0->dimSize[i];
        dimSizeTMP[0] = -mergeNum;
xiaotong committed
350 351

        XMem * mem = smallsItem0->mem;
352 353 354
        XTensor * tensorTMP = new XTensor(smallsItem0->order + 1, dimSizeTMP,
                                          smallsItem0->dataType, smallsItem0->denseRatio,
                                          smallsItem0->devID, mem);
xiaotong committed
355 356 357 358 359 360
        int size = mergeNum * itemSize;

        void * dataTMP = NULL;
        if (uniform)
            dataTMP = smallsItem0->data;
        else
361
            dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(t->devID, size);
xiaotong committed
362 363 364 365 366 367

        tensorTMP->data = dataTMP;

        /* copy from source to tmp */
        if (!uniform) {
            for (int i = 0; i < mergeNum; i++) {
368
                XTensor* smallsItem = smalls->GetItem(i);
xiaotong committed
369 370 371 372
                XMemCopy((char*)(tensorTMP->data) + (itemSize * i), tensorTMP->devID, smallsItem->data, smallsItem->devID, itemSize);
            }
        }

373
        _Merge(tensorTMP, t, whereToMerge + 1);
xiaotong committed
374 375 376

        delete[] dimSizeTMP;

377
        tensorTMP->data = NULL;
xiaotong committed
378 379 380 381 382
        delete tensorTMP;

        if ((!uniform) && (mem != NULL))
            mem->ReleaseBuf(mem->devID, size);
        else
383
            XMemFree(t->devID, dataTMP);
xiaotong committed
384 385
    }
}
386 387

/*
xiaotong committed
388
merge small tensors into a big tensor (return an XTensor structure)
389 390 391 392 393 394
make a new tensor to keep the result and return it

>> smalls - the list of the small tensors
>> whereToMerge - the merging operation is along with which dimension
<< return - the big tensor merged by small tensors
*/
395
XTensor Merge(const TensorList &smalls, int whereToMerge)
396
{
397
    XTensor * tensor = smalls.GetItem(0);
398 399 400 401 402 403 404 405 406
    int order = tensor->order;
    int * dimSize = new int[order];
    for (int i = 0; i < tensor->order; i++) {
        if (i != whereToMerge)
            dimSize[i] = tensor->dimSize[i];
        else
            dimSize[i] = tensor->dimSize[whereToMerge] * smalls.count;
    }

407 408
    float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
    XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
409
    big.SetTMPFlag();
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424

    /* call _Merge function */
    _Merge(&smalls, &big, whereToMerge);
    
    /* tensor connections */
    XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
    XLink::AddParamToHeadInt(&big, whereToMerge);

    /* destroy variables */
    delete[] dimSize;

    return big;
}

/* 
xiaotong committed
425
merge two tensors into a big tensor (return an XTensor structure) 
426 427 428 429 430 431
>> smalls - the list of the small tensors
>> whereToMerge - the merging operation is along with which dimension
<< return - the big tensor merged by small tensors
*/
XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge)
{
432
    CheckNTErrors(XTensor::IsSameShaped(&smallA, &smallB), 
433 434 435 436 437 438 439 440 441 442 443 444 445
                 "The two tensors must be of the same size!");

    int order = smallA.order;
    int * dimSize = new int[order];
    for (int i = 0; i < smallA.order; i++) {
        if (i != whereToMerge)
            dimSize[i] = smallA.dimSize[i];
        else
            dimSize[i] = smallA.dimSize[whereToMerge] * 2;
    }

    float dr = (!smallA.isSparse) ? 1.0F : smallA.denseRatio;
    XTensor big(order, dimSize, smallA.dataType, dr, smallA.devID, smallA.mem);
446
    big.SetTMPFlag();
447

448 449 450
    TensorList smalls(2);
    smalls.Add((XTensor*)&smallA);
    smalls.Add((XTensor*)&smallB);
451

452 453 454
    /* call _Merge function */
    _Merge(&smalls, &big, whereToMerge);

455 456 457 458
    /* tensor connections */
    XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
    XLink::AddParamToHeadInt(&big, whereToMerge);

459
    /* destroy variables */
xiaotong committed
460
    delete[] dimSize;
461 462 463 464

    return big;
}

xiaotong committed
465
} // namespace nts(NiuTrans.Tensor)