Commit f7ed3448 by xiaotong

clean the code of Split

parent 15f75d3a
......@@ -21,6 +21,7 @@
#include <stdio.h>
#include "XNet.h"
#include "../tensor/XUtility.h"
#include "../tensor/function/FHeader.h"
#include "../tensor/core/CHeader.h"
#include "../sample/fnnlm/FNNLM.h"
......
......@@ -47,10 +47,11 @@ struct XLink;
/* define the maximum number of dimensions in a tensor */
#define MAX_TENSOR_DIM_NUM 6
#define USE_BATCHED_STRIDED_MAT_MUL
#define MIN_TENSOR_SPLIT_NUM 10
#define MIN_TENSOR_SPLIT_NUM 0
#define MIN_TENSOR_SPLIT_LIST_NUM 1024
#define MIN_TENSOR_CAT_NUM 8
/* computation flags */
#define UNSAFE_BUT_FAST_MEM
#define FAST_MATRIX
......
......@@ -90,12 +90,12 @@ void _CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target,
int bSize = blockSize / sizeof(DTYPE);
if (bSize % 4 == 0) {
GDevs.GetCudaThread2D(myMem->devID, bSize / 4, blockNum, MAX_INT, cudaGrids, cudaBlocks);
GDevs.GetCudaThread2D(devID, bSize / 4, blockNum, MAX_INT, cudaGrids, cudaBlocks);
KernelCopyBlocks<4> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((DTYPE*)source, bSize, blockNum, (DTYPE*)target, targetBlocks);
}
else {
GDevs.GetCudaThread2D(myMem->devID, bSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
GDevs.GetCudaThread2D(devID, bSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
KernelCopyBlocks<1> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((DTYPE*)source, bSize, blockNum, (DTYPE*)target, targetBlocks);
}
......
......@@ -33,14 +33,14 @@ set target data block index for the data movement in merge
>> splitSizeInGrid - size of each data array to merge
>> gridSize - number of blocks in a grid (here grid is a higher level orgnization upon blocks)
>> gridNum - number of grids
>> mem - the memory pool
>> devID - device id
*/
void _MakeMergeBlockIndex(int * blockIndex, int blockNum, int blockNumInMerge,
int splitSizeInGrid, int gridSize, int gridNum, XMem * mem)
int splitSizeInGrid, int gridSize, int gridNum, int devID)
{
if (mem != NULL && mem->devID >= 0) {
if (devID >= 0) {
#ifdef USE_CUDA
_CudaMakeMergeBlockIndex(mem->devID, blockIndex, blockNum, blockNumInMerge, splitSizeInGrid, gridSize, gridNum);
_CudaMakeMergeBlockIndex(devID, blockIndex, blockNum, blockNumInMerge, splitSizeInGrid, gridSize, gridNum);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
......
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* set target data block index for the data movement in merge */
void _MakeMergeBlockIndex(int * blockIndex, int blockNum, int blockNumInMerge,
int splitSizeInGrid, int gridSize, int gridNum, XMem * mem);
int splitSizeInGrid, int gridSize, int gridNum, int devID);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -122,27 +122,23 @@ void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
int * blockIndex = (int*)(mem != NULL ?
mem->AllocBuf(mem->devID, blockNum * gridNum * sizeof(int)) :
XMemAlloc(mem->devID, blockNum * gridNum * sizeof(int)));
XMemAlloc(s->devID, blockNum * gridNum * sizeof(int)));
_MakeMergeBlockIndex(blockIndex, blockNum, blockNumInMerge, splitSizeInGrid, gridSize, gridNum, mem);
_MakeMergeBlockIndex(blockIndex, blockNum, blockNumInMerge, splitSizeInGrid, gridSize, gridNum, s->devID);
_CopyBlocksOnSite(s->data, realBlockSize, blockNum, dataTMP, blockIndex, s->devID);
if (mem != NULL)
mem->ReleaseBuf(mem->devID, blockNum * gridNum * sizeof(int));
else
XMemFree(mem->devID, blockIndex);
/* copy from tmp to target */
XMemCopy(t->data, t->devID, dataTMP, s->devID, size);
XMemFree(s->devID, blockIndex);
if (!isOnSameDevice) {
XMemCopy(t->data, t->devID, dataTMP, s->devID, size);
if (mem != NULL)
mem->ReleaseBuf(mem->devID, size);
else
XMemFree(mem->devID, dataTMP);
XMemFree(s->devID, dataTMP);
}
}
}
......@@ -300,9 +296,8 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
_Merge(tensorTMP, big, whereToMerge + 1);
delete[] dimSizeTMP;
tensorTMP->data = NULL;
dataTMP = NULL;
tensorTMP->data = NULL;
delete tensorTMP;
if ((!uniform) && (mem != NULL))
......
......@@ -83,7 +83,6 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
CheckNTErrors((blockNum % splitNum == 0), "Incorrect split number!");
if (splitNum <= MIN_TENSOR_SPLIT_NUM) {
//if (splitNum <= 0) {
int sPitch = blockSize * splitNum * s->unitSize;
int tPitch = blockSize * t->unitSize;
int mSize = blockSize * t->unitSize;
......@@ -143,7 +142,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
if (mem != NULL)
mem->ReleaseBuf(mem->devID, blockNum * sizeof(int));
else
XMemFree(mem->devID, blockIndex);
XMemFree(s->devID, blockIndex);
/* copy from tmp to target */
if (!isOnSameDevice) {
......@@ -152,7 +151,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
if (mem != NULL)
mem->ReleaseBuf(mem->devID, size);
else
XMemFree(mem->devID, dataTMP);
XMemFree(s->devID, dataTMP);
}
}
}
......@@ -321,7 +320,6 @@ void _Split(const XTensor * big, XList * smalls, int whereToSplit, int splitNum)
delete[] dimSizeTMP;
tensorTMP->data = NULL;
dataTMP = NULL;
delete tensorTMP;
if ((!uniform) && (mem != NULL))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论