/* NiuTrans.Tensor - an open-source tensor library * Copyright (C) 2017, Natural Language Processing Lab, Northestern University. * All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24 */ #include "../../XTensor.h" #include "../../XUtility.h" #include "Split.h" #include "MakeSplitBlockIndex.h" #include "../movement/CopyBlocksOnSite.h" namespace nts { // namespace nts(NiuTrans.Tensor) /* transform a tensor by splitting it, e.g., (N, M) -> (N/3, M, 3) >> s - the source tensor >> t - the target tensor (for return) >> whereToSplit - which dimension of the tensor is to split >> splitNum - how many splits */ void Split(XTensor * s, XTensor * t, int whereToSplit, int splitNum) { CheckNTErrors((s && t), "Invalid tensors!"); CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)), "the data must be kept on the same device!"); CheckNTErrors((s->unitNum == t->unitNum && s->unitSize == t->unitSize), "Unmatched tensors!"); CheckNTErrors((s->order == t->order - 1), "Unmatched tensors!"); CheckNTErrors((t->dimSizeRDI[t->order - 1] == splitNum), "Incorrect tensor sizes!"); int whereToSplitRDI = s->order - whereToSplit - 1; for (int i = 0; i < s->order; i++) { if (i == whereToSplitRDI) { CheckNTErrors((s->dimSizeRDI[i] == t->dimSizeRDI[i] * splitNum), "Unmatched tensor sizes!"); } else { CheckNTErrors((s->dimSizeRDI[i] == t->dimSizeRDI[i]), "Unmatched tensor sizes!"); } } /* for the case that we split the last dimension. Actually (N, M) and (N, M/3, 3) have the same memory layout */ if (s->order - 1 == whereToSplitRDI) { XMemCopy(t->data, t->devID, s->data, s->devID, s->unitNum * s->unitSize); return; } int blockSize = 1; int blockNum = 1; for (int i = 0; i < s->order; i++) { if (i == whereToSplitRDI) { blockSize *= s->dimSizeRDI[i] / splitNum; blockNum *= splitNum; } else if (i < whereToSplitRDI) blockSize *= s->dimSizeRDI[i]; else blockNum *= s->dimSizeRDI[i]; } CheckNTErrors((blockNum % splitNum == 0), "Incorrect split number!"); if (splitNum <= MIN_TENSOR_SPLIT_NUM) { int sPitch = blockSize * splitNum * s->unitSize; int tPitch = blockSize * t->unitSize; int mSize = blockSize * t->unitSize; int n = blockNum / splitNum; int sStep = blockSize * s->unitSize; int tStep = n * tPitch; for (int k = 0; k < splitNum; k++) { XMemCopy2D((char*)t->data + k * tStep, tPitch, t->devID, (char*)s->data + k * sStep, sPitch, s->devID, mSize, n); } } else { XMem * mem = s->mem; int size = s->unitNum * s->unitSize; bool isOnSameDevice = (s->devID < 0 && t->devID < 0) || (s->devID == t->devID); void * dataTMP = t->data; if (!isOnSameDevice) dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(mem->devID, size); int realBlockSize = blockSize * t->unitSize; int blockSplitSize = blockNum / splitNum; int * blockIndex = (int*)(mem != NULL ? mem->AllocBuf(mem->devID, blockNum * sizeof(int)) : XMemAlloc(mem->devID, blockNum * sizeof(int))); MakeSplitBlockIndex(blockIndex, splitNum, blockSplitSize, blockNum, mem); CopyBlocksOnSite(s->data, realBlockSize, blockNum, dataTMP, blockIndex, mem); if (mem != NULL) mem->ReleaseBuf(mem->devID, blockNum * sizeof(int)); else XMemFree(mem->devID, blockIndex); /* copy from tmp to target */ if (!isOnSameDevice) { XMemCopy(t->data, t->devID, dataTMP, s->devID, size); if (mem != NULL) mem->ReleaseBuf(mem->devID, size); else XMemFree(mem->devID, dataTMP); } } } /* split a big tensor into small tensors >> big - the source tensor >> smalls - the list that keeps the resulting tensors (for return) NOTE that all the "small" tensors have already been placed in the list in advance. >> whereToSplit - which dimension of the tensor is to split >> splitNum - how many splits */ void Split(XTensor * big, XList * smalls, int whereToSplit, int splitNum) { CheckNTErrors((smalls != NULL), "Invalid list!"); CheckNTErrors((smalls->count == splitNum), "Unmatched tensors!"); CheckNTErrors((smalls->count > 0), "Wrong input!"); int whereToSplitRDI = big->order - whereToSplit - 1; bool uniform = true; for (int i = 0; i < smalls->count; i++) { XTensor* smallsItem = (XTensor*)smalls->GetItem(i); CheckNTErrors((big->unitNum == smallsItem->unitNum * splitNum), "Unmatched tensors!"); if (i > 0) { XTensor * preItem = (XTensor*)smalls->GetItem(i - 1); if (smallsItem->unitNum * smallsItem->unitSize != (char*)smallsItem->data - (char*)preItem->data) uniform = false; } } int blockSize = 1; int blockNum = 1; for (int i = 0; i < big->order; i++) { if (i == whereToSplitRDI) { blockSize *= big->dimSizeRDI[i] / splitNum; blockNum *= splitNum; } else if (i < whereToSplitRDI) blockSize *= big->dimSizeRDI[i]; else blockNum *= big->dimSizeRDI[i]; } CheckNTErrors((blockNum % splitNum == 0), "Incorrect split number!"); /* splitting with fewer data copy operations */ if (splitNum <= MIN_TENSOR_SPLIT_LIST_NUM) { XTensor * t0 = (XTensor*)smalls->GetItem(0); int sPitch = blockSize * splitNum * big->unitSize; int tPitch = blockSize * t0->unitSize; int mSize = blockSize * t0->unitSize; int n = blockNum / splitNum; int sStep = blockSize * big->unitSize; int tStep = 0; for (int k = 0; k < splitNum; k++) { XTensor * t = (XTensor*)smalls->GetItem(k); XMemCopy2D((char*)t->data + k * tStep, tPitch, t->devID, (char*)big->data + k * sStep, sPitch, big->devID, mSize, n); } } /* splitting with fewer kernel/api calls??? (i'm not sure about it!! may remove this later) */ else { int* dimSizeTMP = new int[MAX_TENSOR_DIM_NUM]; for (int i = 0; i < MAX_TENSOR_DIM_NUM; i++) dimSizeTMP[i] = -big->dimSize[i]; dimSizeTMP[whereToSplit] /= splitNum; dimSizeTMP[big->order] = -splitNum; XMem * mem = big->mem; XTensor* tensorTMP = new XTensor(big->order + 1, dimSizeTMP, big->dataType, big->denseRatio, mem); int size = big->unitNum * big->unitSize; void * dataTMP = NULL; if (uniform) { XTensor* first = (XTensor*)smalls->GetItem(0); dataTMP = first->data; } else { dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(mem->devID, size); } tensorTMP->data = dataTMP; Split(big, tensorTMP, whereToSplit, splitNum); /* copy from tmp to target */ if (!uniform) { int splitSize = big->unitNum * big->unitSize / splitNum; for (int i = 0; i < splitNum; i++) { XTensor* smallsItem = (XTensor*)smalls->GetItem(i); XMemCopy(smallsItem->data, smallsItem->devID, (char*)(tensorTMP->data) + (splitSize * i), tensorTMP->devID, splitSize); } } delete[] dimSizeTMP; tensorTMP->data = NULL; dataTMP = NULL; delete tensorTMP; if ((!uniform) && (mem != NULL)) mem->ReleaseBuf(mem->devID, size); else XMemFree(mem->devID, dataTMP); } } } // namespace nts(NiuTrans.Tensor)