Commit f7ed3448 by xiaotong

clean the code of Split

parent 15f75d3a
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <stdio.h> #include <stdio.h>
#include "XNet.h" #include "XNet.h"
#include "../tensor/XUtility.h"
#include "../tensor/function/FHeader.h" #include "../tensor/function/FHeader.h"
#include "../tensor/core/CHeader.h" #include "../tensor/core/CHeader.h"
#include "../sample/fnnlm/FNNLM.h" #include "../sample/fnnlm/FNNLM.h"
......
...@@ -47,10 +47,11 @@ struct XLink; ...@@ -47,10 +47,11 @@ struct XLink;
/* define the maximum number of dimensions in a tensor */ /* define the maximum number of dimensions in a tensor */
#define MAX_TENSOR_DIM_NUM 6 #define MAX_TENSOR_DIM_NUM 6
#define USE_BATCHED_STRIDED_MAT_MUL #define USE_BATCHED_STRIDED_MAT_MUL
#define MIN_TENSOR_SPLIT_NUM 10 #define MIN_TENSOR_SPLIT_NUM 0
#define MIN_TENSOR_SPLIT_LIST_NUM 1024 #define MIN_TENSOR_SPLIT_LIST_NUM 1024
#define MIN_TENSOR_CAT_NUM 8 #define MIN_TENSOR_CAT_NUM 8
/* computation flags */ /* computation flags */
#define UNSAFE_BUT_FAST_MEM #define UNSAFE_BUT_FAST_MEM
#define FAST_MATRIX #define FAST_MATRIX
......
...@@ -90,12 +90,12 @@ void _CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target, ...@@ -90,12 +90,12 @@ void _CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target,
int bSize = blockSize / sizeof(DTYPE); int bSize = blockSize / sizeof(DTYPE);
if (bSize % 4 == 0) { if (bSize % 4 == 0) {
GDevs.GetCudaThread2D(myMem->devID, bSize / 4, blockNum, MAX_INT, cudaGrids, cudaBlocks); GDevs.GetCudaThread2D(devID, bSize / 4, blockNum, MAX_INT, cudaGrids, cudaBlocks);
KernelCopyBlocks<4> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> KernelCopyBlocks<4> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((DTYPE*)source, bSize, blockNum, (DTYPE*)target, targetBlocks); ((DTYPE*)source, bSize, blockNum, (DTYPE*)target, targetBlocks);
} }
else { else {
GDevs.GetCudaThread2D(myMem->devID, bSize, blockNum, MAX_INT, cudaGrids, cudaBlocks); GDevs.GetCudaThread2D(devID, bSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
KernelCopyBlocks<1> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> KernelCopyBlocks<1> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((DTYPE*)source, bSize, blockNum, (DTYPE*)target, targetBlocks); ((DTYPE*)source, bSize, blockNum, (DTYPE*)target, targetBlocks);
} }
......
...@@ -33,14 +33,14 @@ set target data block index for the data movement in merge ...@@ -33,14 +33,14 @@ set target data block index for the data movement in merge
>> splitSizeInGrid - size of each data array to merge >> splitSizeInGrid - size of each data array to merge
>> gridSize - number of blocks in a grid (here grid is a higher level orgnization upon blocks) >> gridSize - number of blocks in a grid (here grid is a higher level orgnization upon blocks)
>> gridNum - number of grids >> gridNum - number of grids
>> mem - the memory pool >> devID - device id
*/ */
void _MakeMergeBlockIndex(int * blockIndex, int blockNum, int blockNumInMerge, void _MakeMergeBlockIndex(int * blockIndex, int blockNum, int blockNumInMerge,
int splitSizeInGrid, int gridSize, int gridNum, XMem * mem) int splitSizeInGrid, int gridSize, int gridNum, int devID)
{ {
if (mem != NULL && mem->devID >= 0) { if (devID >= 0) {
#ifdef USE_CUDA #ifdef USE_CUDA
_CudaMakeMergeBlockIndex(mem->devID, blockIndex, blockNum, blockNumInMerge, splitSizeInGrid, gridSize, gridNum); _CudaMakeMergeBlockIndex(devID, blockIndex, blockNum, blockNumInMerge, splitSizeInGrid, gridSize, gridNum);
#else #else
ShowNTErrors("Please specify USE_CUDA and recompile the code!"); ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif #endif
......
...@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* set target data block index for the data movement in merge */ /* set target data block index for the data movement in merge */
void _MakeMergeBlockIndex(int * blockIndex, int blockNum, int blockNumInMerge, void _MakeMergeBlockIndex(int * blockIndex, int blockNum, int blockNumInMerge,
int splitSizeInGrid, int gridSize, int gridNum, XMem * mem); int splitSizeInGrid, int gridSize, int gridNum, int devID);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -122,27 +122,23 @@ void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim) ...@@ -122,27 +122,23 @@ void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
int * blockIndex = (int*)(mem != NULL ? int * blockIndex = (int*)(mem != NULL ?
mem->AllocBuf(mem->devID, blockNum * gridNum * sizeof(int)) : mem->AllocBuf(mem->devID, blockNum * gridNum * sizeof(int)) :
XMemAlloc(mem->devID, blockNum * gridNum * sizeof(int))); XMemAlloc(s->devID, blockNum * gridNum * sizeof(int)));
_MakeMergeBlockIndex(blockIndex, blockNum, blockNumInMerge, splitSizeInGrid, gridSize, gridNum, mem); _MakeMergeBlockIndex(blockIndex, blockNum, blockNumInMerge, splitSizeInGrid, gridSize, gridNum, s->devID);
_CopyBlocksOnSite(s->data, realBlockSize, blockNum, dataTMP, blockIndex, s->devID); _CopyBlocksOnSite(s->data, realBlockSize, blockNum, dataTMP, blockIndex, s->devID);
if (mem != NULL) if (mem != NULL)
mem->ReleaseBuf(mem->devID, blockNum * gridNum * sizeof(int)); mem->ReleaseBuf(mem->devID, blockNum * gridNum * sizeof(int));
else else
XMemFree(mem->devID, blockIndex); XMemFree(s->devID, blockIndex);
/* copy from tmp to target */
XMemCopy(t->data, t->devID, dataTMP, s->devID, size);
if (!isOnSameDevice) { if (!isOnSameDevice) {
XMemCopy(t->data, t->devID, dataTMP, s->devID, size); XMemCopy(t->data, t->devID, dataTMP, s->devID, size);
if (mem != NULL) if (mem != NULL)
mem->ReleaseBuf(mem->devID, size); mem->ReleaseBuf(mem->devID, size);
else else
XMemFree(mem->devID, dataTMP); XMemFree(s->devID, dataTMP);
} }
} }
} }
...@@ -300,9 +296,8 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge) ...@@ -300,9 +296,8 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
_Merge(tensorTMP, big, whereToMerge + 1); _Merge(tensorTMP, big, whereToMerge + 1);
delete[] dimSizeTMP; delete[] dimSizeTMP;
tensorTMP->data = NULL;
dataTMP = NULL;
tensorTMP->data = NULL;
delete tensorTMP; delete tensorTMP;
if ((!uniform) && (mem != NULL)) if ((!uniform) && (mem != NULL))
......
...@@ -83,7 +83,6 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum) ...@@ -83,7 +83,6 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
CheckNTErrors((blockNum % splitNum == 0), "Incorrect split number!"); CheckNTErrors((blockNum % splitNum == 0), "Incorrect split number!");
if (splitNum <= MIN_TENSOR_SPLIT_NUM) { if (splitNum <= MIN_TENSOR_SPLIT_NUM) {
//if (splitNum <= 0) {
int sPitch = blockSize * splitNum * s->unitSize; int sPitch = blockSize * splitNum * s->unitSize;
int tPitch = blockSize * t->unitSize; int tPitch = blockSize * t->unitSize;
int mSize = blockSize * t->unitSize; int mSize = blockSize * t->unitSize;
...@@ -143,7 +142,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum) ...@@ -143,7 +142,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
if (mem != NULL) if (mem != NULL)
mem->ReleaseBuf(mem->devID, blockNum * sizeof(int)); mem->ReleaseBuf(mem->devID, blockNum * sizeof(int));
else else
XMemFree(mem->devID, blockIndex); XMemFree(s->devID, blockIndex);
/* copy from tmp to target */ /* copy from tmp to target */
if (!isOnSameDevice) { if (!isOnSameDevice) {
...@@ -152,7 +151,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum) ...@@ -152,7 +151,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
if (mem != NULL) if (mem != NULL)
mem->ReleaseBuf(mem->devID, size); mem->ReleaseBuf(mem->devID, size);
else else
XMemFree(mem->devID, dataTMP); XMemFree(s->devID, dataTMP);
} }
} }
} }
...@@ -321,7 +320,6 @@ void _Split(const XTensor * big, XList * smalls, int whereToSplit, int splitNum) ...@@ -321,7 +320,6 @@ void _Split(const XTensor * big, XList * smalls, int whereToSplit, int splitNum)
delete[] dimSizeTMP; delete[] dimSizeTMP;
tensorTMP->data = NULL; tensorTMP->data = NULL;
dataTMP = NULL;
delete tensorTMP; delete tensorTMP;
if ((!uniform) && (mem != NULL)) if ((!uniform) && (mem != NULL))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论