Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
Emmay
NiuTrans.Tensor
Commits
15f75d3a
Commit
15f75d3a
authored
Jul 27, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add devID to the function as an argument so that it does not require XMem as a necessary input
parent
3f23f074
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
38 行增加
和
34 行删除
+38
-34
source/tensor/core/movement/CopyBlocks.cpp
+17
-7
source/tensor/core/movement/CopyBlocks.h
+1
-1
source/tensor/core/movement/CopyBlocksOnSite.cpp
+6
-9
source/tensor/core/movement/CopyBlocksOnSite.cu
+9
-11
source/tensor/core/movement/CopyBlocksOnSite.cuh
+2
-3
source/tensor/core/movement/CopyBlocksOnSite.h
+1
-1
source/tensor/core/shape/Merge.cpp
+1
-1
source/tensor/core/shape/Split.cpp
+1
-1
没有找到文件。
source/tensor/core/movement/CopyBlocks.cpp
查看文件 @
15f75d3a
...
...
@@ -35,24 +35,33 @@ copy a number of blocks to target positions
>> target - target data array
>> targetBlocks - target positions of the copy
>> myMem - the memory pool
>> devID - device id
*/
void
_CopyBlocks
(
void
*
source
,
int
blockSize
,
int
blockNum
,
void
*
target
,
int
*
targetBlocks
,
XMem
*
myMem
)
void
_CopyBlocks
(
void
*
source
,
int
blockSize
,
int
blockNum
,
void
*
target
,
int
*
targetBlocks
,
XMem
*
myMem
,
int
devID
)
{
if
(
myMem
!=
NULL
&&
myMem
->
devID
>=
0
)
{
if
(
myMem
!=
NULL
)
devID
=
myMem
->
devID
;
if
(
devID
>=
0
)
{
#ifdef USE_CUDA
/* copy the index from host to device */
int
*
targetBlocksTMP
=
(
int
*
)
myMem
->
AllocBuf
(
myMem
->
devID
,
blockNum
*
sizeof
(
int
));
int
*
targetBlocksTMP
=
myMem
!=
NULL
?
(
int
*
)
myMem
->
AllocBuf
(
myMem
->
devID
,
blockNum
*
sizeof
(
int
))
:
(
int
*
)
XMemAlloc
(
devID
,
blockNum
*
sizeof
(
int
));
XMemCopy
(
targetBlocksTMP
,
myMem
->
devID
,
targetBlocks
,
-
1
,
blockNum
*
sizeof
(
int
));
_CopyBlocksOnSite
(
source
,
blockSize
,
blockNum
,
target
,
targetBlocksTMP
,
myMem
);
_CopyBlocksOnSite
(
source
,
blockSize
,
blockNum
,
target
,
targetBlocksTMP
,
devID
);
myMem
->
ReleaseBuf
(
myMem
->
devID
,
blockNum
*
sizeof
(
int
));
if
(
myMem
!=
NULL
)
myMem
->
ReleaseBuf
(
myMem
->
devID
,
blockNum
*
sizeof
(
int
));
else
XMemFree
(
devID
,
targetBlocksTMP
);
#else
ShowNTErrors
(
"Plesae specify USE_CUDA and recompile the code!"
);
#endif
}
else
{
_CopyBlocksOnSite
(
source
,
blockSize
,
blockNum
,
target
,
targetBlocks
,
myMem
);
_CopyBlocksOnSite
(
source
,
blockSize
,
blockNum
,
target
,
targetBlocks
,
devID
);
}
}
...
...
@@ -65,11 +74,12 @@ copy a number of blocks source source positions to target positions
>> target - target data array
>> targetBlocks - target positions of the copy
>> myMem - the memory pool
>> devID - device id
*/
void
_CopyBlocks
(
void
*
source
,
int
blockSize
,
int
*
sourceBlocks
,
int
blockNum
,
void
*
target
,
int
*
targetBlocks
,
XMem
*
myMem
,
int
devID
)
{
if
(
myMem
!=
NULL
)
CheckNTErrors
((
myMem
->
devID
==
devID
),
"DevIDs are different between memory pool and input devID!"
)
;
devID
=
myMem
->
devID
;
if
(
devID
>=
0
)
{
#ifdef USE_CUDA
...
...
source/tensor/core/movement/CopyBlocks.h
查看文件 @
15f75d3a
...
...
@@ -27,7 +27,7 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/* copy a number of blocks to target positions */
void
_CopyBlocks
(
void
*
source
,
int
blockSize
,
int
blockNum
,
void
*
target
,
int
*
targetBlocks
,
XMem
*
myMem
);
void
_CopyBlocks
(
void
*
source
,
int
blockSize
,
int
blockNum
,
void
*
target
,
int
*
targetBlocks
,
XMem
*
myMem
,
int
devID
);
/* copy a number of blocks from source positions to target positions */
void
_CopyBlocks
(
void
*
source
,
int
blockSize
,
int
*
sourceBlocks
,
int
blockNum
,
void
*
target
,
int
*
targetBlocks
,
XMem
*
myMem
,
int
devID
);
...
...
source/tensor/core/movement/CopyBlocksOnSite.cpp
查看文件 @
15f75d3a
...
...
@@ -34,20 +34,18 @@ all the data has been on the device (CPU/GPU) already.
>> blockNum - number of blocks
>> target - target data array
>> targetBlocks - target positions of the copy
>>
myMem - the memory pool
>>
devID - device id
*/
void
_CopyBlocksOnSite
(
void
*
source
,
int
blockSize
,
int
blockNum
,
void
*
target
,
int
*
targetBlocks
,
XMem
*
myMem
)
void
_CopyBlocksOnSite
(
void
*
source
,
int
blockSize
,
int
blockNum
,
void
*
target
,
int
*
targetBlocks
,
int
devID
)
{
if
(
myMem
!=
NULL
&&
myMem
->
devID
>=
0
)
{
if
(
devID
>=
0
)
{
#ifdef USE_CUDA
_CudaCopyBlocks
(
source
,
blockSize
,
blockNum
,
target
,
targetBlocks
,
myMem
);
_CudaCopyBlocks
(
source
,
blockSize
,
blockNum
,
target
,
targetBlocks
,
devID
);
#else
ShowNTErrors
(
"Plesae specify USE_CUDA and recompile the code!"
);
#endif
}
else
{
int
devID
=
myMem
!=
NULL
?
myMem
->
devID
:
-
1
;
/*
The following code should be fine with GPUs, but too many
kernel calls would slow down the system. We prefer to use
...
...
@@ -55,8 +53,8 @@ void _CopyBlocksOnSite(void * source, int blockSize, int blockNum, void * target
*/
for
(
int
i
=
0
,
b
=
0
;
i
<
blockNum
;
i
++
,
b
+=
blockSize
)
{
XMemCopy
((
char
*
)
target
+
targetBlocks
[
i
]
*
blockSize
,
devID
,
(
char
*
)
source
+
b
,
devID
,
blockSize
);
(
char
*
)
source
+
b
,
devID
,
blockSize
);
}
}
}
}
//
namespace
nts
(
NiuTrans
.
Tensor
)
\ No newline at end of file
}
// namespace nts(NiuTrans.Tensor)
source/tensor/core/movement/CopyBlocksOnSite.cu
查看文件 @
15f75d3a
...
...
@@ -78,13 +78,12 @@ copy a number of blocks to target positions (cuda version)
>> blockNum - number of blocks
>> target - target data array
>> targetBlocks - target positions of the copy (on the device)
>>
myMem - memory pool
>>
devID - device id
*/
void _CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks,
XMem * myMem
)
void _CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks,
int devID
)
{
CheckNTErrors((myMem != NULL), "No memory pool!");
CheckNTErrors((myMem->devID >= 0), "Wrong device to run!");
CheckNTErrors((blockSize % sizeof(DTYPE) == 0), "Unsupported block size!");
CheckNTErrors(devID >= 0, "Wrong device to run!");
CheckNTErrors(blockSize % sizeof(DTYPE) == 0, "Unsupported block size!");
int cudaGrids[3];
int cudaBlocks[3];
...
...
@@ -92,15 +91,15 @@ void _CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target,
if (bSize % 4 == 0) {
GDevs.GetCudaThread2D(myMem->devID, bSize / 4, blockNum, MAX_INT, cudaGrids, cudaBlocks);
KernelCopyBlocks<4> <<
<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>
>
((DTYPE*)source, bSize, blockNum, (DTYPE*)target, targetBlocks);
KernelCopyBlocks<4> <<
<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>
>
((DTYPE*)source, bSize, blockNum, (DTYPE*)target, targetBlocks);
}
else {
GDevs.GetCudaThread2D(myMem->devID, bSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
KernelCopyBlocks<1> <<
<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>
>
((DTYPE*)source, bSize, blockNum, (DTYPE*)target, targetBlocks);
KernelCopyBlocks<1> <<
<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>
>
((DTYPE*)source, bSize, blockNum, (DTYPE*)target, targetBlocks);
}
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
} // namespace nts(NiuTrans.Tensor)
source/tensor/core/movement/CopyBlocksOnSite.cuh
查看文件 @
15f75d3a
...
...
@@ -33,10 +33,10 @@ __global__
void KernelCopyBlocks(DTYPE * source, int blockSize, int blockNum, DTYPE * target, int * targetBlocks);
/* copy a number of blocks to target positions (cuda version) */
void _CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks,
XMem * myMem
);
void _CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks,
int devID
);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __COPYBLOCKS_CUH__
\ No newline at end of file
#endif // __COPYBLOCKS_CUH__
source/tensor/core/movement/CopyBlocksOnSite.h
查看文件 @
15f75d3a
...
...
@@ -27,7 +27,7 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/* copy a number of blocks to target positions (on site) */
void
_CopyBlocksOnSite
(
void
*
source
,
int
blockSize
,
int
blockNum
,
void
*
target
,
int
*
targetBlocks
,
XMem
*
myMem
);
void
_CopyBlocksOnSite
(
void
*
source
,
int
blockSize
,
int
blockNum
,
void
*
target
,
int
*
targetBlocks
,
int
devID
);
}
// namespace nts(NiuTrans.Tensor)
...
...
source/tensor/core/shape/Merge.cpp
查看文件 @
15f75d3a
...
...
@@ -126,7 +126,7 @@ void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
_MakeMergeBlockIndex
(
blockIndex
,
blockNum
,
blockNumInMerge
,
splitSizeInGrid
,
gridSize
,
gridNum
,
mem
);
_CopyBlocksOnSite
(
s
->
data
,
realBlockSize
,
blockNum
,
dataTMP
,
blockIndex
,
mem
);
_CopyBlocksOnSite
(
s
->
data
,
realBlockSize
,
blockNum
,
dataTMP
,
blockIndex
,
s
->
devID
);
if
(
mem
!=
NULL
)
mem
->
ReleaseBuf
(
mem
->
devID
,
blockNum
*
gridNum
*
sizeof
(
int
));
...
...
source/tensor/core/shape/Split.cpp
查看文件 @
15f75d3a
...
...
@@ -138,7 +138,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
_MakeSplitBlockIndex
(
blockIndex
,
splitNum
,
blockSplitSize
,
blockNum
,
s
->
devID
);
_CopyBlocksOnSite
(
s
->
data
,
realBlockSize
,
blockNum
,
dataTMP
,
blockIndex
,
mem
);
_CopyBlocksOnSite
(
s
->
data
,
realBlockSize
,
blockNum
,
dataTMP
,
blockIndex
,
s
->
devID
);
if
(
mem
!=
NULL
)
mem
->
ReleaseBuf
(
mem
->
devID
,
blockNum
*
sizeof
(
int
));
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论