Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
Emmay
NiuTrans.Tensor
Commits
5bf12b8c
Commit
5bf12b8c
authored
Jul 19, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
back propagation for Merge
parent
b8415485
显示空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
300 行增加
和
44 行删除
+300
-44
source/network/Main.cpp
+1
-0
source/network/XBackwardMath.cpp
+0
-0
source/network/XBackwardShape.cpp
+213
-0
source/network/XBackwardShape.h
+23
-0
source/tensor/XName.cpp
+4
-0
source/tensor/XName.h
+4
-2
source/tensor/core/shape/Concatenate.cpp
+15
-29
source/tensor/core/shape/Merge.cpp
+12
-1
source/tensor/core/shape/Merge.h
+5
-9
source/tensor/core/shape/Split.cpp
+23
-3
没有找到文件。
source/network/Main.cpp
查看文件 @
5bf12b8c
...
@@ -65,6 +65,7 @@ int main( int argc, const char ** argv )
...
@@ -65,6 +65,7 @@ int main( int argc, const char ** argv )
a
.
Dump
(
stderr
,
"a:"
);
a
.
Dump
(
stderr
,
"a:"
);
b
.
Dump
(
stderr
,
"b:"
);
b
.
Dump
(
stderr
,
"b:"
);
c
.
Dump
(
stderr
,
"c:"
);
c
.
Dump
(
stderr
,
"c:"
);
XLink
::
ShowNetwork
(
stderr
,
&
c
);
net
.
Backward
(
c
);
net
.
Backward
(
c
);
...
...
source/network/XBackwardMath.cpp
查看文件 @
5bf12b8c
source/network/XBackwardShape.cpp
查看文件 @
5bf12b8c
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* backward computation for math operations
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-19
* It was chilly when I came into the office this morning ...
* because i forgot to turn the air-condition off last night :(
*/
#include "XNoder.h"
#include "XBackwardShape.h"
#include "../tensor/XName.h"
#include "../tensor/core/CHeader.h"
namespace
nts
{
/* compute dE/dx of a node */
void
XShapeGrad
::
MakeGrad
(
XTensor
*
node
)
{
CheckNTErrors
(
node
->
grad
!=
NULL
,
"No gradient found!"
);
XLink
&
income
=
node
->
income
;
int
operID
=
income
.
typeID
;
if
(
operID
==
SHAPE_MERGE
)
GradMerge
(
node
);
else
if
(
operID
==
SHAPE_MERGE_LIST
)
GradMergeList
(
node
);
else
{
ShowNTErrors
(
"TODO!"
);
}
}
/* indicates whether the node is for a math operation */
bool
XShapeGrad
::
IsShapeOP
(
XTensor
*
node
)
{
XLink
&
income
=
node
->
income
;
return
(
income
.
typeID
&
DATA_BASE
)
!=
0
;
}
/*
gradient for merge
for
c = merge(a_0, a_1, ...)
where a_i is the i-th block in a tensor a
we have
dE/da_0 = dE/dc_{split_0}
dE/db_1 = dE/dc_{split_1}
...
i.e.,
dE/da = split(dE/dc)
>> node - the node (c) for backward computation
*/
void
XShapeGrad
::
GradMerge
(
XTensor
*
node
)
{
XLink
&
income
=
node
->
income
;
CheckNTErrors
(
income
.
tailNum
==
0
,
"Wrong input tensor number for MERGE!"
);
XTensor
*
input
=
income
.
tails
[
0
];
int
whereToMerge
=
income
.
GetParamInt
(
0
);
int
leadDim
=
income
.
GetParamInt
(
1
);
int
blockSize
=
1
;
int
blockNum
=
1
;
for
(
int
i
=
0
;
i
<
input
->
order
;
i
++
){
if
(
i
<
leadDim
)
blockNum
*=
input
->
dimSize
[
i
];
}
blockSize
=
input
->
GetDataSizeInChar
()
/
blockNum
;
XNoder
::
MakeGrad
(
input
);
int
*
dims
=
new
int
[
input
->
order
];
for
(
int
i
=
0
,
j
=
0
;
i
<
input
->
order
;
i
++
){
if
(
i
>=
leadDim
){
dims
[
j
++
]
=
input
->
dimSize
[
i
];
}
}
dims
[
0
]
=
-
dims
[
0
];
XTensor
gradInputSmall
(
input
->
order
-
leadDim
,
dims
,
input
->
dataType
,
input
->
denseRatio
,
input
->
devID
,
input
->
mem
);
dims
[
whereToMerge
-
leadDim
]
*=
dims
[
0
];
XTensor
gradNodeSmall
(
node
->
order
-
leadDim
,
dims
,
node
->
dataType
,
node
->
denseRatio
,
node
->
devID
,
node
->
mem
);
/* we can simply split the gradient tensor
if the input is used in merging only */
if
(
input
->
outgo
.
tailNum
==
1
){
for
(
int
i
=
0
;
i
<
blockNum
;
i
++
){
gradNodeSmall
.
data
=
(
char
*
)
node
->
grad
->
data
+
i
*
blockSize
;
gradInputSmall
.
data
=
(
char
*
)
input
->
grad
->
data
+
i
*
blockSize
;
_Split
(
&
gradNodeSmall
,
&
gradInputSmall
,
whereToMerge
-
leadDim
,
input
->
dimSize
[
leadDim
]);
}
}
/* a more complicated case is that the input tensor is used for
other operations somewhere else. So we have to do gradient
accumulation after spliting, i.e., we need an additional
SUM operation */
else
{
XTensor
gradInputSmallBuf
(
&
gradInputSmall
);
for
(
int
i
=
0
;
i
<
blockNum
;
i
++
){
gradNodeSmall
.
data
=
(
char
*
)
node
->
grad
->
data
+
i
*
blockSize
;
gradInputSmall
.
data
=
(
char
*
)
input
->
grad
->
data
+
i
*
blockSize
;
_Split
(
&
gradNodeSmall
,
&
gradInputSmallBuf
,
whereToMerge
-
leadDim
,
input
->
dimSize
[
leadDim
]);
_Sum
(
&
gradInputSmall
,
&
gradInputSmallBuf
,
&
gradInputSmall
);
}
}
gradNodeSmall
.
data
=
NULL
;
gradInputSmall
.
data
=
NULL
;
delete
[]
dims
;
}
/*
gradient for merging a list of tensors
for
c = merge(list(a, b, ...))
where a, b ... are of the same size
we have
dE/da = dE/dc_{split_0}
dE/db = dE/dc_{split_1}
i.e.,
list(dE/da, dE/db, ...) = split(dE/dc)
*/
void
XShapeGrad
::
GradMergeList
(
XTensor
*
node
)
{
XLink
&
income
=
node
->
income
;
CheckNTErrors
(
income
.
tailNum
>
0
,
"Wrong input tensor number for MERGE!"
);
XTensor
*
last
=
NULL
;
XList
smalls
(
income
.
tailNum
);
XList
smallsGrad
(
income
.
tailNum
);
bool
mergeOnly
=
true
;
for
(
int
i
=
0
;
i
<
income
.
tailNum
;
i
++
){
XTensor
*
tail
=
income
.
tails
[
i
];
XNoder
::
MakeGrad
(
tail
);
smalls
.
Add
(
tail
);
smallsGrad
.
Add
(
tail
->
grad
);
if
(
i
>
1
){
CheckNTErrors
(
XTensor
::
IsIdentical
(
last
,
tail
),
"Input tensors must be of the same size!"
);
}
if
(
tail
->
outgo
.
tailNum
>
1
)
mergeOnly
=
false
;
last
=
tail
;
}
int
whereToMerge
=
income
.
GetParamInt
(
0
);
/* we can simply split the gradient tensor into the input tensors
if the inputs are used in merging only */
if
(
mergeOnly
)
_Split
(
node
->
grad
,
&
smallsGrad
,
whereToMerge
,
smalls
.
count
);
/* a more complicated case is that the input tensors are used for
other operations somewhere else. So we have to do gradient
accumulation after spliting, i.e., we need an additional
SUM operation */
else
{
int
*
dims
=
new
int
[
last
->
order
+
1
];
dims
[
0
]
=
smalls
.
count
;
for
(
int
i
=
0
;
i
<
last
->
order
;
i
++
)
dims
[
i
+
1
]
=
last
->
dimSize
[
i
];
XTensor
gradSplit
(
last
->
order
+
1
,
dims
,
last
->
dataType
,
last
->
denseRatio
,
last
->
devID
,
last
->
mem
);
_Split
(
node
->
grad
,
&
gradSplit
,
whereToMerge
,
smalls
.
count
);
memcpy
(
dims
,
last
->
dimSize
,
sizeof
(
int
)
*
last
->
order
);
dims
[
0
]
=
-
dims
[
0
];
XTensor
gradSmall
(
last
->
order
,
dims
,
last
->
dataType
,
last
->
denseRatio
,
last
->
devID
,
last
->
mem
);
/* gradient accumulation for each split */
for
(
int
i
=
0
;
i
<
smalls
.
count
;
i
++
){
XTensor
*
smallGrad
=
(
XTensor
*
)
smallsGrad
.
Get
(
i
);
gradSmall
.
data
=
(
char
*
)
gradSplit
.
data
+
i
*
last
->
unitNum
*
last
->
unitSize
;
_Sum
(
smallGrad
,
&
gradSmall
,
smallGrad
);
}
delete
[]
dims
;
}
}
}
\ No newline at end of file
source/network/XBackwardShape.h
查看文件 @
5bf12b8c
...
@@ -28,6 +28,28 @@
...
@@ -28,6 +28,28 @@
namespace
nts
{
namespace
nts
{
/* this class computes the gradient for tensor shaping and movement given a node */
class
XShapeGrad
{
public
:
/* compute dE/dx of a node */
static
void
MakeGrad
(
XTensor
*
node
);
/* indicates whether the node is for a shaping operation */
static
bool
IsShapeOP
(
XTensor
*
node
);
private
:
/* gradient for merge: c = merge(a, b, ...) */
static
void
GradMerge
(
XTensor
*
node
);
/* gradient for merging a list of tensors : c = merge(list(a, b, ...)) */
static
void
GradMergeList
(
XTensor
*
node
);
};
}
}
#endif
#endif
\ No newline at end of file
source/tensor/XName.cpp
查看文件 @
5bf12b8c
...
@@ -71,10 +71,14 @@ const char * GetOPName(int type)
...
@@ -71,10 +71,14 @@ const char * GetOPName(int type)
return
"S_CONCATENATE"
;
return
"S_CONCATENATE"
;
else
if
(
type
==
SHAPE_MERGE
)
else
if
(
type
==
SHAPE_MERGE
)
return
"S_MERGE"
;
return
"S_MERGE"
;
else
if
(
type
==
SHAPE_MERGE_LIST
)
return
"S_MERGE_LIST"
;
else
if
(
type
==
SHAPE_PERMUTE
)
else
if
(
type
==
SHAPE_PERMUTE
)
return
"S_PERMUTE"
;
return
"S_PERMUTE"
;
else
if
(
type
==
SHAPE_SPLIT
)
else
if
(
type
==
SHAPE_SPLIT
)
return
"S_SPLIT"
;
return
"S_SPLIT"
;
else
if
(
type
==
SHAPE_SPLIT_LIST
)
return
"S_SPLIT_LIST"
;
else
if
(
type
==
SHAPE_TRANSPOSE
)
else
if
(
type
==
SHAPE_TRANSPOSE
)
return
"S_TRANSPOSE"
;
return
"S_TRANSPOSE"
;
else
if
(
type
==
SHAPE_UNSQUEEZE
)
else
if
(
type
==
SHAPE_UNSQUEEZE
)
...
...
source/tensor/XName.h
查看文件 @
5bf12b8c
...
@@ -62,9 +62,11 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
...
@@ -62,9 +62,11 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define SHAPE REDUCE_REDUCEVARIANCE + 1
#define SHAPE REDUCE_REDUCEVARIANCE + 1
#define SHAPE_CONCATENATE SHAPE + 1
#define SHAPE_CONCATENATE SHAPE + 1
#define SHAPE_MERGE SHAPE_CONCATENATE + 1
#define SHAPE_MERGE SHAPE_CONCATENATE + 1
#define SHAPE_PERMUTE SHAPE_MERGE + 1
#define SHAPE_MERGE_LIST SHAPE_MERGE + 1
#define SHAPE_PERMUTE SHAPE_MERGE_LIST + 1
#define SHAPE_SPLIT SHAPE_PERMUTE + 1
#define SHAPE_SPLIT SHAPE_PERMUTE + 1
#define SHAPE_TRANSPOSE SHAPE_SPLIT + 1
#define SHAPE_SPLIT_LIST SHAPE_SPLIT + 1
#define SHAPE_TRANSPOSE SHAPE_SPLIT_LIST + 1
#define SHAPE_UNSQUEEZE SHAPE_TRANSPOSE + 1
#define SHAPE_UNSQUEEZE SHAPE_TRANSPOSE + 1
/* activation functions */
/* activation functions */
...
...
source/tensor/core/shape/Concatenate.cpp
查看文件 @
5bf12b8c
...
@@ -80,40 +80,19 @@ XTensor Concatenate(const XList &smalls, int dim)
...
@@ -80,40 +80,19 @@ XTensor Concatenate(const XList &smalls, int dim)
if
(
!
XTensor
::
IsIdentical
(
a
,
b
))
if
(
!
XTensor
::
IsIdentical
(
a
,
b
))
uniform
=
false
;
uniform
=
false
;
}
}
int
*
dimSize
;
if
(
uniform
)
{
XTensor
*
tensor
=
(
XTensor
*
)
smalls
.
GetItem
(
0
);
XTensor
*
tensor
=
(
XTensor
*
)
smalls
.
GetItem
(
0
);
int
order
=
tensor
->
order
;
int
order
=
tensor
->
order
;
dimSize
=
new
int
[
order
];
int
*
dimSize
=
new
int
[
order
];
if
(
uniform
)
{
for
(
int
i
=
0
;
i
<
tensor
->
order
;
i
++
)
{
for
(
int
i
=
0
;
i
<
tensor
->
order
;
i
++
)
{
if
(
i
!=
dim
)
if
(
i
!=
dim
)
dimSize
[
i
]
=
tensor
->
dimSize
[
i
];
dimSize
[
i
]
=
tensor
->
dimSize
[
i
];
else
else
dimSize
[
i
]
=
tensor
->
dimSize
[
dim
]
*
smalls
.
count
;
dimSize
[
i
]
=
tensor
->
dimSize
[
dim
]
*
smalls
.
count
;
}
}
XTensor
big
=
XTensor
(
order
,
dimSize
,
tensor
->
dataType
,
tensor
->
denseRatio
,
tensor
->
devID
,
tensor
->
mem
);
big
.
SetZeroAll
();
big
.
SetTMP
();
/* call _Merge function */
_Merge
(
&
smalls
,
&
big
,
dim
);
///* tensor connection */
//XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
//XLink::AddParamToHead(&big, dim);
/* destroy variables */
delete
dimSize
;
return
big
;
}
}
else
{
else
{
XTensor
*
tensor
=
(
XTensor
*
)
smalls
.
GetItem
(
0
);
int
order
=
tensor
->
order
;
dimSize
=
new
int
[
order
];
for
(
int
i
=
0
;
i
<
tensor
->
order
;
i
++
)
for
(
int
i
=
0
;
i
<
tensor
->
order
;
i
++
)
if
(
i
!=
dim
)
if
(
i
!=
dim
)
dimSize
[
i
]
=
tensor
->
dimSize
[
i
];
dimSize
[
i
]
=
tensor
->
dimSize
[
i
];
...
@@ -124,19 +103,24 @@ XTensor Concatenate(const XList &smalls, int dim)
...
@@ -124,19 +103,24 @@ XTensor Concatenate(const XList &smalls, int dim)
catDimSize
+=
tensor
->
dimSize
[
dim
];
catDimSize
+=
tensor
->
dimSize
[
dim
];
}
}
dimSize
[
dim
]
=
catDimSize
;
dimSize
[
dim
]
=
catDimSize
;
}
XTensor
big
=
XTensor
(
order
,
dimSize
,
tensor
->
dataType
,
tensor
->
denseRatio
,
tensor
->
devID
,
tensor
->
mem
);
XTensor
big
=
NewTensor
(
order
,
dimSize
,
tensor
->
dataType
,
tensor
->
denseRatio
,
tensor
->
devID
,
tensor
->
mem
);
big
.
SetZeroAll
();
big
.
SetZeroAll
();
big
.
SetTMP
();
big
.
SetTMP
();
/* call _ConcatenateSolely function */
/* call _Merge function */
_ConcatenateSolely
(
&
smalls
,
&
big
,
dim
);
_Merge
(
&
smalls
,
&
big
,
dim
);
/* tensor connection */
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_CONCATENATE
);
XLink
::
AddParamToHeadInt
(
&
big
,
dim
);
/* destroy variables */
/* destroy variables */
delete
dimSize
;
delete
[]
dimSize
;
return
big
;
return
big
;
}
}
}
/*
/*
...
@@ -168,6 +152,8 @@ make a new tensor to keep the result and return it.
...
@@ -168,6 +152,8 @@ make a new tensor to keep the result and return it.
*/
*/
XTensor
Concatenate
(
const
XTensor
&
smallA
,
const
XTensor
&
smallB
,
int
dim
)
XTensor
Concatenate
(
const
XTensor
&
smallA
,
const
XTensor
&
smallB
,
int
dim
)
{
{
ShowNTErrors
(
"rewrite this!!!!!!!!"
);
XList
smalls
(
2
);
XList
smalls
(
2
);
smalls
.
Add
(
&
smallA
);
smalls
.
Add
(
&
smallA
);
smalls
.
Add
(
&
smallB
);
smalls
.
Add
(
&
smallB
);
...
...
source/tensor/core/shape/Merge.cpp
查看文件 @
5bf12b8c
...
@@ -187,6 +187,11 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
...
@@ -187,6 +187,11 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
/* call _Merge function */
/* call _Merge function */
_Merge
(
&
s
,
&
t
,
whereToMerge
,
leadingDim
);
_Merge
(
&
s
,
&
t
,
whereToMerge
,
leadingDim
);
/* tensor connections */
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_MERGE
);
XLink
::
AddParamToHeadInt
(
&
t
,
whereToMerge
);
XLink
::
AddParamToHeadInt
(
&
t
,
leadingDim
);
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -327,13 +332,19 @@ XTensor Merge(const XList &smalls, int whereToMerge)
...
@@ -327,13 +332,19 @@ XTensor Merge(const XList &smalls, int whereToMerge)
dimSize
[
i
]
=
tensor
->
dimSize
[
whereToMerge
]
*
smalls
.
count
;
dimSize
[
i
]
=
tensor
->
dimSize
[
whereToMerge
]
*
smalls
.
count
;
}
}
XTensor
big
=
NewTensor
(
order
,
dimSize
,
tensor
->
dataType
,
tensor
->
denseRatio
,
tensor
->
devID
,
tensor
->
mem
);
XTensor
big
=
NewTensor
(
order
,
dimSize
,
tensor
->
dataType
,
tensor
->
denseRatio
,
tensor
->
devID
,
tensor
->
mem
);
big
.
SetZeroAll
();
big
.
SetZeroAll
();
big
.
SetTMP
();
big
.
SetTMP
();
/* call _Merge function */
/* call _Merge function */
_Merge
(
&
smalls
,
&
big
,
whereToMerge
);
_Merge
(
&
smalls
,
&
big
,
whereToMerge
);
/* tensor connections */
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_MERGE_LIST
);
XLink
::
AddParamToHeadInt
(
&
big
,
whereToMerge
);
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
...
source/tensor/core/shape/Merge.h
查看文件 @
5bf12b8c
...
@@ -29,20 +29,16 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
...
@@ -29,20 +29,16 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* transform a tensor by merging it alone with a dimension, e.g., (M, N/3, 3) -> (M, N) */
/* transform a tensor by merging it alone with a dimension, e.g., (M, N/3, 3) -> (M, N) */
void
_Merge
(
const
XTensor
*
s
,
XTensor
*
t
,
int
whereToMerge
,
int
leadingDim
=
-
1
);
void
_Merge
(
const
XTensor
*
s
,
XTensor
*
t
,
int
whereToMerge
,
int
leadingDim
=
-
1
);
/*
/* transform a tensor by merging it alone with a dimension (return a XTensor structure).
transform a tensor by merging it alone with a dimension (return a XTensor structure).
make a new tensor to keep the result and return it.
make a new tensor to keep the result and return it.
e.g., (M, N/3, 3) -> (M, N) */
e.g., (M, N/3, 3) -> (M, N)
*/
XTensor
Merge
(
const
XTensor
&
s
,
int
whereToMerge
,
int
leadingDim
=
-
1
);
XTensor
Merge
(
const
XTensor
&
s
,
int
whereToMerge
,
int
leadingDim
=
-
1
);
/* merge small tensors into a big tensor */
/* merge small tensors into a big tensor */
void
_Merge
(
const
XList
*
smalls
,
XTensor
*
big
,
int
whereToMerge
);
void
_Merge
(
const
XList
*
smalls
,
XTensor
*
big
,
int
whereToMerge
);
/*
/* merge small tensors into a big tensor (return a XTensor structure).
merge small tensors into a big tensor (return a XTensor structure).
make a new tensor to keep the result and return it. */
make a new tensor to keep the result and return it.
*/
XTensor
Merge
(
const
XList
&
smalls
,
int
whereToMerge
);
XTensor
Merge
(
const
XList
&
smalls
,
int
whereToMerge
);
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
...
...
source/tensor/core/shape/Split.cpp
查看文件 @
5bf12b8c
...
@@ -19,10 +19,12 @@
...
@@ -19,10 +19,12 @@
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
*/
#include "../../XTensor.h"
#include "../../XUtility.h"
#include "Split.h"
#include "Split.h"
#include "MakeSplitBlockIndex.h"
#include "MakeSplitBlockIndex.h"
#include "../../XName.h"
#include "../../XTensor.h"
#include "../../XUtility.h"
#include "../movement/CopyBlocksOnSite.h"
#include "../movement/CopyBlocksOnSite.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
...
@@ -161,6 +163,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
...
@@ -161,6 +163,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
/* call _Split function */
/* call _Split function */
_Split
(
&
s
,
&
t
,
whereToSplit
,
splitNum
);
_Split
(
&
s
,
&
t
,
whereToSplit
,
splitNum
);
/* tensor connections */
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_SPLIT
);
XLink
::
AddParamToHeadInt
(
&
t
,
whereToSplit
);
XLink
::
AddParamToHeadInt
(
&
t
,
splitNum
);
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -298,7 +305,9 @@ XList SplitList(const XTensor &big, int whereToSplit, int splitNum)
...
@@ -298,7 +305,9 @@ XList SplitList(const XTensor &big, int whereToSplit, int splitNum)
}
}
for
(
int
i
=
0
;
i
<
splitNum
;
i
++
)
{
for
(
int
i
=
0
;
i
<
splitNum
;
i
++
)
{
XTensor
tensor
=
NewTensor
(
order
,
dimSize
,
big
.
dataType
,
big
.
denseRatio
,
big
.
devID
,
big
.
mem
);
XTensor
tensor
=
NewTensor
(
order
,
dimSize
,
big
.
dataType
,
big
.
denseRatio
,
big
.
devID
,
big
.
mem
);
tensor
.
SetZeroAll
();
tensor
.
SetZeroAll
();
tensor
.
SetTMP
();
tensor
.
SetTMP
();
smalls
.
Add
(
&
tensor
);
smalls
.
Add
(
&
tensor
);
...
@@ -307,6 +316,17 @@ XList SplitList(const XTensor &big, int whereToSplit, int splitNum)
...
@@ -307,6 +316,17 @@ XList SplitList(const XTensor &big, int whereToSplit, int splitNum)
/* call _Split function */
/* call _Split function */
_Split
(
&
big
,
&
smalls
,
whereToSplit
,
splitNum
);
_Split
(
&
big
,
&
smalls
,
whereToSplit
,
splitNum
);
/* tensor connections */
for
(
int
i
=
0
;
i
<
smalls
.
count
;
i
++
){
XTensor
*
s
=
(
XTensor
*
)
smalls
.
Get
(
i
);
XLink
::
MakeLink
(
&
big
,
NULL
,
s
,
SHAPE_SPLIT_LIST
);
XLink
::
AddParamToHeadInt
(
s
,
whereToSplit
);
/* it is tricky here that we keep the id of each
block, rather than the total number of splits */
XLink
::
AddParamToHeadInt
(
s
,
i
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论