Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
d221ef9d
Commit
d221ef9d
authored
Oct 12, 2019
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
merge with zhangyuhao branch
parent
e0d86e5b
隐藏空白字符变更
内嵌
并排
正在显示
49 个修改的文件
包含
555 行增加
和
243 行删除
+555
-243
source/tensor/XLink.cpp
+9
-0
source/tensor/XList.cpp
+50
-12
source/tensor/XList.h
+22
-5
source/tensor/XTensor.cpp
+44
-0
source/tensor/XTensor.h
+6
-0
source/tensor/core/arithmetic/Div.cpp
+12
-8
source/tensor/core/arithmetic/DivDim.cpp
+7
-5
source/tensor/core/arithmetic/Mask.cpp
+5
-3
source/tensor/core/arithmetic/MatrixMul.cpp
+14
-10
source/tensor/core/arithmetic/MatrixMulBatched.cpp
+12
-8
source/tensor/core/arithmetic/MulAndShift.cpp
+81
-5
source/tensor/core/arithmetic/MulAndShift.h
+3
-0
source/tensor/core/arithmetic/Multiply.cpp
+12
-8
source/tensor/core/arithmetic/MultiplyDim.cpp
+11
-7
source/tensor/core/arithmetic/Sub.cpp
+11
-7
source/tensor/core/arithmetic/SubDim.cpp
+6
-4
source/tensor/core/arithmetic/Sum.cpp
+15
-11
source/tensor/core/arithmetic/SumDim.cpp
+12
-8
source/tensor/core/getandset/ConvertDataType.cpp
+20
-19
source/tensor/core/getandset/Select.cpp
+6
-4
source/tensor/core/math/Binary.cpp
+4
-2
source/tensor/core/math/Clip.cpp
+14
-12
source/tensor/core/math/Normalize.cpp
+6
-4
source/tensor/core/math/ScaleAndShift.cpp
+6
-4
source/tensor/core/math/Unary.cpp
+4
-2
source/tensor/core/movement/CopyIndexed.cpp
+15
-11
source/tensor/core/movement/CopyValues.cpp
+3
-1
source/tensor/core/movement/Gather.cpp
+5
-2
source/tensor/core/reduce/ReduceMax.cpp
+5
-3
source/tensor/core/reduce/ReduceMean.cpp
+5
-3
source/tensor/core/reduce/ReduceSum.cpp
+14
-10
source/tensor/core/reduce/ReduceSumSquared.cpp
+5
-3
source/tensor/core/reduce/ReduceVariance.cpp
+5
-3
source/tensor/core/shape/Concatenate.cpp
+18
-10
source/tensor/core/shape/Merge.cpp
+14
-8
source/tensor/core/shape/Reshape.cpp
+4
-2
source/tensor/core/shape/Split.cpp
+14
-9
source/tensor/core/shape/Squeeze.cpp
+4
-2
source/tensor/core/shape/Transpose.cpp
+5
-3
source/tensor/core/shape/Unsqueeze.cpp
+6
-4
source/tensor/function/DropoutWithIndex.cpp
+4
-2
source/tensor/function/HardTanH.cpp
+4
-2
source/tensor/function/Identity.cpp
+4
-2
source/tensor/function/LogSoftmax.cpp
+5
-3
source/tensor/function/Rectify.cpp
+4
-2
source/tensor/function/Sigmoid.cpp
+4
-2
source/tensor/function/Softmax.cpp
+5
-3
source/tensor/loss/CrossEntropy.cpp
+10
-4
source/tensor/test/TSetData.cpp
+1
-1
没有找到文件。
source/tensor/XLink.cpp
查看文件 @
d221ef9d
...
@@ -300,6 +300,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id
...
@@ -300,6 +300,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id
if
(
h
==
NULL
)
if
(
h
==
NULL
)
return
;
return
;
if
(
!
t1
->
enableGrad
)
return
;
TensorList
list
(
2
);
TensorList
list
(
2
);
list
.
Add
((
XTensor
*
)
t1
);
list
.
Add
((
XTensor
*
)
t1
);
list
.
Add
((
XTensor
*
)
t2
);
list
.
Add
((
XTensor
*
)
t2
);
...
@@ -320,6 +323,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, const XTensor * t3,
...
@@ -320,6 +323,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, const XTensor * t3,
if
(
h
==
NULL
)
if
(
h
==
NULL
)
return
;
return
;
if
(
!
t1
->
enableGrad
||
!
t2
->
enableGrad
)
return
;
TensorList
list
(
3
);
TensorList
list
(
3
);
list
.
Add
((
XTensor
*
)
t1
);
list
.
Add
((
XTensor
*
)
t1
);
list
.
Add
((
XTensor
*
)
t2
);
list
.
Add
((
XTensor
*
)
t2
);
...
@@ -370,6 +376,9 @@ create a hyper edge with a input tensors and a list of output tensors
...
@@ -370,6 +376,9 @@ create a hyper edge with a input tensors and a list of output tensors
*/
*/
void
XLink
::
MakeLink
(
XTensor
*
t
,
TensorList
*
list
,
int
id
)
void
XLink
::
MakeLink
(
XTensor
*
t
,
TensorList
*
list
,
int
id
)
{
{
if
(
!
t
->
enableGrad
)
return
;
/* forward */
/* forward */
for
(
int
i
=
0
;
i
<
list
->
count
;
i
++
){
for
(
int
i
=
0
;
i
<
list
->
count
;
i
++
){
XTensor
*
h
=
(
XTensor
*
)
list
->
GetItem
(
i
);
XTensor
*
h
=
(
XTensor
*
)
list
->
GetItem
(
i
);
...
...
source/tensor/XList.cpp
查看文件 @
d221ef9d
...
@@ -23,15 +23,11 @@
...
@@ -23,15 +23,11 @@
*
*
*/
*/
#include "
XList
.h"
#include "
time
.h"
#include "XMem.h"
#include "XMem.h"
#include "XList.h"
#include "XGlobal.h"
#include "XGlobal.h"
#include <ctime>
#include <utility>
#include <algorithm>
/* the nts (NiuTrans.Tensor) namespace */
/* the nts (NiuTrans.Tensor) namespace */
namespace
nts
{
namespace
nts
{
...
@@ -78,7 +74,8 @@ TensorListBase<T>::TensorListBase(int myMaxNum, XMem* myMem)
...
@@ -78,7 +74,8 @@ TensorListBase<T>::TensorListBase(int myMaxNum, XMem* myMem)
template
<
typename
T
>
template
<
typename
T
>
TensorListBase
<
T
>::~
TensorListBase
()
TensorListBase
<
T
>::~
TensorListBase
()
{
{
delete
[]
items
;
if
(
items
&&
mem
)
delete
[]
items
;
}
}
...
@@ -103,6 +100,13 @@ void TensorListBase<T>::Add(T&& item)
...
@@ -103,6 +100,13 @@ void TensorListBase<T>::Add(T&& item)
items
[
count
++
]
=
item
;
items
[
count
++
]
=
item
;
}
}
/* return number of elements */
template
<
typename
T
>
size_t
TensorListBase
<
T
>::
Size
()
{
return
count
;
}
/*
/*
add an item into the list
add an item into the list
>> item - a const reference to the item
>> item - a const reference to the item
...
@@ -130,7 +134,7 @@ add a number of items into the list
...
@@ -130,7 +134,7 @@ add a number of items into the list
>> inputItemCount - number of input items
>> inputItemCount - number of input items
*/
*/
template
<
typename
T
>
template
<
typename
T
>
void
TensorListBase
<
T
>::
Add
(
T
*
inputItems
,
int
inputItemCount
)
void
TensorListBase
<
T
>::
Add
(
const
T
*
inputItems
,
int
inputItemCount
)
{
{
if
(
count
+
inputItemCount
>=
maxNum
)
{
if
(
count
+
inputItemCount
>=
maxNum
)
{
int
newMaxNum
=
(
count
+
inputItemCount
)
*
2
+
1
;
int
newMaxNum
=
(
count
+
inputItemCount
)
*
2
+
1
;
...
@@ -206,10 +210,10 @@ void TensorListBase<T>::Insert(int pos, T&& item)
...
@@ -206,10 +210,10 @@ void TensorListBase<T>::Insert(int pos, T&& item)
template
<
typename
T
>
template
<
typename
T
>
T
&
TensorListBase
<
T
>::
GetItem
(
int
i
)
const
T
&
TensorListBase
<
T
>::
GetItem
(
int
i
)
const
{
{
CheckNTErrors
(
i
>=
-
1
&&
i
<
count
,
"Index of a list item is out of scope!"
);
CheckNTErrors
(
i
>=
-
count
&&
i
<
count
,
"Index of a list item is out of scope!"
);
CheckNTErrors
(
count
>
0
,
"Cannt index the item in an empty list!"
);
CheckNTErrors
(
count
>
0
,
"Cannt index the item in an empty list!"
);
if
(
i
==
-
1
)
if
(
i
<
0
)
return
items
[
count
-
1
];
return
items
[
count
+
i
];
else
else
return
items
[
i
];
return
items
[
i
];
}
}
...
@@ -226,7 +230,7 @@ template<typename T>
...
@@ -226,7 +230,7 @@ template<typename T>
inline
void
TensorListBase
<
T
>::
SetItem
(
int
i
,
T
&&
item
)
inline
void
TensorListBase
<
T
>::
SetItem
(
int
i
,
T
&&
item
)
{
{
if
(
i
>=
0
&&
i
<
count
)
if
(
i
>=
0
&&
i
<
count
)
items
[
i
]
=
std
::
move
(
item
)
;
items
[
i
]
=
item
;
}
}
/*
/*
...
@@ -245,6 +249,26 @@ inline int TensorListBase<T>::FindFirst(const T& item)
...
@@ -245,6 +249,26 @@ inline int TensorListBase<T>::FindFirst(const T& item)
return
-
1
;
return
-
1
;
}
}
template
<>
inline
int
TensorListBase
<
Example
>::
FindFirst
(
const
Example
&
item
)
{
for
(
int
i
=
0
;
i
<
count
;
i
++
)
{
if
(
item
.
id
==
items
[
i
].
id
)
return
i
;
}
return
-
1
;
}
template
<>
inline
int
TensorListBase
<
Result
>::
FindFirst
(
const
Result
&
item
)
{
for
(
int
i
=
0
;
i
<
count
;
i
++
)
{
if
(
item
.
id
==
items
[
i
].
id
)
return
i
;
}
return
-
1
;
}
/* clear the data array */
/* clear the data array */
template
<
typename
T
>
template
<
typename
T
>
void
TensorListBase
<
T
>::
Clear
()
void
TensorListBase
<
T
>::
Clear
()
...
@@ -294,6 +318,17 @@ void TensorListBase<T>::Remove(int i)
...
@@ -294,6 +318,17 @@ void TensorListBase<T>::Remove(int i)
count
--
;
count
--
;
}
}
template
<
typename
T
>
void
TensorListBase
<
T
>::
Reserve
(
int
n
)
{
if
(
items
)
{
/* reserve failed */
return
;
}
items
=
new
T
[
n
];
}
/*
/*
copy the list
copy the list
>> myMem - memory pool used for allocating the data in the new list
>> myMem - memory pool used for allocating the data in the new list
...
@@ -348,6 +383,8 @@ template struct TensorListBase<long>;
...
@@ -348,6 +383,8 @@ template struct TensorListBase<long>;
template
struct
TensorListBase
<
float
>
;
template
struct
TensorListBase
<
float
>
;
template
struct
TensorListBase
<
short
>
;
template
struct
TensorListBase
<
short
>
;
template
struct
TensorListBase
<
XTensor
*>
;
template
struct
TensorListBase
<
XTensor
*>
;
template
struct
TensorListBase
<
Result
>
;
template
struct
TensorListBase
<
Example
>
;
template
struct
TensorListBase
<
void
*>
;
template
struct
TensorListBase
<
void
*>
;
}
/* end of the nts (NiuTrans.Tensor) namespace */
}
/* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
source/tensor/XList.h
查看文件 @
d221ef9d
...
@@ -66,11 +66,14 @@ public:
...
@@ -66,11 +66,14 @@ public:
/* add an item into the list */
/* add an item into the list */
void
Add
(
T
&&
item
);
void
Add
(
T
&&
item
);
/* return number of elements */
size_t
Size
();
/* add an item into the list */
/* add an item into the list */
void
Add
(
const
T
&
item
);
void
Add
(
const
T
&
item
);
/* add a number of items into the list */
/* add a number of items into the list */
void
Add
(
T
*
inputItems
,
int
inputItemCount
);
void
Add
(
const
T
*
inputItems
,
int
inputItemCount
);
/* append a list to the current list */
/* append a list to the current list */
void
AddList
(
TensorListBase
*
l
);
void
AddList
(
TensorListBase
*
l
);
...
@@ -105,6 +108,9 @@ public:
...
@@ -105,6 +108,9 @@ public:
/* remove the item at position i */
/* remove the item at position i */
void
Remove
(
int
i
);
void
Remove
(
int
i
);
/* reserve space for data entry */
void
Reserve
(
int
n
);
/* copy the list */
/* copy the list */
TensorListBase
*
Copy
(
XMem
*
myMem
);
TensorListBase
*
Copy
(
XMem
*
myMem
);
...
@@ -112,22 +118,33 @@ public:
...
@@ -112,22 +118,33 @@ public:
void
Shuffle
(
int
nround
=
10
,
int
beg
=
-
1
,
int
len
=
0
);
void
Shuffle
(
int
nround
=
10
,
int
beg
=
-
1
,
int
len
=
0
);
/* short */
/* short */
T
&
operator
[]
(
int
i
)
{
T
&
operator
[]
(
int
i
)
{
return
GetItem
(
i
);
};
return
GetItem
(
i
);
};
T
&
Get
(
int
i
)
{
return
GetItem
(
i
);
};
T
&
Get
(
int
i
)
{
return
GetItem
(
i
);
};
void
Set
(
int
i
,
T
item
)
{
SetItem
(
i
,
item
);
};
void
Set
(
int
i
,
T
item
)
{
SetItem
(
i
,
item
);
};
};
};
struct
XTensor
;
struct
XTensor
;
typedef
TensorListBase
<
void
*>
XList
;
typedef
TensorListBase
<
int
>
IntList
;
typedef
TensorListBase
<
int
>
IntList
;
typedef
TensorListBase
<
char
>
CharList
;
typedef
TensorListBase
<
char
>
CharList
;
typedef
TensorListBase
<
char
*>
StrList
;
typedef
TensorListBase
<
char
*>
StrList
;
typedef
TensorListBase
<
long
>
LongList
;
typedef
TensorListBase
<
long
>
LongList
;
typedef
TensorListBase
<
float
>
FloatList
;
typedef
TensorListBase
<
float
>
FloatList
;
typedef
TensorListBase
<
short
>
ShortList
;
typedef
TensorListBase
<
short
>
ShortList
;
typedef
TensorListBase
<
void
*>
XList
;
struct
Example
{
int
id
;
IntList
data
;
};
struct
Result
{
int
id
;
IntList
data
;
};
typedef
TensorListBase
<
Result
>
ResultList
;
typedef
TensorListBase
<
Example
>
ExampleList
;
typedef
TensorListBase
<
XTensor
*>
TensorList
;
typedef
TensorListBase
<
XTensor
*>
TensorList
;
}
/* end of the nts (NiuTrans.Tensor) namespace */
}
/* end of the nts (NiuTrans.Tensor) namespace */
...
...
source/tensor/XTensor.cpp
查看文件 @
d221ef9d
...
@@ -1916,6 +1916,26 @@ void XTensor::Dump(const XTensor * tensor, FILE * file, const char * label, cons
...
@@ -1916,6 +1916,26 @@ void XTensor::Dump(const XTensor * tensor, FILE * file, const char * label, cons
}
}
/*
/*
dump data to a binary file
>> file - where to dump the data
*/
void
XTensor
::
BinaryDump
(
FILE
*
file
)
{
XTensor
tmp
;
InitTensorOnCPU
(
&
tmp
,
this
);
_CopyValues
(
this
,
&
tmp
);
switch
(
dataType
)
{
case
X_INT
:
{
fwrite
(
tmp
.
data
,
sizeof
(
int
),
unitNum
,
file
);
}
default
:
{
fwrite
(
tmp
.
data
,
sizeof
(
float
),
unitNum
,
file
);
}
}
}
/*
read data from a file
read data from a file
>> file - where to load the data
>> file - where to load the data
>> label - label of the tensor
>> label - label of the tensor
...
@@ -2027,6 +2047,30 @@ void XTensor::Read(FILE * file, const char * label)
...
@@ -2027,6 +2047,30 @@ void XTensor::Read(FILE * file, const char * label)
delete
[](
char
*
)
dataBuf
;
delete
[](
char
*
)
dataBuf
;
}
}
/*
read data from a binary file
>>> file - the file stream pointer
>>> offset - the distance from the start to this tensor
*/
void
XTensor
::
BinaryRead
(
FILE
*
file
,
size_t
offset
)
{
fseek
(
file
,
offset
,
0
);
switch
(
dataType
)
{
case
X_INT
:
{
int
*
d
=
new
int
[
unitNum
];
fread
(
d
,
sizeof
(
int
),
unitNum
,
file
);
SetData
(
d
,
unitNum
);
delete
[]
d
;
}
default
:
{
float
*
d
=
new
float
[
unitNum
];
fread
(
d
,
sizeof
(
float
),
unitNum
,
file
);
SetData
(
d
,
unitNum
);
delete
[]
d
;
}
}
}
/*
/*
flush the data to the target device
flush the data to the target device
>> targetMem - memory pool on the target device
>> targetMem - memory pool on the target device
...
...
source/tensor/XTensor.h
查看文件 @
d221ef9d
...
@@ -433,9 +433,15 @@ public:
...
@@ -433,9 +433,15 @@ public:
static
static
void
Dump
(
const
XTensor
*
tensor
,
FILE
*
file
,
const
char
*
label
=
NULL
,
const
int
n
=
-
1
,
const
int
beg
=
0
,
const
int
verbose
=
0
);
void
Dump
(
const
XTensor
*
tensor
,
FILE
*
file
,
const
char
*
label
=
NULL
,
const
int
n
=
-
1
,
const
int
beg
=
0
,
const
int
verbose
=
0
);
/* dump data to a binary file */
void
BinaryDump
(
FILE
*
file
);
/* read data from a file */
/* read data from a file */
void
Read
(
FILE
*
file
,
const
char
*
label
=
NULL
);
void
Read
(
FILE
*
file
,
const
char
*
label
=
NULL
);
/* read data from a binary file */
void
BinaryRead
(
FILE
*
file
,
size_t
offset
);
/* flush the data to the target device */
/* flush the data to the target device */
void
FlushToMem
(
XMem
*
targetMem
);
void
FlushToMem
(
XMem
*
targetMem
);
...
...
source/tensor/core/arithmetic/Div.cpp
查看文件 @
d221ef9d
...
@@ -215,18 +215,22 @@ XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
...
@@ -215,18 +215,22 @@ XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
_Div
(
&
a
,
&
b
,
&
c
,
alpha
,
leadingDim
);
_Div
(
&
a
,
&
b
,
&
c
,
alpha
,
leadingDim
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_DIV
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_DIV
);
XLink
::
AddParamToHeadInt
(
&
c
,
leadingDim
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHeadInt
(
&
c
,
leadingDim
);
}
}
}
else
if
(
n
>=
0
&&
n
<
a
.
order
){
else
if
(
n
>=
0
&&
n
<
a
.
order
){
/* call _DivDim function */
/* call _DivDim function */
_DivDim
(
&
a
,
&
b
,
&
c
,
n
,
alpha
);
_DivDim
(
&
a
,
&
b
,
&
c
,
n
,
alpha
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_DIVDIM
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_DIVDIM
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
}
}
}
else
{
else
{
ShowNTErrors
(
"Something is wrong!"
);
ShowNTErrors
(
"Something is wrong!"
);
...
@@ -261,7 +265,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
...
@@ -261,7 +265,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
/* call _Div function */
/* call _Div function */
_Div
(
&
a
,
&
b
,
&
c
,
0
,
leadingDim
);
_Div
(
&
a
,
&
b
,
&
c
,
0
,
leadingDim
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_DIV
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_DIV
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
...
@@ -272,7 +276,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
...
@@ -272,7 +276,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
/* call _DivDim function */
/* call _DivDim function */
_DivDim
(
&
a
,
&
b
,
&
c
,
n
,
alpha
);
_DivDim
(
&
a
,
&
b
,
&
c
,
n
,
alpha
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_DIVDIM
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_DIVDIM
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
...
...
source/tensor/core/arithmetic/DivDim.cpp
查看文件 @
d221ef9d
...
@@ -164,10 +164,12 @@ XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
...
@@ -164,10 +164,12 @@ XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
_DivDim
(
&
a
,
&
b
,
&
c
,
n
,
alpha
);
_DivDim
(
&
a
,
&
b
,
&
c
,
n
,
alpha
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_DIVDIM
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_DIVDIM
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
}
return
c
;
return
c
;
}
}
...
@@ -193,7 +195,7 @@ void DivDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE alpha)
...
@@ -193,7 +195,7 @@ void DivDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE alpha)
/* call _Div function */
/* call _Div function */
_DivDim
(
&
a
,
&
b
,
&
c
,
n
,
alpha
);
_DivDim
(
&
a
,
&
b
,
&
c
,
n
,
alpha
);
if
(
c
.
enableGrad
==
true
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_DIVDIM
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_DIVDIM
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
...
...
source/tensor/core/arithmetic/Mask.cpp
查看文件 @
d221ef9d
...
@@ -155,8 +155,10 @@ XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha)
...
@@ -155,8 +155,10 @@ XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha)
_Mask
(
&
a
,
&
mask
,
&
c
,
alpha
);
_Mask
(
&
a
,
&
mask
,
&
c
,
alpha
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
mask
,
&
c
,
MATH_MASK
);
if
(
a
.
enableGrad
)
{
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
MakeLink
(
&
a
,
&
mask
,
&
c
,
MATH_MASK
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
}
return
c
;
return
c
;
}
}
...
@@ -176,7 +178,7 @@ void Mask(const XTensor &a, const XTensor &mask, XTensor &c, DTYPE alpha)
...
@@ -176,7 +178,7 @@ void Mask(const XTensor &a, const XTensor &mask, XTensor &c, DTYPE alpha)
/* call _Mask function */
/* call _Mask function */
_Mask
(
&
a
,
&
mask
,
&
c
,
alpha
);
_Mask
(
&
a
,
&
mask
,
&
c
,
alpha
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
)
{
XLink
::
MakeLink
(
&
a
,
&
mask
,
&
c
,
MATH_MASK
);
XLink
::
MakeLink
(
&
a
,
&
mask
,
&
c
,
MATH_MASK
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
}
}
...
...
source/tensor/core/arithmetic/MatrixMul.cpp
查看文件 @
d221ef9d
...
@@ -296,10 +296,12 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
...
@@ -296,10 +296,12 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
_MatrixMul
(
&
a
,
transposedA
,
&
b
,
transposedB
,
&
c
,
alpha
,
0
,
parallelRunner
);
_MatrixMul
(
&
a
,
transposedA
,
&
b
,
transposedB
,
&
c
,
alpha
,
0
,
parallelRunner
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMUL
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedA
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMUL
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedB
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedA
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedB
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -344,7 +346,7 @@ void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
...
@@ -344,7 +346,7 @@ void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
/* call _MatrixMul function */
/* call _MatrixMul function */
_MatrixMul
(
&
a
,
transposedA
,
&
b
,
transposedB
,
&
c
,
alpha
,
beta
,
parallelRunner
);
_MatrixMul
(
&
a
,
transposedA
,
&
b
,
transposedB
,
&
c
,
alpha
,
beta
,
parallelRunner
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMUL
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMUL
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedA
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedA
);
...
@@ -393,10 +395,12 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b,
...
@@ -393,10 +395,12 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b,
_MatrixMul
(
&
a
,
X_NOTRANS
,
&
b
,
X_NOTRANS
,
&
c
,
alpha
,
0
,
parallelRunner
);
_MatrixMul
(
&
a
,
X_NOTRANS
,
&
b
,
X_NOTRANS
,
&
c
,
alpha
,
0
,
parallelRunner
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMUL
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMUL
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -440,7 +444,7 @@ void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c,
...
@@ -440,7 +444,7 @@ void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c,
/* call _MatrixMul function */
/* call _MatrixMul function */
_MatrixMul
(
&
a
,
X_NOTRANS
,
&
b
,
X_NOTRANS
,
&
c
,
alpha
,
0
,
parallelRunner
);
_MatrixMul
(
&
a
,
X_NOTRANS
,
&
b
,
X_NOTRANS
,
&
c
,
alpha
,
0
,
parallelRunner
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMUL
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMUL
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
...
...
source/tensor/core/arithmetic/MatrixMulBatched.cpp
查看文件 @
d221ef9d
...
@@ -314,10 +314,12 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const
...
@@ -314,10 +314,12 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const
_MatrixMulBatched
(
&
a
,
transposedA
,
&
b
,
transposedB
,
&
c
,
alpha
,
0
,
parallelRunner
);
_MatrixMulBatched
(
&
a
,
transposedA
,
&
b
,
transposedB
,
&
c
,
alpha
,
0
,
parallelRunner
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMULBATCHED
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedA
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMULBATCHED
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedB
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedA
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedB
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -370,10 +372,12 @@ XTensor MatrixMulBatched(const XTensor &a, const XTensor &b,
...
@@ -370,10 +372,12 @@ XTensor MatrixMulBatched(const XTensor &a, const XTensor &b,
_MatrixMulBatched
(
&
a
,
X_NOTRANS
,
&
b
,
X_NOTRANS
,
&
c
,
alpha
,
0
,
parallelRunner
);
_MatrixMulBatched
(
&
a
,
X_NOTRANS
,
&
b
,
X_NOTRANS
,
&
c
,
alpha
,
0
,
parallelRunner
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMULBATCHED
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMULBATCHED
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
...
source/tensor/core/arithmetic/MulAndShift.cpp
查看文件 @
d221ef9d
...
@@ -118,11 +118,87 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
...
@@ -118,11 +118,87 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
}
}
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
x
,
&
w
,
&
b
,
&
c
,
MATH_MULANDSHIFT
);
if
(
w
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
MakeLink
(
&
x
,
&
w
,
&
b
,
&
c
,
MATH_MULANDSHIFT
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
//XLink::AddParamToHead(&c, beta);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
}
/* destroy variables */
delete
[]
dimSize
;
DelTensorBuf
(
tmp
);
return
c
;
}
/*
operation c = x * w + b MulAndShift
>> x - tensor x
>> w - tensor w
>> b - tensor b
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
*/
XTensor
MulAndShift
(
const
XTensor
&
x
,
MATRIX_TRANS_TYPE
transposedA
,
const
XTensor
&
w
,
MATRIX_TRANS_TYPE
transposedB
,
const
XTensor
&
b
,
DTYPE
alpha
,
XPRunner
*
parallelRunner
)
{
CheckNTErrors
(
x
.
dataType
==
w
.
dataType
,
"Input tensors should have the same data type!"
);
CheckNTErrors
(
x
.
order
>=
2
&&
w
.
order
>=
2
,
"Input tensors must have a order >= 2!"
);
int
xn
=
transposedA
==
X_TRANS
?
x
.
dimSizeRDI
[
0
]
:
x
.
dimSizeRDI
[
1
];
int
xm
=
transposedA
==
X_TRANS
?
x
.
dimSizeRDI
[
1
]
:
x
.
dimSizeRDI
[
0
];
int
wn
=
transposedB
==
X_TRANS
?
w
.
dimSizeRDI
[
0
]
:
w
.
dimSizeRDI
[
1
];
int
wm
=
transposedB
==
X_TRANS
?
w
.
dimSizeRDI
[
1
]
:
w
.
dimSizeRDI
[
0
];
int
order
=
x
.
order
+
w
.
order
-
2
;
int
sub
=
0
;
int
*
dimSize
=
new
int
[
order
];
for
(
int
i
=
2
;
i
<
x
.
order
;
i
++
)
dimSize
[
sub
++
]
=
x
.
dimSizeRDI
[
x
.
order
+
1
-
i
];
for
(
int
i
=
2
;
i
<
w
.
order
;
i
++
)
dimSize
[
sub
++
]
=
w
.
dimSizeRDI
[
w
.
order
+
1
-
i
];
dimSize
[
sub
++
]
=
xn
;
dimSize
[
sub
++
]
=
wm
;
float
dr
=
(
!
x
.
isSparse
||
!
w
.
isSparse
)
?
1.0
F
:
MAX
(
x
.
denseRatio
,
w
.
denseRatio
);
XTensor
*
tmp
=
NewTensorBuf
(
order
,
dimSize
,
x
.
dataType
,
dr
,
x
.
devID
,
x
.
mem
);
/* call _MatrixMul function */
_MatrixMul
(
&
x
,
transposedA
,
&
w
,
transposedB
,
tmp
,
alpha
,
0
,
parallelRunner
);
XTensor
c
(
tmp
);
c
.
SetTMPFlag
();
int
n
=
GetSumIndex
(
tmp
,
b
);
if
(
n
==
-
1
)
{
/* call _Sum function */
_Sum
(
tmp
,
&
b
,
&
c
);
// TODO!!
ShowNTErrors
(
"TODO!"
);
}
else
if
(
n
>=
0
&&
n
<
tmp
->
order
)
{
/* call _SumDim function */
_SumDim
(
tmp
,
&
b
,
&
c
,
n
);
}
else
{
ShowNTErrors
(
"Something is wrong!"
);
}
/* tensor connections */
if
(
w
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
MakeLink
(
&
x
,
&
w
,
&
b
,
&
c
,
MATH_MULANDSHIFT
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedA
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedB
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
...
source/tensor/core/arithmetic/MulAndShift.h
查看文件 @
d221ef9d
...
@@ -31,6 +31,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
...
@@ -31,6 +31,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
XTensor
MulAndShift
(
const
XTensor
&
x
,
const
XTensor
&
w
,
const
XTensor
&
b
,
XTensor
MulAndShift
(
const
XTensor
&
x
,
const
XTensor
&
w
,
const
XTensor
&
b
,
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
XPRunner
*
parallelRunner
=
NULL
);
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
XPRunner
*
parallelRunner
=
NULL
);
XTensor
MulAndShift
(
const
XTensor
&
x
,
MATRIX_TRANS_TYPE
transposedA
,
const
XTensor
&
w
,
MATRIX_TRANS_TYPE
transposedB
,
const
XTensor
&
b
,
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
XPRunner
*
parallelRunner
=
NULL
);
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
...
...
source/tensor/core/arithmetic/Multiply.cpp
查看文件 @
d221ef9d
...
@@ -216,18 +216,22 @@ XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim
...
@@ -216,18 +216,22 @@ XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim
_Multiply
(
&
a
,
&
b
,
&
c
,
0
,
leadingDim
);
_Multiply
(
&
a
,
&
b
,
&
c
,
0
,
leadingDim
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLY
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLY
);
XLink
::
AddParamToHeadInt
(
&
c
,
leadingDim
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHeadInt
(
&
c
,
leadingDim
);
}
}
}
else
if
(
n
>=
0
&&
n
<
a
.
order
){
else
if
(
n
>=
0
&&
n
<
a
.
order
){
/* call _MultiplyDim function */
/* call _MultiplyDim function */
_MultiplyDim
(
&
a
,
&
b
,
&
c
,
n
,
alpha
);
_MultiplyDim
(
&
a
,
&
b
,
&
c
,
n
,
alpha
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLYDIM
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLYDIM
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
}
}
}
else
{
else
{
ShowNTErrors
(
"Something is wrong!"
);
ShowNTErrors
(
"Something is wrong!"
);
...
@@ -262,7 +266,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
...
@@ -262,7 +266,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
/* call _Multiply function */
/* call _Multiply function */
_Multiply
(
&
a
,
&
b
,
&
c
,
0
,
leadingDim
);
_Multiply
(
&
a
,
&
b
,
&
c
,
0
,
leadingDim
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLY
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLY
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
...
@@ -273,7 +277,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
...
@@ -273,7 +277,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
/* call _MultiplyDim function */
/* call _MultiplyDim function */
_MultiplyDim
(
&
a
,
&
b
,
&
c
,
n
,
alpha
);
_MultiplyDim
(
&
a
,
&
b
,
&
c
,
n
,
alpha
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLYDIM
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLYDIM
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
...
...
source/tensor/core/arithmetic/MultiplyDim.cpp
查看文件 @
d221ef9d
...
@@ -180,9 +180,11 @@ XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n)
...
@@ -180,9 +180,11 @@ XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n)
_MultiplyDim
(
&
a
,
&
b
,
&
c
,
n
,
0
);
_MultiplyDim
(
&
a
,
&
b
,
&
c
,
n
,
0
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLYDIM
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLYDIM
);
XLink
::
AddParamToHead
(
&
c
,
0
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHead
(
&
c
,
0
);
}
return
c
;
return
c
;
}
}
...
@@ -208,7 +210,7 @@ void MultiplyDim(const XTensor &a, const XTensor &b, XTensor &c, int n)
...
@@ -208,7 +210,7 @@ void MultiplyDim(const XTensor &a, const XTensor &b, XTensor &c, int n)
/* call _Multiply function */
/* call _Multiply function */
_MultiplyDim
(
&
a
,
&
b
,
&
c
,
n
,
0
);
_MultiplyDim
(
&
a
,
&
b
,
&
c
,
n
,
0
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLYDIM
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLYDIM
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
...
@@ -350,8 +352,10 @@ XTensor MultiplyBroadcast(const XTensor &a, const XTensor &b)
...
@@ -350,8 +352,10 @@ XTensor MultiplyBroadcast(const XTensor &a, const XTensor &b)
_MultiplyBroadcast
(
&
a
,
&
b
,
&
c
,
0
);
_MultiplyBroadcast
(
&
a
,
&
b
,
&
c
,
0
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLYBROADCAST
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHead
(
&
c
,
0
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLYBROADCAST
);
XLink
::
AddParamToHead
(
&
c
,
0
);
}
return
c
;
return
c
;
}
}
...
@@ -374,7 +378,7 @@ void MultiplyBroadcast(const XTensor &a, const XTensor &b, XTensor &c)
...
@@ -374,7 +378,7 @@ void MultiplyBroadcast(const XTensor &a, const XTensor &b, XTensor &c)
/* call _SumBroadcast function */
/* call _SumBroadcast function */
_MultiplyBroadcast
(
&
a
,
&
b
,
&
c
,
0
);
_MultiplyBroadcast
(
&
a
,
&
b
,
&
c
,
0
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLYBROADCAST
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MULTIPLYBROADCAST
);
XLink
::
AddParamToHead
(
&
c
,
0
);
XLink
::
AddParamToHead
(
&
c
,
0
);
...
...
source/tensor/core/arithmetic/Sub.cpp
查看文件 @
d221ef9d
...
@@ -190,17 +190,21 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
...
@@ -190,17 +190,21 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
_Sub
(
&
a
,
&
b
,
&
c
,
beta
);
_Sub
(
&
a
,
&
b
,
&
c
,
beta
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUB
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHead
(
&
c
,
beta
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUB
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
}
}
}
else
if
(
n
>=
0
&&
n
<
a
.
order
){
else
if
(
n
>=
0
&&
n
<
a
.
order
){
/* call _SubDim function */
/* call _SubDim function */
_SubDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
_SubDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUBDIM
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUBDIM
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
}
}
}
else
{
else
{
ShowNTErrors
(
"Something is wrong!"
);
ShowNTErrors
(
"Something is wrong!"
);
...
@@ -229,7 +233,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
...
@@ -229,7 +233,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
/* call _Sub function */
/* call _Sub function */
_Sub
(
&
a
,
&
b
,
&
c
,
beta
);
_Sub
(
&
a
,
&
b
,
&
c
,
beta
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUB
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUB
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
...
@@ -239,7 +243,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
...
@@ -239,7 +243,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
/* call _SubDim function */
/* call _SubDim function */
_SubDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
_SubDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUBDIM
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUBDIM
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
...
...
source/tensor/core/arithmetic/SubDim.cpp
查看文件 @
d221ef9d
...
@@ -164,9 +164,11 @@ XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
...
@@ -164,9 +164,11 @@ XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
_SubDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
_SubDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUBDIM
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUBDIM
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
}
return
c
;
return
c
;
}
}
...
@@ -193,7 +195,7 @@ void SubDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta)
...
@@ -193,7 +195,7 @@ void SubDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta)
/* call _Sub function */
/* call _Sub function */
_SubDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
_SubDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUBDIM
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUBDIM
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
...
...
source/tensor/core/arithmetic/Sum.cpp
查看文件 @
d221ef9d
...
@@ -224,17 +224,21 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
...
@@ -224,17 +224,21 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
_Sum
(
&
a
,
&
b
,
&
c
,
beta
);
_Sum
(
&
a
,
&
b
,
&
c
,
beta
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUM
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHead
(
&
c
,
beta
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUM
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
}
}
}
else
if
(
n
>=
0
&&
n
<
a
.
order
){
else
if
(
n
>=
0
&&
n
<
a
.
order
){
/* call _SumDim function */
/* call _SumDim function */
_SumDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
_SumDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMDIM
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMDIM
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
}
}
}
else
{
else
{
ShowNTErrors
(
"Something is wrong!"
);
ShowNTErrors
(
"Something is wrong!"
);
...
@@ -261,9 +265,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
...
@@ -261,9 +265,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
if
(
n
==
-
1
)
{
if
(
n
==
-
1
)
{
/* call _Sum function */
/* call _Sum function */
_Sum
(
&
a
,
&
b
,
&
c
,
beta
);
_Sum
(
&
a
,
&
b
,
&
c
,
beta
);
if
(
c
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUM
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUM
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
}
}
...
@@ -271,9 +275,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
...
@@ -271,9 +275,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
else
if
(
n
>=
0
&&
n
<
a
.
order
)
{
else
if
(
n
>=
0
&&
n
<
a
.
order
)
{
/* call _SumDim function */
/* call _SumDim function */
_SumDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
_SumDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
if
(
c
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMDIM
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMDIM
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
...
...
source/tensor/core/arithmetic/SumDim.cpp
查看文件 @
d221ef9d
...
@@ -181,9 +181,11 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
...
@@ -181,9 +181,11 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
_SumDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
_SumDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMDIM
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMDIM
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
}
return
c
;
return
c
;
}
}
...
@@ -210,7 +212,7 @@ void SumDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta)
...
@@ -210,7 +212,7 @@ void SumDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta)
/* call _SumDim function */
/* call _SumDim function */
_SumDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
_SumDim
(
&
a
,
&
b
,
&
c
,
n
,
beta
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMDIM
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMDIM
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
...
@@ -353,9 +355,11 @@ XTensor SumBroadcast(const XTensor &a, const XTensor &b, DTYPE beta)
...
@@ -353,9 +355,11 @@ XTensor SumBroadcast(const XTensor &a, const XTensor &b, DTYPE beta)
_SumBroadcast
(
&
a
,
&
b
,
&
c
,
beta
);
_SumBroadcast
(
&
a
,
&
b
,
&
c
,
beta
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMBROADCAST
);
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
XLink
::
AddParamToHead
(
&
c
,
beta
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMBROADCAST
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
}
return
c
;
return
c
;
}
}
...
@@ -377,7 +381,7 @@ void SumBroadcast(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
...
@@ -377,7 +381,7 @@ void SumBroadcast(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
/* call _SumBroadcast function */
/* call _SumBroadcast function */
_SumBroadcast
(
&
a
,
&
b
,
&
c
,
beta
);
_SumBroadcast
(
&
a
,
&
b
,
&
c
,
beta
);
if
(
c
.
enableGrad
)
{
if
(
a
.
enableGrad
&&
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMBROADCAST
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMBROADCAST
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
...
...
source/tensor/core/getandset/ConvertDataType.cpp
查看文件 @
d221ef9d
/* NiuTrans.Tensor - an open-source tensor library
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
* All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing, software
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License.
* limitations under the License.
*/
*/
/*
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
*/
#include "../../XTensor.h"
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XName.h"
...
@@ -121,7 +121,8 @@ XTensor ConvertDataType(const XTensor & input, TENSOR_DATA_TYPE dataType)
...
@@ -121,7 +121,8 @@ XTensor ConvertDataType(const XTensor & input, TENSOR_DATA_TYPE dataType)
_ConvertDataType
(
&
input
,
&
output
);
_ConvertDataType
(
&
input
,
&
output
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
GETANDSET_CONVERTDATATYPE
);
if
(
input
.
enableGrad
)
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
GETANDSET_CONVERTDATATYPE
);
return
output
;
return
output
;
}
}
...
@@ -136,7 +137,7 @@ void ConvertDataType(const XTensor & input, XTensor & output, TENSOR_DATA_TYPE d
...
@@ -136,7 +137,7 @@ void ConvertDataType(const XTensor & input, XTensor & output, TENSOR_DATA_TYPE d
_ConvertDataType
(
&
input
,
&
output
);
_ConvertDataType
(
&
input
,
&
output
);
/* tensor connection */
/* tensor connection */
if
(
out
put
.
enableGrad
)
if
(
in
put
.
enableGrad
)
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
GETANDSET_CONVERTDATATYPE
);
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
GETANDSET_CONVERTDATATYPE
);
}
}
...
...
source/tensor/core/getandset/Select.cpp
查看文件 @
d221ef9d
...
@@ -117,10 +117,12 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high)
...
@@ -117,10 +117,12 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high)
_SelectRange
(
&
a
,
&
c
,
dim
,
low
,
high
);
_SelectRange
(
&
a
,
&
c
,
dim
,
low
,
high
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
a
,
NULL
,
&
c
,
GETANDSET_SELECT
);
if
(
a
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
c
,
dim
);
XLink
::
MakeLink
(
&
a
,
NULL
,
&
c
,
GETANDSET_SELECT
);
XLink
::
AddParamToHeadInt
(
&
c
,
low
);
XLink
::
AddParamToHeadInt
(
&
c
,
dim
);
XLink
::
AddParamToHeadInt
(
&
c
,
high
);
XLink
::
AddParamToHeadInt
(
&
c
,
low
);
XLink
::
AddParamToHeadInt
(
&
c
,
high
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
...
source/tensor/core/math/Binary.cpp
查看文件 @
d221ef9d
...
@@ -167,7 +167,9 @@ XTensor funcName(const XTensor &a, T num)
...
@@ -167,7 +167,9 @@ XTensor funcName(const XTensor &a, T num)
XTensor b(&a); \
XTensor b(&a); \
b.SetTMPFlag(); \
b.SetTMPFlag(); \
_funcName(&a, &b, num); \
_funcName(&a, &b, num); \
XLink::MakeLink(&a, NULL, &b, operationId); \
if(a.enableGrad){ \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
XLink::AddParamToHead(&b, num); \
XLink::AddParamToHead(&b, num); \
return b; \
return b; \
} \
} \
...
@@ -183,7 +185,7 @@ void funcName(const XTensor &a, XTensor &b, T num)
...
@@ -183,7 +185,7 @@ void funcName(const XTensor &a, XTensor &b, T num)
InitTensor(&b, &a); \
InitTensor(&b, &a); \
} \
} \
_funcName(&a, &b, num); \
_funcName(&a, &b, num); \
if (
b
.enableGrad) { \
if (
a
.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \
XLink::MakeLink(&a, NULL, &b, operationId); \
XLink::AddParamToHead(&b, num); \
XLink::AddParamToHead(&b, num); \
} \
} \
...
...
source/tensor/core/math/Clip.cpp
查看文件 @
d221ef9d
...
@@ -67,7 +67,7 @@ keep the result in the input tensor a and return nothing
...
@@ -67,7 +67,7 @@ keep the result in the input tensor a and return nothing
*/
*/
void
_ClipMe
(
XTensor
*
a
,
DTYPE
lower
,
DTYPE
upper
)
void
_ClipMe
(
XTensor
*
a
,
DTYPE
lower
,
DTYPE
upper
)
{
{
_Clip
(
a
,
a
,
lower
,
upper
);
_Clip
(
a
,
a
,
lower
,
upper
);
}
}
/*
/*
...
@@ -92,18 +92,20 @@ make a new tensor to keep the result and return it
...
@@ -92,18 +92,20 @@ make a new tensor to keep the result and return it
*/
*/
XTensor
Clip
(
const
XTensor
&
a
,
DTYPE
lower
,
DTYPE
upper
)
XTensor
Clip
(
const
XTensor
&
a
,
DTYPE
lower
,
DTYPE
upper
)
{
{
XTensor
b
(
&
a
);
XTensor
b
(
&
a
);
b
.
SetTMPFlag
();
b
.
SetTMPFlag
();
/* call _Clip function */
/* call _Clip function */
_Clip
(
&
a
,
&
b
,
lower
,
upper
);
_Clip
(
&
a
,
&
b
,
lower
,
upper
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
MATH_CLIP
);
if
(
a
.
enableGrad
)
{
XLink
::
AddParamToHead
(
&
b
,
lower
);
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
MATH_CLIP
);
XLink
::
AddParamToHead
(
&
b
,
upper
);
XLink
::
AddParamToHead
(
&
b
,
lower
);
XLink
::
AddParamToHead
(
&
b
,
upper
);
}
return
b
;
return
b
;
}
}
void
Clip
(
const
XTensor
&
a
,
XTensor
&
b
,
DTYPE
lower
,
DTYPE
upper
)
void
Clip
(
const
XTensor
&
a
,
XTensor
&
b
,
DTYPE
lower
,
DTYPE
upper
)
...
@@ -115,8 +117,8 @@ void Clip(const XTensor & a, XTensor & b, DTYPE lower, DTYPE upper)
...
@@ -115,8 +117,8 @@ void Clip(const XTensor & a, XTensor & b, DTYPE lower, DTYPE upper)
/* call _Clip function */
/* call _Clip function */
_Clip
(
&
a
,
&
b
,
lower
,
upper
);
_Clip
(
&
a
,
&
b
,
lower
,
upper
);
if
(
b
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
if
(
a
.
enableGrad
)
{
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
MATH_CLIP
);
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
MATH_CLIP
);
XLink
::
AddParamToHead
(
&
b
,
lower
);
XLink
::
AddParamToHead
(
&
b
,
lower
);
XLink
::
AddParamToHead
(
&
b
,
upper
);
XLink
::
AddParamToHead
(
&
b
,
upper
);
...
...
source/tensor/core/math/Normalize.cpp
查看文件 @
d221ef9d
...
@@ -173,9 +173,11 @@ XTensor Normalize(const XTensor &input, int dim,
...
@@ -173,9 +173,11 @@ XTensor Normalize(const XTensor &input, int dim,
list
.
Add
((
XTensor
*
)
&
var
);
list
.
Add
((
XTensor
*
)
&
var
);
list
.
Add
((
XTensor
*
)
&
a
);
list
.
Add
((
XTensor
*
)
&
a
);
list
.
Add
((
XTensor
*
)
&
b
);
list
.
Add
((
XTensor
*
)
&
b
);
XLink
::
MakeLink
(
&
list
,
&
output
,
MATH_NORMALIZE
);
if
(
input
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
MakeLink
(
&
list
,
&
output
,
MATH_NORMALIZE
);
XLink
::
AddParamToHead
(
&
output
,
epsilon
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
AddParamToHead
(
&
output
,
epsilon
);
}
return
output
;
return
output
;
}
}
...
@@ -208,7 +210,7 @@ void Normalize(const XTensor &input, XTensor &output, int dim,
...
@@ -208,7 +210,7 @@ void Normalize(const XTensor &input, XTensor &output, int dim,
/* call _Normalize function */
/* call _Normalize function */
_Normalize
(
&
input
,
&
output
,
dim
,
&
mean
,
&
var
,
&
a
,
&
b
,
epsilon
);
_Normalize
(
&
input
,
&
output
,
dim
,
&
mean
,
&
var
,
&
a
,
&
b
,
epsilon
);
if
(
out
put
.
enableGrad
==
true
)
{
if
(
in
put
.
enableGrad
==
true
)
{
/* tensor connections */
/* tensor connections */
TensorList
list
(
5
);
TensorList
list
(
5
);
list
.
Add
((
XTensor
*
)
&
input
);
list
.
Add
((
XTensor
*
)
&
input
);
...
...
source/tensor/core/math/ScaleAndShift.cpp
查看文件 @
d221ef9d
...
@@ -126,9 +126,11 @@ XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift)
...
@@ -126,9 +126,11 @@ XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift)
_ScaleAndShift
(
&
a
,
&
b
,
scale
,
shift
);
_ScaleAndShift
(
&
a
,
&
b
,
scale
,
shift
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
MATH_SCALEANDSHIFT
);
if
(
a
.
enableGrad
)
{
XLink
::
AddParamToHead
(
&
b
,
scale
);
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
MATH_SCALEANDSHIFT
);
XLink
::
AddParamToHead
(
&
b
,
shift
);
XLink
::
AddParamToHead
(
&
b
,
scale
);
XLink
::
AddParamToHead
(
&
b
,
shift
);
}
return
b
;
return
b
;
}
}
...
@@ -152,7 +154,7 @@ void ScaleAndShift(const XTensor & a, XTensor & b, DTYPE scale, DTYPE shift)
...
@@ -152,7 +154,7 @@ void ScaleAndShift(const XTensor & a, XTensor & b, DTYPE scale, DTYPE shift)
/* call _ScaleAndShift function */
/* call _ScaleAndShift function */
_ScaleAndShift
(
&
a
,
&
b
,
scale
,
shift
);
_ScaleAndShift
(
&
a
,
&
b
,
scale
,
shift
);
if
(
b
.
enableGrad
)
{
if
(
a
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
MATH_SCALEANDSHIFT
);
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
MATH_SCALEANDSHIFT
);
XLink
::
AddParamToHead
(
&
b
,
scale
);
XLink
::
AddParamToHead
(
&
b
,
scale
);
...
...
source/tensor/core/math/Unary.cpp
查看文件 @
d221ef9d
...
@@ -151,7 +151,9 @@ XTensor funcName(const XTensor & a)
...
@@ -151,7 +151,9 @@ XTensor funcName(const XTensor & a)
XTensor b(&a); \
XTensor b(&a); \
b.SetTMPFlag(); \
b.SetTMPFlag(); \
_funcName(&a, &b); \
_funcName(&a, &b); \
XLink::MakeLink(&a, NULL, &b, operationId); \
if(a.enableGrad){ \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
return b; \
return b; \
}
}
...
@@ -162,7 +164,7 @@ void funcName(const XTensor & a, XTensor & b)
...
@@ -162,7 +164,7 @@ void funcName(const XTensor & a, XTensor & b)
InitTensor(&b, &a); \
InitTensor(&b, &a); \
} \
} \
_funcName(&a, &b); \
_funcName(&a, &b); \
if (
b
.enableGrad) { \
if (
a
.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
} \
}
}
...
...
source/tensor/core/movement/CopyIndexed.cpp
查看文件 @
d221ef9d
...
@@ -258,10 +258,12 @@ XTensor CopyIndexed(const XTensor & s, int dim,
...
@@ -258,10 +258,12 @@ XTensor CopyIndexed(const XTensor & s, int dim,
list
.
Add
((
XTensor
*
)
&
tgtIndex
);
list
.
Add
((
XTensor
*
)
&
tgtIndex
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
list
,
&
t
,
MOVEMENT_COPYINDEXED
);
if
(
s
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
t
,
dim
);
XLink
::
MakeLink
(
&
list
,
&
t
,
MOVEMENT_COPYINDEXED
);
XLink
::
AddParamToHeadInt
(
&
t
,
copyNum
);
XLink
::
AddParamToHeadInt
(
&
t
,
dim
);
XLink
::
AddParamToHeadInt
(
&
t
,
copyNum
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -314,13 +316,15 @@ XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, in
...
@@ -314,13 +316,15 @@ XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, in
memcpy
(
saveTgtIndex
,
tgtIndex
,
indexSize
*
sizeof
(
int
));
memcpy
(
saveTgtIndex
,
tgtIndex
,
indexSize
*
sizeof
(
int
));
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
MOVEMENT_COPYINDEXED
);
if
(
s
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
t
,
dim
);
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
MOVEMENT_COPYINDEXED
);
XLink
::
AddParamToHeadPointer
(
&
t
,
saveSrcIndex
);
XLink
::
AddParamToHeadInt
(
&
t
,
dim
);
XLink
::
AddParamToHeadInt
(
&
t
,
indexSize
);
XLink
::
AddParamToHeadPointer
(
&
t
,
saveSrcIndex
);
XLink
::
AddParamToHeadPointer
(
&
t
,
saveTgtIndex
);
XLink
::
AddParamToHeadInt
(
&
t
,
indexSize
);
XLink
::
AddParamToHeadInt
(
&
t
,
copyNum
);
XLink
::
AddParamToHeadPointer
(
&
t
,
saveTgtIndex
);
XLink
::
AddParamToHeadInt
(
&
t
,
copyNum
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
...
source/tensor/core/movement/CopyValues.cpp
查看文件 @
d221ef9d
...
@@ -134,7 +134,9 @@ XTensor CopyValues(const XTensor &s, XStream * stream)
...
@@ -134,7 +134,9 @@ XTensor CopyValues(const XTensor &s, XStream * stream)
_CopyValues
(
&
s
,
&
t
,
stream
);
_CopyValues
(
&
s
,
&
t
,
stream
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
MOVEMENT_COPYVALUES
);
if
(
s
.
enableGrad
)
{
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
MOVEMENT_COPYVALUES
);
}
return
t
;
return
t
;
}
}
...
...
source/tensor/core/movement/Gather.cpp
查看文件 @
d221ef9d
...
@@ -93,9 +93,11 @@ XTensor Gather(XTensor &s, XTensor &index)
...
@@ -93,9 +93,11 @@ XTensor Gather(XTensor &s, XTensor &index)
_Gather
(
&
s
,
&
t
,
&
index
);
_Gather
(
&
s
,
&
t
,
&
index
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
s
,
&
index
,
&
t
,
MOVEMENT_GATHER
);
if
(
s
.
enableGrad
)
{
XLink
::
MakeLink
(
&
s
,
&
index
,
&
t
,
MOVEMENT_GATHER
);
}
return
t
;
return
t
;
}
}
}
// namespace nts(NiuTrans.Tensor)
}
//
namespace
nts
(
NiuTrans
.
Tensor
)
\ No newline at end of file
source/tensor/core/reduce/ReduceMax.cpp
查看文件 @
d221ef9d
...
@@ -181,8 +181,10 @@ XTensor ReduceMax(const XTensor &input, int dim)
...
@@ -181,8 +181,10 @@ XTensor ReduceMax(const XTensor &input, int dim)
_ReduceMax
(
&
input
,
&
output
,
dim
);
_ReduceMax
(
&
input
,
&
output
,
dim
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCEMAX
);
if
(
input
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCEMAX
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -221,7 +223,7 @@ void ReduceMax(const XTensor &input, XTensor &output, int dim)
...
@@ -221,7 +223,7 @@ void ReduceMax(const XTensor &input, XTensor &output, int dim)
/* call _ReduceMax function */
/* call _ReduceMax function */
_ReduceMax
(
&
input
,
&
output
,
dim
);
_ReduceMax
(
&
input
,
&
output
,
dim
);
if
(
out
put
.
enableGrad
)
{
if
(
in
put
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCEMAX
);
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCEMAX
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
...
...
source/tensor/core/reduce/ReduceMean.cpp
查看文件 @
d221ef9d
...
@@ -77,8 +77,10 @@ XTensor ReduceMean(const XTensor &input, int dim)
...
@@ -77,8 +77,10 @@ XTensor ReduceMean(const XTensor &input, int dim)
_ReduceMean
(
&
input
,
&
output
,
dim
);
_ReduceMean
(
&
input
,
&
output
,
dim
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCEMEAN
);
if
(
input
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCEMEAN
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -119,7 +121,7 @@ void ReduceMean(const XTensor &input, XTensor &output, int dim)
...
@@ -119,7 +121,7 @@ void ReduceMean(const XTensor &input, XTensor &output, int dim)
/* call _ReduceMean function */
/* call _ReduceMean function */
_ReduceMean
(
&
input
,
&
output
,
dim
);
_ReduceMean
(
&
input
,
&
output
,
dim
);
if
(
out
put
.
enableGrad
)
{
if
(
in
put
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCEMEAN
);
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCEMEAN
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
...
...
source/tensor/core/reduce/ReduceSum.cpp
查看文件 @
d221ef9d
...
@@ -306,10 +306,12 @@ XTensor ReduceSum(const XTensor &input, int dim, const XTensor &shift, DTYPE pow
...
@@ -306,10 +306,12 @@ XTensor ReduceSum(const XTensor &input, int dim, const XTensor &shift, DTYPE pow
_ReduceSum
(
&
input
,
&
output
,
dim
,
&
shift
,
power
,
isExp
);
_ReduceSum
(
&
input
,
&
output
,
dim
,
&
shift
,
power
,
isExp
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
input
,
&
shift
,
&
output
,
REDUCE_REDUCESUM
);
if
(
input
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
MakeLink
(
&
input
,
&
shift
,
&
output
,
REDUCE_REDUCESUM
);
XLink
::
AddParamToHead
(
&
output
,
power
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
AddParamToHeadBool
(
&
output
,
isExp
);
XLink
::
AddParamToHead
(
&
output
,
power
);
XLink
::
AddParamToHeadBool
(
&
output
,
isExp
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -341,7 +343,7 @@ void ReduceSum(const XTensor &input, XTensor &output, int dim, const XTensor &sh
...
@@ -341,7 +343,7 @@ void ReduceSum(const XTensor &input, XTensor &output, int dim, const XTensor &sh
/* call _ReduceSum function */
/* call _ReduceSum function */
_ReduceSum
(
&
input
,
&
output
,
dim
,
&
shift
,
power
,
isExp
);
_ReduceSum
(
&
input
,
&
output
,
dim
,
&
shift
,
power
,
isExp
);
if
(
out
put
.
enableGrad
)
{
if
(
in
put
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
input
,
&
shift
,
&
output
,
REDUCE_REDUCESUM
);
XLink
::
MakeLink
(
&
input
,
&
shift
,
&
output
,
REDUCE_REDUCESUM
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
...
@@ -385,10 +387,12 @@ XTensor ReduceSum(const XTensor &input, int dim, DTYPE power, bool isExp)
...
@@ -385,10 +387,12 @@ XTensor ReduceSum(const XTensor &input, int dim, DTYPE power, bool isExp)
_ReduceSum
(
&
input
,
&
output
,
dim
,
NULL
,
power
,
isExp
);
_ReduceSum
(
&
input
,
&
output
,
dim
,
NULL
,
power
,
isExp
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCESUM
);
if
(
input
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCESUM
);
XLink
::
AddParamToHead
(
&
output
,
power
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
AddParamToHeadBool
(
&
output
,
isExp
);
XLink
::
AddParamToHead
(
&
output
,
power
);
XLink
::
AddParamToHeadBool
(
&
output
,
isExp
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -434,7 +438,7 @@ void ReduceSum(const XTensor &input, XTensor &output, int dim, DTYPE power, bool
...
@@ -434,7 +438,7 @@ void ReduceSum(const XTensor &input, XTensor &output, int dim, DTYPE power, bool
/* call _ReduceSum function */
/* call _ReduceSum function */
_ReduceSum
(
&
input
,
&
output
,
dim
,
NULL
,
power
,
isExp
);
_ReduceSum
(
&
input
,
&
output
,
dim
,
NULL
,
power
,
isExp
);
if
(
out
put
.
enableGrad
)
{
if
(
in
put
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCESUM
);
XLink
::
MakeLink
(
&
input
,
NULL
,
&
output
,
REDUCE_REDUCESUM
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
...
...
source/tensor/core/reduce/ReduceSumSquared.cpp
查看文件 @
d221ef9d
...
@@ -73,8 +73,10 @@ XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift)
...
@@ -73,8 +73,10 @@ XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift)
_ReduceSumSquared
(
&
input
,
&
output
,
dim
,
&
shift
);
_ReduceSumSquared
(
&
input
,
&
output
,
dim
,
&
shift
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
input
,
&
shift
,
&
output
,
REDUCE_REDUCESUMSQUARED
);
if
(
input
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
MakeLink
(
&
input
,
&
shift
,
&
output
,
REDUCE_REDUCESUMSQUARED
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -116,7 +118,7 @@ void ReduceSumSquared(const XTensor &input, XTensor &output, int dim, const XTen
...
@@ -116,7 +118,7 @@ void ReduceSumSquared(const XTensor &input, XTensor &output, int dim, const XTen
/* call _ReduceSumSquared function */
/* call _ReduceSumSquared function */
_ReduceSumSquared
(
&
input
,
&
output
,
dim
,
&
shift
);
_ReduceSumSquared
(
&
input
,
&
output
,
dim
,
&
shift
);
if
(
out
put
.
enableGrad
)
{
if
(
in
put
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
input
,
&
shift
,
&
output
,
REDUCE_REDUCESUMSQUARED
);
XLink
::
MakeLink
(
&
input
,
&
shift
,
&
output
,
REDUCE_REDUCESUMSQUARED
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
...
...
source/tensor/core/reduce/ReduceVariance.cpp
查看文件 @
d221ef9d
...
@@ -76,8 +76,10 @@ XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean)
...
@@ -76,8 +76,10 @@ XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean)
_ReduceVariance
(
&
input
,
&
output
,
dim
,
&
mean
);
_ReduceVariance
(
&
input
,
&
output
,
dim
,
&
mean
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
input
,
&
mean
,
&
output
,
REDUCE_REDUCEVARIANCE
);
if
(
input
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
MakeLink
(
&
input
,
&
mean
,
&
output
,
REDUCE_REDUCEVARIANCE
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -119,7 +121,7 @@ void ReduceVariance(const XTensor &input, XTensor &output, int dim, const XTenso
...
@@ -119,7 +121,7 @@ void ReduceVariance(const XTensor &input, XTensor &output, int dim, const XTenso
/* call _ReduceVariance function */
/* call _ReduceVariance function */
_ReduceVariance
(
&
input
,
&
output
,
dim
,
&
mean
);
_ReduceVariance
(
&
input
,
&
output
,
dim
,
&
mean
);
if
(
out
put
.
enableGrad
)
{
if
(
in
put
.
enableGrad
)
{
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
input
,
&
mean
,
&
output
,
REDUCE_REDUCEVARIANCE
);
XLink
::
MakeLink
(
&
input
,
&
mean
,
&
output
,
REDUCE_REDUCEVARIANCE
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
XLink
::
AddParamToHeadInt
(
&
output
,
dim
);
...
...
source/tensor/core/shape/Concatenate.cpp
查看文件 @
d221ef9d
...
@@ -99,9 +99,11 @@ XTensor Concatenate(const TensorList &smalls, int dim)
...
@@ -99,9 +99,11 @@ XTensor Concatenate(const TensorList &smalls, int dim)
_Merge
(
&
smalls
,
&
big
,
dim
);
_Merge
(
&
smalls
,
&
big
,
dim
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_MERGE
);
if
(
tensor
->
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
big
,
dim
);
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_MERGE
);
XLink
::
AddParamToHeadInt
(
&
big
,
dim
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -127,8 +129,10 @@ XTensor Concatenate(const TensorList &smalls, int dim)
...
@@ -127,8 +129,10 @@ XTensor Concatenate(const TensorList &smalls, int dim)
_ConcatenateSolely
(
&
smalls
,
&
big
,
dim
);
_ConcatenateSolely
(
&
smalls
,
&
big
,
dim
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_CONCATENATE
);
if
(
tensor
->
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
big
,
dim
);
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_CONCATENATE
);
XLink
::
AddParamToHeadInt
(
&
big
,
dim
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -309,9 +313,11 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
...
@@ -309,9 +313,11 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
_Merge
(
&
smalls
,
&
big
,
dim
);
_Merge
(
&
smalls
,
&
big
,
dim
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_MERGE
);
if
(
tensor
->
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
big
,
dim
);
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_MERGE
);
XLink
::
AddParamToHeadInt
(
&
big
,
dim
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -337,8 +343,10 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
...
@@ -337,8 +343,10 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
_ConcatenateSolely
(
&
smalls
,
&
big
,
dim
);
_ConcatenateSolely
(
&
smalls
,
&
big
,
dim
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_CONCATENATE
);
if
(
tensor
->
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
big
,
dim
);
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_CONCATENATE
);
XLink
::
AddParamToHeadInt
(
&
big
,
dim
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
...
source/tensor/core/shape/Merge.cpp
查看文件 @
d221ef9d
...
@@ -222,9 +222,11 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
...
@@ -222,9 +222,11 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
_Merge
(
&
s
,
&
t
,
whereToMerge
,
leadingDim
);
_Merge
(
&
s
,
&
t
,
whereToMerge
,
leadingDim
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_MERGE
);
if
(
s
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
t
,
whereToMerge
);
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_MERGE
);
XLink
::
AddParamToHeadInt
(
&
t
,
leadingDim
);
XLink
::
AddParamToHeadInt
(
&
t
,
whereToMerge
);
XLink
::
AddParamToHeadInt
(
&
t
,
leadingDim
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -261,7 +263,7 @@ void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim)
...
@@ -261,7 +263,7 @@ void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim)
/* call _Merge function */
/* call _Merge function */
_Merge
(
&
s
,
&
t
,
whereToMerge
,
leadingDim
);
_Merge
(
&
s
,
&
t
,
whereToMerge
,
leadingDim
);
if
(
t
.
enableGrad
)
{
if
(
s
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_MERGE
);
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_MERGE
);
XLink
::
AddParamToHeadInt
(
&
t
,
whereToMerge
);
XLink
::
AddParamToHeadInt
(
&
t
,
whereToMerge
);
...
@@ -412,8 +414,10 @@ XTensor Merge(const TensorList &smalls, int whereToMerge)
...
@@ -412,8 +414,10 @@ XTensor Merge(const TensorList &smalls, int whereToMerge)
_Merge
(
&
smalls
,
&
big
,
whereToMerge
);
_Merge
(
&
smalls
,
&
big
,
whereToMerge
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_MERGE_LIST
);
if
(
tensor
->
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
big
,
whereToMerge
);
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_MERGE_LIST
);
XLink
::
AddParamToHeadInt
(
&
big
,
whereToMerge
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -453,8 +457,10 @@ XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge)
...
@@ -453,8 +457,10 @@ XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge)
_Merge
(
&
smalls
,
&
big
,
whereToMerge
);
_Merge
(
&
smalls
,
&
big
,
whereToMerge
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_MERGE_LIST
);
if
(
smallA
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
big
,
whereToMerge
);
XLink
::
MakeLink
(
&
smalls
,
&
big
,
SHAPE_MERGE_LIST
);
XLink
::
AddParamToHeadInt
(
&
big
,
whereToMerge
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
...
source/tensor/core/shape/Reshape.cpp
查看文件 @
d221ef9d
...
@@ -43,7 +43,9 @@ XTensor Reshape(XTensor &s, int order, int * dimSize)
...
@@ -43,7 +43,9 @@ XTensor Reshape(XTensor &s, int order, int * dimSize)
t
.
Reshape
(
order
,
dimSize
);
t
.
Reshape
(
order
,
dimSize
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_RESHAPE
);
if
(
s
.
enableGrad
)
{
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_RESHAPE
);
}
return
t
;
return
t
;
}
}
...
@@ -57,7 +59,7 @@ void Reshape(XTensor &s, XTensor &t, int order, int * dimSize)
...
@@ -57,7 +59,7 @@ void Reshape(XTensor &s, XTensor &t, int order, int * dimSize)
/* call Reshape function */
/* call Reshape function */
t
.
Reshape
(
order
,
dimSize
);
t
.
Reshape
(
order
,
dimSize
);
if
(
t
.
enableGrad
)
{
if
(
s
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_RESHAPE
);
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_RESHAPE
);
}
}
...
...
source/tensor/core/shape/Split.cpp
查看文件 @
d221ef9d
...
@@ -217,9 +217,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
...
@@ -217,9 +217,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
_Split
(
&
s
,
&
t
,
whereToSplit
,
splitNum
);
_Split
(
&
s
,
&
t
,
whereToSplit
,
splitNum
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_SPLIT
);
if
(
s
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
t
,
whereToSplit
);
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_SPLIT
);
XLink
::
AddParamToHeadInt
(
&
t
,
splitNum
);
XLink
::
AddParamToHeadInt
(
&
t
,
whereToSplit
);
XLink
::
AddParamToHeadInt
(
&
t
,
splitNum
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -251,7 +253,7 @@ void Split(const XTensor &s, XTensor &t, int whereToSplit, int splitNum)
...
@@ -251,7 +253,7 @@ void Split(const XTensor &s, XTensor &t, int whereToSplit, int splitNum)
/* call _Split function */
/* call _Split function */
_Split
(
&
s
,
&
t
,
whereToSplit
,
splitNum
);
_Split
(
&
s
,
&
t
,
whereToSplit
,
splitNum
);
if
(
t
.
enableGrad
)
{
if
(
s
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_SPLIT
);
XLink
::
MakeLink
(
&
s
,
NULL
,
&
t
,
SHAPE_SPLIT
);
XLink
::
AddParamToHeadInt
(
&
t
,
whereToSplit
);
XLink
::
AddParamToHeadInt
(
&
t
,
whereToSplit
);
...
@@ -409,12 +411,15 @@ void Split(const XTensor &big, TensorList &smalls, int whereToSplit, int splitNu
...
@@ -409,12 +411,15 @@ void Split(const XTensor &big, TensorList &smalls, int whereToSplit, int splitNu
/* tensor connections */
/* tensor connections */
for
(
int
i
=
0
;
i
<
smalls
.
count
;
i
++
){
for
(
int
i
=
0
;
i
<
smalls
.
count
;
i
++
){
XTensor
*
s
=
(
XTensor
*
)
smalls
.
Get
(
i
);
XTensor
*
s
=
(
XTensor
*
)
smalls
.
Get
(
i
);
XLink
::
MakeLink
(
&
big
,
NULL
,
s
,
SHAPE_SPLIT_LIST
);
XLink
::
AddParamToHeadInt
(
s
,
whereToSplit
);
/* it is tricky here that we keep the id of each
if
(
s
->
enableGrad
)
{
block, rather than the total number of the splits */
XLink
::
MakeLink
(
&
big
,
NULL
,
s
,
SHAPE_SPLIT_LIST
);
XLink
::
AddParamToHeadInt
(
s
,
i
);
XLink
::
AddParamToHeadInt
(
s
,
whereToSplit
);
/* it is tricky here that we keep the id of each
block, rather than the total number of the splits */
XLink
::
AddParamToHeadInt
(
s
,
i
);
}
}
}
}
}
...
...
source/tensor/core/shape/Squeeze.cpp
查看文件 @
d221ef9d
...
@@ -121,7 +121,9 @@ XTensor Squeeze(XTensor & source, int leadingDim)
...
@@ -121,7 +121,9 @@ XTensor Squeeze(XTensor & source, int leadingDim)
_Squeeze
(
&
source
,
&
target
,
leadingDim
);
_Squeeze
(
&
source
,
&
target
,
leadingDim
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
source
,
NULL
,
&
target
,
SHAPE_SQUEEZE
);
if
(
source
.
enableGrad
)
{
XLink
::
MakeLink
(
&
source
,
NULL
,
&
target
,
SHAPE_SQUEEZE
);
}
return
target
;
return
target
;
}
}
...
@@ -135,7 +137,7 @@ void Squeeze(XTensor & source, XTensor & target, int leadingDim)
...
@@ -135,7 +137,7 @@ void Squeeze(XTensor & source, XTensor & target, int leadingDim)
/* call _Squeeze function */
/* call _Squeeze function */
_Squeeze
(
&
source
,
&
target
,
leadingDim
);
_Squeeze
(
&
source
,
&
target
,
leadingDim
);
if
(
target
.
enableGrad
)
{
if
(
source
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
source
,
NULL
,
&
target
,
SHAPE_SQUEEZE
);
XLink
::
MakeLink
(
&
source
,
NULL
,
&
target
,
SHAPE_SQUEEZE
);
}
}
...
...
source/tensor/core/shape/Transpose.cpp
查看文件 @
d221ef9d
...
@@ -144,9 +144,11 @@ XTensor Transpose(const XTensor &a, const int i, const int j)
...
@@ -144,9 +144,11 @@ XTensor Transpose(const XTensor &a, const int i, const int j)
_Transpose
(
&
a
,
&
b
,
i
,
j
);
_Transpose
(
&
a
,
&
b
,
i
,
j
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
SHAPE_TRANSPOSE
);
if
(
a
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
b
,
i
);
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
SHAPE_TRANSPOSE
);
XLink
::
AddParamToHeadInt
(
&
b
,
j
);
XLink
::
AddParamToHeadInt
(
&
b
,
i
);
XLink
::
AddParamToHeadInt
(
&
b
,
j
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
...
source/tensor/core/shape/Unsqueeze.cpp
查看文件 @
d221ef9d
...
@@ -156,9 +156,11 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize)
...
@@ -156,9 +156,11 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize)
_Unsqueeze
(
&
a
,
&
b
,
dim
,
dSize
);
_Unsqueeze
(
&
a
,
&
b
,
dim
,
dSize
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
SHAPE_UNSQUEEZE
);
if
(
a
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
b
,
dim
);
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
SHAPE_UNSQUEEZE
);
XLink
::
AddParamToHeadInt
(
&
b
,
dSize
);
XLink
::
AddParamToHeadInt
(
&
b
,
dim
);
XLink
::
AddParamToHeadInt
(
&
b
,
dSize
);
}
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
@@ -191,7 +193,7 @@ void Unsqueeze(const XTensor &a, XTensor &b, int dim, int dSize)
...
@@ -191,7 +193,7 @@ void Unsqueeze(const XTensor &a, XTensor &b, int dim, int dSize)
/* call _Unsqueeze function */
/* call _Unsqueeze function */
_Unsqueeze
(
&
a
,
&
b
,
dim
,
dSize
);
_Unsqueeze
(
&
a
,
&
b
,
dim
,
dSize
);
if
(
b
.
enableGrad
)
{
if
(
a
.
enableGrad
)
{
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
SHAPE_UNSQUEEZE
);
XLink
::
MakeLink
(
&
a
,
NULL
,
&
b
,
SHAPE_UNSQUEEZE
);
XLink
::
AddParamToHeadInt
(
&
b
,
dim
);
XLink
::
AddParamToHeadInt
(
&
b
,
dim
);
...
...
source/tensor/function/DropoutWithIndex.cpp
查看文件 @
d221ef9d
...
@@ -81,8 +81,10 @@ XTensor DropoutWithIndex(const XTensor &x, XTensor &maskIndex, DTYPE scale)
...
@@ -81,8 +81,10 @@ XTensor DropoutWithIndex(const XTensor &x, XTensor &maskIndex, DTYPE scale)
_ScaleAndShiftMe
(
&
c
,
scale
);
_ScaleAndShiftMe
(
&
c
,
scale
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
x
,
&
maskIndex
,
&
c
,
MOVEMENT_DROPOUTWITHINDEX
);
if
(
x
.
enableGrad
)
{
XLink
::
AddParamToHead
(
&
c
,
scale
);
XLink
::
MakeLink
(
&
x
,
&
maskIndex
,
&
c
,
MOVEMENT_DROPOUTWITHINDEX
);
XLink
::
AddParamToHead
(
&
c
,
scale
);
}
return
c
;
return
c
;
}
}
...
...
source/tensor/function/HardTanH.cpp
查看文件 @
d221ef9d
...
@@ -78,7 +78,9 @@ XTensor HardTanH(const XTensor &x)
...
@@ -78,7 +78,9 @@ XTensor HardTanH(const XTensor &x)
_HardTanH
(
&
x
,
&
y
);
_HardTanH
(
&
x
,
&
y
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_HARDTANH
);
if
(
x
.
enableGrad
)
{
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_HARDTANH
);
}
return
y
;
return
y
;
}
}
...
@@ -92,7 +94,7 @@ void HardTanH(const XTensor &x, XTensor &y)
...
@@ -92,7 +94,7 @@ void HardTanH(const XTensor &x, XTensor &y)
/* call _HardTanH function */
/* call _HardTanH function */
_HardTanH
(
&
x
,
&
y
);
_HardTanH
(
&
x
,
&
y
);
if
(
y
.
enableGrad
)
{
if
(
x
.
enableGrad
)
{
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_HARDTANH
);
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_HARDTANH
);
}
}
...
...
source/tensor/function/Identity.cpp
查看文件 @
d221ef9d
...
@@ -54,7 +54,9 @@ XTensor Identity(const XTensor &x)
...
@@ -54,7 +54,9 @@ XTensor Identity(const XTensor &x)
_Identity
(
&
x
,
&
y
);
_Identity
(
&
x
,
&
y
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_IDENTITY
);
if
(
x
.
enableGrad
)
{
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_IDENTITY
);
}
return
y
;
return
y
;
}
}
...
@@ -68,7 +70,7 @@ void Identity(const XTensor &x, XTensor &y)
...
@@ -68,7 +70,7 @@ void Identity(const XTensor &x, XTensor &y)
/* call _Identity function */
/* call _Identity function */
_Identity
(
&
x
,
&
y
);
_Identity
(
&
x
,
&
y
);
if
(
y
.
enableGrad
)
{
if
(
x
.
enableGrad
)
{
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_IDENTITY
);
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_IDENTITY
);
}
}
...
...
source/tensor/function/LogSoftmax.cpp
查看文件 @
d221ef9d
...
@@ -188,8 +188,10 @@ XTensor LogSoftmax(const XTensor &x, int leadDim)
...
@@ -188,8 +188,10 @@ XTensor LogSoftmax(const XTensor &x, int leadDim)
_LogSoftmax
(
&
x
,
&
y
,
ld
);
_LogSoftmax
(
&
x
,
&
y
,
ld
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_LOGSOFTMAX
);
if
(
x
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
y
,
ld
);
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_LOGSOFTMAX
);
XLink
::
AddParamToHeadInt
(
&
y
,
ld
);
}
return
y
;
return
y
;
}
}
...
@@ -215,7 +217,7 @@ void LogSoftmax(const XTensor &x, XTensor &y, int leadDim)
...
@@ -215,7 +217,7 @@ void LogSoftmax(const XTensor &x, XTensor &y, int leadDim)
/* call _LogSoftmax function */
/* call _LogSoftmax function */
_LogSoftmax
(
&
x
,
&
y
,
ld
);
_LogSoftmax
(
&
x
,
&
y
,
ld
);
if
(
y
.
enableGrad
)
{
if
(
x
.
enableGrad
)
{
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_LOGSOFTMAX
);
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_LOGSOFTMAX
);
XLink
::
AddParamToHeadInt
(
&
y
,
ld
);
XLink
::
AddParamToHeadInt
(
&
y
,
ld
);
...
...
source/tensor/function/Rectify.cpp
查看文件 @
d221ef9d
...
@@ -70,7 +70,9 @@ XTensor Rectify(const XTensor &x)
...
@@ -70,7 +70,9 @@ XTensor Rectify(const XTensor &x)
_Rectify
(
&
x
,
&
y
);
_Rectify
(
&
x
,
&
y
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_RECTIFY
);
if
(
x
.
enableGrad
)
{
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_RECTIFY
);
}
return
y
;
return
y
;
}
}
...
@@ -84,7 +86,7 @@ void Rectify(const XTensor &x, XTensor &y)
...
@@ -84,7 +86,7 @@ void Rectify(const XTensor &x, XTensor &y)
/* call _Rectify function */
/* call _Rectify function */
_Rectify
(
&
x
,
&
y
);
_Rectify
(
&
x
,
&
y
);
if
(
y
.
enableGrad
)
{
if
(
x
.
enableGrad
)
{
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_RECTIFY
);
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_RECTIFY
);
}
}
...
...
source/tensor/function/Sigmoid.cpp
查看文件 @
d221ef9d
...
@@ -73,7 +73,9 @@ XTensor Sigmoid(const XTensor &x)
...
@@ -73,7 +73,9 @@ XTensor Sigmoid(const XTensor &x)
_Sigmoid
(
&
x
,
&
y
);
_Sigmoid
(
&
x
,
&
y
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_SIGMOID
);
if
(
x
.
enableGrad
)
{
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_SIGMOID
);
}
return
y
;
return
y
;
}
}
...
@@ -87,7 +89,7 @@ void Sigmoid(const XTensor &x, XTensor &y)
...
@@ -87,7 +89,7 @@ void Sigmoid(const XTensor &x, XTensor &y)
/* call _Sigmoid function */
/* call _Sigmoid function */
_Sigmoid
(
&
x
,
&
y
);
_Sigmoid
(
&
x
,
&
y
);
if
(
y
.
enableGrad
)
{
if
(
x
.
enableGrad
)
{
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_SIGMOID
);
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_SIGMOID
);
}
}
...
...
source/tensor/function/Softmax.cpp
查看文件 @
d221ef9d
...
@@ -142,8 +142,10 @@ XTensor Softmax(const XTensor &x, int leadDim)
...
@@ -142,8 +142,10 @@ XTensor Softmax(const XTensor &x, int leadDim)
_Softmax
(
&
x
,
&
y
,
ld
);
_Softmax
(
&
x
,
&
y
,
ld
);
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_SOFTMAX
);
if
(
x
.
enableGrad
)
{
XLink
::
AddParamToHeadInt
(
&
y
,
ld
);
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_SOFTMAX
);
XLink
::
AddParamToHeadInt
(
&
y
,
ld
);
}
return
y
;
return
y
;
}
}
...
@@ -161,7 +163,7 @@ void Softmax(const XTensor &x, XTensor &y, int leadDim)
...
@@ -161,7 +163,7 @@ void Softmax(const XTensor &x, XTensor &y, int leadDim)
/* call _Softmax function */
/* call _Softmax function */
_Softmax
(
&
x
,
&
y
,
ld
);
_Softmax
(
&
x
,
&
y
,
ld
);
if
(
y
.
enableGrad
)
{
if
(
x
.
enableGrad
)
{
/* tensor connection */
/* tensor connection */
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_SOFTMAX
);
XLink
::
MakeLink
(
&
x
,
NULL
,
&
y
,
FUNC_SOFTMAX
);
XLink
::
AddParamToHeadInt
(
&
y
,
ld
);
XLink
::
AddParamToHeadInt
(
&
y
,
ld
);
...
...
source/tensor/loss/CrossEntropy.cpp
查看文件 @
d221ef9d
...
@@ -277,8 +277,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
...
@@ -277,8 +277,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
tails
.
Add
((
XTensor
*
)
&
gold
);
tails
.
Add
((
XTensor
*
)
&
gold
);
tails
.
Add
(
weight
);
tails
.
Add
(
weight
);
tails
.
Add
(
padding
);
tails
.
Add
(
padding
);
XLink
::
MakeLink
(
&
tails
,
&
loss
,
LOSS_CROSSENTROPY
);
XLink
::
AddParamToHeadInt
(
&
loss
,
dim
);
if
(
output
.
enableGrad
)
{
XLink
::
MakeLink
(
&
tails
,
&
loss
,
LOSS_CROSSENTROPY
);
XLink
::
AddParamToHeadInt
(
&
loss
,
dim
);
}
return
loss
;
return
loss
;
}
}
...
@@ -302,8 +305,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
...
@@ -302,8 +305,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
tails
.
Add
((
XTensor
*
)
&
gold
);
tails
.
Add
((
XTensor
*
)
&
gold
);
tails
.
Add
(
weight
);
tails
.
Add
(
weight
);
tails
.
Add
((
XTensor
*
)
&
padding
);
tails
.
Add
((
XTensor
*
)
&
padding
);
XLink
::
MakeLink
(
&
tails
,
&
loss
,
LOSS_CROSSENTROPY
);
XLink
::
AddParamToHeadInt
(
&
loss
,
dim
);
if
(
output
.
enableGrad
)
{
XLink
::
MakeLink
(
&
tails
,
&
loss
,
LOSS_CROSSENTROPY
);
XLink
::
AddParamToHeadInt
(
&
loss
,
dim
);
}
return
loss
;
return
loss
;
}
}
...
...
source/tensor/test/TSetData.cpp
查看文件 @
d221ef9d
...
@@ -421,7 +421,7 @@ bool TestSetData6()
...
@@ -421,7 +421,7 @@ bool TestSetData6()
for
(
int
i
=
0
;
i
<
order
;
i
++
)
for
(
int
i
=
0
;
i
<
order
;
i
++
)
unitNum
*=
dimSize
[
i
];
unitNum
*=
dimSize
[
i
];
DTYPE
answer
[
5
]
=
{
5.2
,
3.2
,
1.2
,
-
0.8
,
-
2.8
};
DTYPE
answer
[
5
]
=
{
5.2
F
,
3.2
F
,
1.2
F
,
-
0.8
F
,
-
2.8
F
};
/* CPU test */
/* CPU test */
bool
cpuTest
=
true
;
bool
cpuTest
=
true
;
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论