Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
3f23f074
Commit
3f23f074
authored
Jul 27, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
split with stream
parent
70e478c4
显示空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
119 行增加
和
23 行删除
+119
-23
source/network/Main.cpp
+0
-1
source/tensor/XLink.cpp
+3
-1
source/tensor/XUtility.cpp
+38
-0
source/tensor/XUtility.h
+2
-0
source/tensor/core/shape/MakeSplitBlockIndex.cpp
+4
-4
source/tensor/core/shape/MakeSplitBlockIndex.h
+1
-1
source/tensor/core/shape/Merge.cpp
+9
-7
source/tensor/core/shape/Split.cpp
+60
-9
source/tensor/core/shape/Split.h
+2
-0
没有找到文件。
source/network/Main.cpp
查看文件 @
3f23f074
...
@@ -32,7 +32,6 @@
...
@@ -32,7 +32,6 @@
using
namespace
nts
;
using
namespace
nts
;
using
namespace
samplefnnlm
;
using
namespace
samplefnnlm
;
int
main
(
int
argc
,
const
char
**
argv
)
int
main
(
int
argc
,
const
char
**
argv
)
{
{
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-test"
))
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-test"
))
...
...
source/tensor/XLink.cpp
查看文件 @
3f23f074
...
@@ -167,7 +167,9 @@ void XLink::SetType(int id)
...
@@ -167,7 +167,9 @@ void XLink::SetType(int id)
type
[
0
]
=
0
;
type
[
0
]
=
0
;
strcpy
(
type
,
GetOPName
(
id
));
strcpy
(
type
,
GetOPName
(
id
));
typeID
=
id
;
typeID
=
id
;
if
(
id
!=
0
){
CheckNTErrors
(
strcmp
(
type
,
"NULL"
),
"illegal edge type name!"
);
CheckNTErrors
(
strcmp
(
type
,
"NULL"
),
"illegal edge type name!"
);
}
}
}
/*
/*
...
@@ -515,7 +517,7 @@ void XLink::CopyIncoming(const XTensor * reference, XTensor * target)
...
@@ -515,7 +517,7 @@ void XLink::CopyIncoming(const XTensor * reference, XTensor * target)
tails
.
Add
(
tail
);
tails
.
Add
(
tail
);
}
}
MakeLink
(
&
tails
,
target
,
reference
->
i
d
);
MakeLink
(
&
tails
,
target
,
reference
->
i
ncome
.
typeID
);
int
paraNum
=
reference
->
income
.
paramNum
;
int
paraNum
=
reference
->
income
.
paramNum
;
target
->
income
.
paramNum
=
paraNum
;
target
->
income
.
paramNum
=
paraNum
;
...
...
source/tensor/XUtility.cpp
查看文件 @
3f23f074
...
@@ -284,6 +284,44 @@ void XMemCopy2D(void * t, size_t tPitch, int devIDT, const void * s, size_t sPit
...
@@ -284,6 +284,44 @@ void XMemCopy2D(void * t, size_t tPitch, int devIDT, const void * s, size_t sPit
#endif
#endif
}
}
void
XMemCopy2DAsync
(
void
*
t
,
size_t
tPitch
,
int
devIDT
,
const
void
*
s
,
size_t
sPitch
,
int
devIDS
,
size_t
mSize
,
int
n
,
XStream
*
stream
)
{
if
(
t
==
s
)
return
;
if
(
devIDT
<
0
&&
devIDS
<
0
)
{
for
(
int
i
=
0
;
i
<
n
;
i
++
)
memcpy
((
char
*
)
t
+
tPitch
*
i
,
(
char
*
)
s
+
sPitch
*
i
,
mSize
);
return
;
}
#ifdef USE_CUDA
else
{
CheckNTErrors
(
stream
!=
NULL
,
"No stream found!"
);
cudaStream_t
&
cstream
=
stream
->
stream
;
if
(
devIDT
>=
0
&&
devIDS
<
0
)
{
cudaError_t
error
=
cudaMemcpy2DAsync
(
t
,
tPitch
,
s
,
sPitch
,
mSize
,
n
,
cudaMemcpyHostToDevice
,
cstream
);
if
(
error
!=
cudaSuccess
){
ShowNTErrors
(
"cudaMemcpy2D error (cudaMemcpyHostToDevice)"
);
}
}
else
if
(
devIDT
<
0
&&
devIDS
>=
0
)
{
cudaError_t
error
=
cudaMemcpy2DAsync
(
t
,
tPitch
,
s
,
sPitch
,
mSize
,
n
,
cudaMemcpyDeviceToHost
,
cstream
);
if
(
error
!=
cudaSuccess
){
ShowNTErrors
(
"cudaMemcpy error (cudaMemcpyDeviceToHost)"
);
}
}
else
{
cudaError_t
error
=
cudaMemcpy2DAsync
(
t
,
tPitch
,
s
,
sPitch
,
mSize
,
n
,
cudaMemcpyDeviceToDevice
,
cstream
);
if
(
error
!=
cudaSuccess
)
{
ShowNTErrors
(
"cudaMemcpy error (cudaMemcpyDeviceToDevice)"
);
}
}
}
#else
ShowNTErrors
(
"Please specify USE_CUDA and recompile the code!"
);
#endif
}
void
*
XMemAlloc
(
int
devID
,
size_t
size
)
void
*
XMemAlloc
(
int
devID
,
size_t
size
)
{
{
void
*
p
=
NULL
;
void
*
p
=
NULL
;
...
...
source/tensor/XUtility.h
查看文件 @
3f23f074
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
#include <stdio.h>
#include <stdio.h>
#include "XGlobal.h"
#include "XGlobal.h"
#include "XDevice.h"
#ifndef __XUTILITY_H__
#ifndef __XUTILITY_H__
#define __XUTILITY_H__
#define __XUTILITY_H__
...
@@ -41,6 +42,7 @@ extern void XMemSet(void * p, int value, size_t size);
...
@@ -41,6 +42,7 @@ extern void XMemSet(void * p, int value, size_t size);
extern
void
XMemSet
(
int
devID
,
void
*
p
,
int
value
,
size_t
size
);
extern
void
XMemSet
(
int
devID
,
void
*
p
,
int
value
,
size_t
size
);
extern
void
XMemCopy
(
void
*
t
,
int
devIDT
,
const
void
*
s
,
int
devIDS
,
size_t
size
);
extern
void
XMemCopy
(
void
*
t
,
int
devIDT
,
const
void
*
s
,
int
devIDS
,
size_t
size
);
extern
void
XMemCopy2D
(
void
*
t
,
size_t
tPitch
,
int
devIDT
,
const
void
*
s
,
size_t
sPitch
,
int
devIDS
,
size_t
mSize
,
int
n
);
extern
void
XMemCopy2D
(
void
*
t
,
size_t
tPitch
,
int
devIDT
,
const
void
*
s
,
size_t
sPitch
,
int
devIDS
,
size_t
mSize
,
int
n
);
extern
void
XMemCopy2DAsync
(
void
*
t
,
size_t
tPitch
,
int
devIDT
,
const
void
*
s
,
size_t
sPitch
,
int
devIDS
,
size_t
mSize
,
int
n
,
XStream
*
stream
);
extern
void
*
XMemAlloc
(
int
devID
,
size_t
size
);
extern
void
*
XMemAlloc
(
int
devID
,
size_t
size
);
extern
void
*
XMemAllocOnDev
(
int
devID
,
size_t
size
);
extern
void
*
XMemAllocOnDev
(
int
devID
,
size_t
size
);
extern
void
XMemFree
(
int
devID
,
void
*
p
);
extern
void
XMemFree
(
int
devID
,
void
*
p
);
...
...
source/tensor/core/shape/MakeSplitBlockIndex.cpp
查看文件 @
3f23f074
...
@@ -31,13 +31,13 @@ set target data block index for the data movement in split
...
@@ -31,13 +31,13 @@ set target data block index for the data movement in split
>> splitNum - number of splits
>> splitNum - number of splits
>> blockSplitSize - size of the splitted block
>> blockSplitSize - size of the splitted block
>> blockNum - number of data blocks
>> blockNum - number of data blocks
>>
mem - the memory pool
>>
devID - device id
*/
*/
void
_MakeSplitBlockIndex
(
int
*
blockIndex
,
int
splitNum
,
int
blockSplitSize
,
int
blockNum
,
XMem
*
mem
)
void
_MakeSplitBlockIndex
(
int
*
blockIndex
,
int
splitNum
,
int
blockSplitSize
,
int
blockNum
,
int
devID
)
{
{
if
(
mem
!=
NULL
&&
mem
->
devID
>=
0
)
{
if
(
devID
>=
0
)
{
#ifdef USE_CUDA
#ifdef USE_CUDA
_CudaMakeSplitBlockIndex
(
mem
->
devID
,
blockIndex
,
splitNum
,
blockSplitSize
,
blockNum
);
_CudaMakeSplitBlockIndex
(
devID
,
blockIndex
,
splitNum
,
blockSplitSize
,
blockNum
);
#else
#else
ShowNTErrors
(
"Please specify USE_CUDA and recompile the code!"
);
ShowNTErrors
(
"Please specify USE_CUDA and recompile the code!"
);
#endif
#endif
...
...
source/tensor/core/shape/MakeSplitBlockIndex.h
查看文件 @
3f23f074
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/* set target data block index for the data movement in split */
/* set target data block index for the data movement in split */
void
_MakeSplitBlockIndex
(
int
*
blockIndex
,
int
splitNum
,
int
blockSplitSize
,
int
blockNum
,
XMem
*
mem
);
void
_MakeSplitBlockIndex
(
int
*
blockIndex
,
int
splitNum
,
int
blockSplitSize
,
int
blockNum
,
int
devID
);
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
...
...
source/tensor/core/shape/Merge.cpp
查看文件 @
3f23f074
...
@@ -42,6 +42,8 @@ e.g., (N/3, M, 3) -> (N, M)
...
@@ -42,6 +42,8 @@ e.g., (N/3, M, 3) -> (N, M)
*/
*/
void
_Merge
(
const
XTensor
*
s
,
XTensor
*
t
,
int
whereToMerge
,
int
leadingDim
)
void
_Merge
(
const
XTensor
*
s
,
XTensor
*
t
,
int
whereToMerge
,
int
leadingDim
)
{
{
if
(
leadingDim
<
0
)
leadingDim
=
0
;
int
whereToMergeRDI
=
s
->
order
-
whereToMerge
-
1
;
int
whereToMergeRDI
=
s
->
order
-
whereToMerge
-
1
;
int
leadingDimRDI
=
s
->
order
-
leadingDim
-
1
;
int
leadingDimRDI
=
s
->
order
-
leadingDim
-
1
;
if
(
leadingDimRDI
<
0
)
if
(
leadingDimRDI
<
0
)
...
@@ -268,10 +270,10 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
...
@@ -268,10 +270,10 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
}
}
/* merging with fewer kernel/api calls??? (i'm not sure about it!! may remove this later) */
/* merging with fewer kernel/api calls??? (i'm not sure about it!! may remove this later) */
else
{
else
{
int
*
dimSizeTMP
=
new
int
[
MAX_TENSOR_DIM_NUM
];
int
*
dimSizeTMP
=
new
int
[
smallsItem0
->
order
+
1
];
for
(
int
i
=
0
;
i
<
MAX_TENSOR_DIM_NUM
;
i
++
)
for
(
int
i
=
0
;
i
<
smallsItem0
->
order
;
i
++
)
dimSizeTMP
[
i
]
=
-
smallsItem0
->
dimSizeRDI
[
i
];
dimSizeTMP
[
i
+
1
]
=
-
smallsItem0
->
dimSize
[
i
];
dimSizeTMP
[
smallsItem0
->
order
]
=
-
mergeNum
;
dimSizeTMP
[
0
]
=
-
mergeNum
;
XMem
*
mem
=
smallsItem0
->
mem
;
XMem
*
mem
=
smallsItem0
->
mem
;
XTensor
*
tensorTMP
=
new
XTensor
(
smallsItem0
->
order
+
1
,
dimSizeTMP
,
XTensor
*
tensorTMP
=
new
XTensor
(
smallsItem0
->
order
+
1
,
dimSizeTMP
,
...
@@ -283,7 +285,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
...
@@ -283,7 +285,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
if
(
uniform
)
if
(
uniform
)
dataTMP
=
smallsItem0
->
data
;
dataTMP
=
smallsItem0
->
data
;
else
else
dataTMP
=
mem
!=
NULL
?
mem
->
AllocBuf
(
mem
->
devID
,
size
)
:
XMemAlloc
(
mem
->
devID
,
size
);
dataTMP
=
mem
!=
NULL
?
mem
->
AllocBuf
(
mem
->
devID
,
size
)
:
XMemAlloc
(
big
->
devID
,
size
);
tensorTMP
->
data
=
dataTMP
;
tensorTMP
->
data
=
dataTMP
;
...
@@ -295,7 +297,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
...
@@ -295,7 +297,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
}
}
}
}
_Merge
(
tensorTMP
,
big
,
whereToMerge
);
_Merge
(
tensorTMP
,
big
,
whereToMerge
+
1
);
delete
[]
dimSizeTMP
;
delete
[]
dimSizeTMP
;
tensorTMP
->
data
=
NULL
;
tensorTMP
->
data
=
NULL
;
...
@@ -306,7 +308,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
...
@@ -306,7 +308,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
if
((
!
uniform
)
&&
(
mem
!=
NULL
))
if
((
!
uniform
)
&&
(
mem
!=
NULL
))
mem
->
ReleaseBuf
(
mem
->
devID
,
size
);
mem
->
ReleaseBuf
(
mem
->
devID
,
size
);
else
else
XMemFree
(
mem
->
devID
,
dataTMP
);
XMemFree
(
big
->
devID
,
dataTMP
);
}
}
}
}
...
...
source/tensor/core/shape/Split.cpp
查看文件 @
3f23f074
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include "MakeSplitBlockIndex.h"
#include "MakeSplitBlockIndex.h"
#include "../../XName.h"
#include "../../XName.h"
#include "../../XTensor.h"
#include "../../XTensor.h"
#include "../../XDevice.h"
#include "../../XUtility.h"
#include "../../XUtility.h"
#include "../movement/CopyBlocksOnSite.h"
#include "../movement/CopyBlocksOnSite.h"
...
@@ -82,18 +83,42 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
...
@@ -82,18 +83,42 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
CheckNTErrors
((
blockNum
%
splitNum
==
0
),
"Incorrect split number!"
);
CheckNTErrors
((
blockNum
%
splitNum
==
0
),
"Incorrect split number!"
);
if
(
splitNum
<=
MIN_TENSOR_SPLIT_NUM
)
{
if
(
splitNum
<=
MIN_TENSOR_SPLIT_NUM
)
{
//if (splitNum <= 0) {
int
sPitch
=
blockSize
*
splitNum
*
s
->
unitSize
;
int
sPitch
=
blockSize
*
splitNum
*
s
->
unitSize
;
int
tPitch
=
blockSize
*
t
->
unitSize
;
int
tPitch
=
blockSize
*
t
->
unitSize
;
int
mSize
=
blockSize
*
t
->
unitSize
;
int
mSize
=
blockSize
*
t
->
unitSize
;
int
n
=
blockNum
/
splitNum
;
int
n
=
blockNum
/
splitNum
;
int
sStep
=
blockSize
*
s
->
unitSize
;
int
sStep
=
blockSize
*
s
->
unitSize
;
int
tStep
=
n
*
tPitch
;
int
tStep
=
n
*
tPitch
;
if
(
t
->
devID
<
0
){
for
(
int
k
=
0
;
k
<
splitNum
;
k
++
)
{
for
(
int
k
=
0
;
k
<
splitNum
;
k
++
)
{
XMemCopy2D
((
char
*
)
t
->
data
+
k
*
tStep
,
tPitch
,
t
->
devID
,
XMemCopy2D
((
char
*
)
t
->
data
+
k
*
tStep
,
tPitch
,
t
->
devID
,
(
char
*
)
s
->
data
+
k
*
sStep
,
sPitch
,
s
->
devID
,
(
char
*
)
s
->
data
+
k
*
sStep
,
sPitch
,
s
->
devID
,
mSize
,
n
);
mSize
,
n
);
}
}
}
}
else
{
#ifdef USE_CUDA
#ifdef STREAMED_MEMCPOPY
XStream
*
stream
=
GDevs
.
GPUs
[
t
->
devID
].
stream
;
for
(
int
k
=
0
;
k
<
splitNum
;
k
++
)
{
XMemCopy2DAsync
((
char
*
)
t
->
data
+
k
*
tStep
,
tPitch
,
t
->
devID
,
(
char
*
)
s
->
data
+
k
*
sStep
,
sPitch
,
s
->
devID
,
mSize
,
n
,
stream
);
}
stream
->
StreamSynchronize
();
#else
for
(
int
k
=
0
;
k
<
splitNum
;
k
++
)
{
XMemCopy2D
((
char
*
)
t
->
data
+
k
*
tStep
,
tPitch
,
t
->
devID
,
(
char
*
)
s
->
data
+
k
*
sStep
,
sPitch
,
s
->
devID
,
mSize
,
n
);
}
#endif
#else
ShowNTErrors
(
"Please specify USE_CUDA and recompile the code!"
);
#endif
}
}
else
{
else
{
XMem
*
mem
=
s
->
mem
;
XMem
*
mem
=
s
->
mem
;
int
size
=
s
->
unitNum
*
s
->
unitSize
;
int
size
=
s
->
unitNum
*
s
->
unitSize
;
...
@@ -109,9 +134,9 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
...
@@ -109,9 +134,9 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
int
*
blockIndex
=
(
int
*
)(
mem
!=
NULL
?
int
*
blockIndex
=
(
int
*
)(
mem
!=
NULL
?
mem
->
AllocBuf
(
mem
->
devID
,
blockNum
*
sizeof
(
int
))
:
mem
->
AllocBuf
(
mem
->
devID
,
blockNum
*
sizeof
(
int
))
:
XMemAlloc
(
mem
->
devID
,
blockNum
*
sizeof
(
int
)));
XMemAlloc
(
s
->
devID
,
blockNum
*
sizeof
(
int
)));
_MakeSplitBlockIndex
(
blockIndex
,
splitNum
,
blockSplitSize
,
blockNum
,
mem
);
_MakeSplitBlockIndex
(
blockIndex
,
splitNum
,
blockSplitSize
,
blockNum
,
s
->
devID
);
_CopyBlocksOnSite
(
s
->
data
,
realBlockSize
,
blockNum
,
dataTMP
,
blockIndex
,
mem
);
_CopyBlocksOnSite
(
s
->
data
,
realBlockSize
,
blockNum
,
dataTMP
,
blockIndex
,
mem
);
...
@@ -226,6 +251,8 @@ void _Split(const XTensor * big, XList * smalls, int whereToSplit, int splitNum)
...
@@ -226,6 +251,8 @@ void _Split(const XTensor * big, XList * smalls, int whereToSplit, int splitNum)
int
n
=
blockNum
/
splitNum
;
int
n
=
blockNum
/
splitNum
;
int
sStep
=
blockSize
*
big
->
unitSize
;
int
sStep
=
blockSize
*
big
->
unitSize
;
int
tStep
=
0
;
int
tStep
=
0
;
if
(
big
->
devID
<
0
){
for
(
int
k
=
0
;
k
<
splitNum
;
k
++
)
{
for
(
int
k
=
0
;
k
<
splitNum
;
k
++
)
{
XTensor
*
t
=
(
XTensor
*
)
smalls
->
GetItem
(
k
);
XTensor
*
t
=
(
XTensor
*
)
smalls
->
GetItem
(
k
);
XMemCopy2D
((
char
*
)
t
->
data
+
k
*
tStep
,
tPitch
,
t
->
devID
,
XMemCopy2D
((
char
*
)
t
->
data
+
k
*
tStep
,
tPitch
,
t
->
devID
,
...
@@ -233,13 +260,37 @@ void _Split(const XTensor * big, XList * smalls, int whereToSplit, int splitNum)
...
@@ -233,13 +260,37 @@ void _Split(const XTensor * big, XList * smalls, int whereToSplit, int splitNum)
mSize
,
n
);
mSize
,
n
);
}
}
}
}
else
{
#ifdef USE_CUDA
#ifdef STREAMED_MEMCPOPY
XStream
*
stream
=
GDevs
.
GPUs
[
big
->
devID
].
stream
;
for
(
int
k
=
0
;
k
<
splitNum
;
k
++
)
{
XTensor
*
t
=
(
XTensor
*
)
smalls
->
GetItem
(
k
);
XMemCopy2DAsync
((
char
*
)
t
->
data
+
k
*
tStep
,
tPitch
,
t
->
devID
,
(
char
*
)
big
->
data
+
k
*
sStep
,
sPitch
,
big
->
devID
,
mSize
,
n
,
stream
);
}
stream
->
StreamSynchronize
();
#else
for
(
int
k
=
0
;
k
<
splitNum
;
k
++
)
{
XTensor
*
t
=
(
XTensor
*
)
smalls
->
GetItem
(
k
);
XMemCopy2D
((
char
*
)
t
->
data
+
k
*
tStep
,
tPitch
,
t
->
devID
,
(
char
*
)
big
->
data
+
k
*
sStep
,
sPitch
,
big
->
devID
,
mSize
,
n
);
}
#endif
#else
ShowNTErrors
(
"Please specify USE_CUDA and recompile the code!"
);
#endif
}
}
/* splitting with fewer kernel/api calls??? (i'm not sure about it!! may remove this later) */
/* splitting with fewer kernel/api calls??? (i'm not sure about it!! may remove this later) */
else
{
else
{
int
*
dimSizeTMP
=
new
int
[
MAX_TENSOR_DIM_NUM
];
int
*
dimSizeTMP
=
new
int
[
big
->
order
+
1
];
for
(
int
i
=
0
;
i
<
MAX_TENSOR_DIM_NUM
;
i
++
)
for
(
int
i
=
0
;
i
<
big
->
order
;
i
++
)
dimSizeTMP
[
i
]
=
-
big
->
dimSize
[
i
];
dimSizeTMP
[
i
+
1
]
=
-
big
->
dimSize
[
i
];
dimSizeTMP
[
whereToSplit
]
/=
splitNum
;
dimSizeTMP
[
whereToSplit
+
1
]
/=
splitNum
;
dimSizeTMP
[
big
->
order
]
=
-
splitNum
;
dimSizeTMP
[
0
]
=
-
splitNum
;
XMem
*
mem
=
big
->
mem
;
XMem
*
mem
=
big
->
mem
;
XTensor
*
tensorTMP
=
new
XTensor
(
big
->
order
+
1
,
dimSizeTMP
,
big
->
dataType
,
big
->
denseRatio
,
big
->
devID
,
mem
);
XTensor
*
tensorTMP
=
new
XTensor
(
big
->
order
+
1
,
dimSizeTMP
,
big
->
dataType
,
big
->
denseRatio
,
big
->
devID
,
mem
);
...
@@ -251,7 +302,7 @@ void _Split(const XTensor * big, XList * smalls, int whereToSplit, int splitNum)
...
@@ -251,7 +302,7 @@ void _Split(const XTensor * big, XList * smalls, int whereToSplit, int splitNum)
dataTMP
=
first
->
data
;
dataTMP
=
first
->
data
;
}
}
else
{
else
{
dataTMP
=
mem
!=
NULL
?
mem
->
AllocBuf
(
mem
->
devID
,
size
)
:
XMemAlloc
(
mem
->
devID
,
size
);
dataTMP
=
mem
!=
NULL
?
mem
->
AllocBuf
(
mem
->
devID
,
size
)
:
XMemAlloc
(
big
->
devID
,
size
);
}
}
tensorTMP
->
data
=
dataTMP
;
tensorTMP
->
data
=
dataTMP
;
...
@@ -276,7 +327,7 @@ void _Split(const XTensor * big, XList * smalls, int whereToSplit, int splitNum)
...
@@ -276,7 +327,7 @@ void _Split(const XTensor * big, XList * smalls, int whereToSplit, int splitNum)
if
((
!
uniform
)
&&
(
mem
!=
NULL
))
if
((
!
uniform
)
&&
(
mem
!=
NULL
))
mem
->
ReleaseBuf
(
mem
->
devID
,
size
);
mem
->
ReleaseBuf
(
mem
->
devID
,
size
);
else
else
XMemFree
(
mem
->
devID
,
dataTMP
);
XMemFree
(
big
->
devID
,
dataTMP
);
}
}
}
}
...
...
source/tensor/core/shape/Split.h
查看文件 @
3f23f074
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
#define STREAMED_MEMCPOPY
/*
/*
transform a tensor by splitting it
transform a tensor by splitting it
e.g., (M, N) -> (M, N/3, 3)
e.g., (M, N) -> (M, N/3, 3)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论