Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
4b7f7a18
Commit
4b7f7a18
authored
Jul 18, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
code update for backward propagation
parent
4452db19
显示空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
581 行增加
和
39 行删除
+581
-39
source/network/Main.cpp
+27
-1
source/network/XBackwardMath.cpp
+34
-0
source/network/XBackwardMath.h
+4
-0
source/network/XNet.cpp
+51
-1
source/network/XNet.h
+12
-2
source/tensor/XName.cpp
+14
-0
source/tensor/core/arithmetic/MatrixMul.cpp
+59
-10
source/tensor/core/arithmetic/MatrixMul.h
+10
-3
source/tensor/core/arithmetic/MatrixMul2D.cu
+5
-2
source/tensor/core/arithmetic/MatrixMul2D.h
+0
-0
source/tensor/core/arithmetic/MatrixMulBatched.cpp
+3
-3
source/tensor/core/arithmetic/MatrixMulBatched.h
+3
-3
source/tensor/core/arithmetic/XTensorBLAS.cu
+0
-0
source/tensor/core/getandset/SetData.cpp
+142
-2
source/tensor/core/getandset/SetData.cu
+149
-0
source/tensor/core/getandset/SetData.cuh
+43
-0
source/tensor/core/getandset/SetData.h
+9
-12
source/tensor/function/HardTanH.cpp
+0
-0
source/tensor/function/HardTanH.h
+2
-0
source/tensor/function/Loss.cpp
+14
-0
没有找到文件。
source/network/Main.cpp
查看文件 @
4b7f7a18
...
@@ -21,6 +21,8 @@
...
@@ -21,6 +21,8 @@
#include <stdio.h>
#include <stdio.h>
#include "XNet.h"
#include "XNet.h"
#include "../tensor/function/FHeader.h"
#include "../tensor/core/CHeader.h"
//#define CRTDBG_MAP_ALLOC
//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>
//#include <stdlib.h>
...
@@ -42,7 +44,31 @@ int main( int argc, const char ** argv )
...
@@ -42,7 +44,31 @@ int main( int argc, const char ** argv )
XNet
net
;
XNet
net
;
XTensor
a
;
XTensor
a
;
net
.
Backward
(
a
);
XTensor
b
;
XTensor
c
;
InitTensor2D
(
&
a
,
2
,
2
);
InitTensor2D
(
&
b
,
2
,
2
);
InitTensor2D
(
&
c
,
2
,
2
);
a
.
SetZeroAll
();
b
.
SetZeroAll
();
c
.
SetZeroAll
();
SetDataFixed
(
a
,
1.0
F
);
a
.
Set2D
(
3.0
F
,
1
,
0
);
a
.
Set2D
(
4.0
F
,
1
,
1
);
b
=
a
+
a
;
c
=
HTanH
(
MMul
(
a
,
b
));
a
.
Dump
(
stderr
,
"a:"
);
b
.
Dump
(
stderr
,
"b:"
);
c
.
Dump
(
stderr
,
"c:"
);
net
.
Backward
(
c
);
net
.
Dump
(
stderr
);
//_CrtDumpMemoryLeaks();
//_CrtDumpMemoryLeaks();
...
...
source/network/XBackwardMath.cpp
查看文件 @
4b7f7a18
...
@@ -39,6 +39,8 @@ void XMathGrad::MakeGrad(XTensor * node)
...
@@ -39,6 +39,8 @@ void XMathGrad::MakeGrad(XTensor * node)
GradSum
(
node
);
GradSum
(
node
);
else
if
(
operID
==
MATH_MULTIPLY
)
else
if
(
operID
==
MATH_MULTIPLY
)
GradMultiply
(
node
);
GradMultiply
(
node
);
else
if
(
operID
==
MATH_MATRIXMUL
)
GradMatrixMul
(
node
);
else
{
else
{
ShowNTErrors
(
"TODO!"
);
ShowNTErrors
(
"TODO!"
);
}
}
...
@@ -102,4 +104,35 @@ void XMathGrad::GradMultiply(XTensor * node)
...
@@ -102,4 +104,35 @@ void XMathGrad::GradMultiply(XTensor * node)
_Multiply
(
node
->
grad
,
a
,
b
->
grad
);
_Multiply
(
node
->
grad
,
a
,
b
->
grad
);
}
}
/*
gradient for matrix multiply
for c = matmul(a, b) * \alpha
we have
dE/da = dE/dc * b^T * \alpha
dE/db = a^T * dE/dc * \alpha
>> node - the node (c) for backward computation
*/
void
XMathGrad
::
GradMatrixMul
(
XTensor
*
node
)
{
XLink
&
income
=
node
->
income
;
CheckNTErrors
(
income
.
tailNum
==
2
,
"Wrong input tensor number for MULTIPLY!"
);
XTensor
*
a
=
income
.
tails
[
0
];
XTensor
*
b
=
income
.
tails
[
1
];
DTYPE
alpha
=
income
.
GetParam
(
0
);
XNoder
::
MakeGrad
(
a
);
XNoder
::
MakeGrad
(
b
);
XTensor
*
dedc
=
node
->
grad
;
XTensor
*
deda
=
a
->
grad
;
XTensor
*
dedb
=
b
->
grad
;
/* dE/da = dE/dc * b^T * \alpha */
_MatrixMul
(
dedc
,
X_NOTRANS
,
b
,
X_TRANS
,
deda
,
alpha
);
/* dE/db = a^T * dE/dc * \alpha */
_MatrixMul
(
a
,
X_TRANS
,
dedc
,
X_NOTRANS
,
dedb
,
alpha
);
}
}
}
\ No newline at end of file
source/network/XBackwardMath.h
查看文件 @
4b7f7a18
...
@@ -47,6 +47,10 @@ private:
...
@@ -47,6 +47,10 @@ private:
/* gradient for multiply (dot production): c = a * b */
/* gradient for multiply (dot production): c = a * b */
static
static
void
GradMultiply
(
XTensor
*
node
);
void
GradMultiply
(
XTensor
*
node
);
/* gradient for matrix multiply: c = matmul(a, b) */
static
void
GradMatrixMul
(
XTensor
*
node
);
};
};
}
}
...
...
source/network/XNet.cpp
查看文件 @
4b7f7a18
...
@@ -83,6 +83,22 @@ void XNet::Backward(XTensor &root, XTensor &gold, LOSS_FUNCTION_NAME loss)
...
@@ -83,6 +83,22 @@ void XNet::Backward(XTensor &root, XTensor &gold, LOSS_FUNCTION_NAME loss)
}
}
/*
/*
backward propagation to obtain gradient
>> root - root node (output) of the network
>> loss - name of loss function
*/
void
XNet
::
Backward
(
XTensor
&
root
,
LOSS_FUNCTION_NAME
loss
)
{
XList
roots
(
1
);
roots
.
Add
(
&
root
);
XList
golds
(
1
);
golds
.
Add
(
NULL
);
Backward
(
roots
,
golds
,
loss
);
}
/*
backward propagation to obtain gradient wrt. the loss/error function
backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes
with a number of root nodes
>> root - a list of root nodes (output) of the network
>> root - a list of root nodes (output) of the network
...
@@ -111,7 +127,7 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
...
@@ -111,7 +127,7 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
/* we compute dE/dx if the output is generated by an activation function y = f(x).
/* we compute dE/dx if the output is generated by an activation function y = f(x).
Note that we do not need to obtain dE/dy here because it is no use in the
Note that we do not need to obtain dE/dy here because it is no use in the
folloing process of back-propagation */
folloing process of back-propagation */
if
(
income
.
tailNum
==
1
&&
(
funcID
&
FUNCTION_BASE
)){
if
(
gold
!=
NULL
&&
income
.
tailNum
==
1
&&
(
funcID
&
FUNCTION_BASE
)){
XTensor
*
x
=
income
.
tails
[
0
];
XTensor
*
x
=
income
.
tails
[
0
];
XNoder
::
MakeGrad
(
x
);
XNoder
::
MakeGrad
(
x
);
lossGrad
.
Compute
(
gold
,
root
,
x
,
NULL
,
x
->
grad
,
funcID
,
params
,
loss
);
lossGrad
.
Compute
(
gold
,
root
,
x
,
NULL
,
x
->
grad
,
funcID
,
params
,
loss
);
...
@@ -133,6 +149,21 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
...
@@ -133,6 +149,21 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
}
}
/*
/*
backward propagation to obtain gradient
with a number of root nodes
>> root - a list of root nodes (output) of the network
>> loss - name of loss function
*/
void
XNet
::
Backward
(
XList
&
roots
,
LOSS_FUNCTION_NAME
loss
)
{
XList
golds
(
roots
.
count
);
for
(
int
i
=
0
;
i
<
roots
.
count
;
i
++
)
golds
.
Add
(
NULL
);
Backward
(
roots
,
golds
,
loss
);
}
/*
backward computation for a given node
backward computation for a given node
>> node - the node keeps the result of an operation (e.g., activation function)
>> node - the node keeps the result of an operation (e.g., activation function)
*/
*/
...
@@ -219,4 +250,22 @@ void XNet::TarjanVisit(XTensor * node, XList &orders, const unsigned int code)
...
@@ -219,4 +250,22 @@ void XNet::TarjanVisit(XTensor * node, XList &orders, const unsigned int code)
}
}
}
}
/*
dump network information
>> file - the file for dumping
*/
void
XNet
::
Dump
(
FILE
*
file
)
{
for
(
int
i
=
0
;
i
<
nodes
.
count
;
i
++
){
XTensor
*
node
=
(
XTensor
*
)
nodes
.
Get
(
i
);
fprintf
(
file
,
"node %d
\n
"
,
i
);
node
->
Dump
(
file
,
"tensor: "
);
if
(
node
->
grad
!=
NULL
)
node
->
grad
->
Dump
(
file
,
"grad: "
);
else
fprintf
(
file
,
"no gradient!
\n
"
);
fprintf
(
file
,
"
\n
"
);
}
}
}
}
\ No newline at end of file
source/network/XNet.h
查看文件 @
4b7f7a18
...
@@ -57,11 +57,18 @@ struct XNet
...
@@ -57,11 +57,18 @@ struct XNet
void
Clear
();
void
Clear
();
/* backward propagation to obtain gradient wrt. the loss/error function */
/* backward propagation to obtain gradient wrt. the loss/error function */
void
Backward
(
XTensor
&
root
,
XTensor
&
gold
=
NULLTensor
,
LOSS_FUNCTION_NAME
loss
=
NOLOSS
);
void
Backward
(
XTensor
&
root
,
XTensor
&
gold
,
LOSS_FUNCTION_NAME
loss
=
NOLOSS
);
/* backward propagation to obtain gradient */
void
Backward
(
XTensor
&
root
,
LOSS_FUNCTION_NAME
loss
=
NOLOSS
);
/* backward propagation to obtain gradient wrt. the loss/error function
/* backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes */
with a number of root nodes */
void
Backward
(
XList
&
roots
,
XList
&
golds
=
NULLList
,
LOSS_FUNCTION_NAME
loss
=
NOLOSS
);
void
Backward
(
XList
&
roots
,
XList
&
golds
,
LOSS_FUNCTION_NAME
loss
=
NOLOSS
);
/* backward propagation to obtain gradient
with a number of root nodes */
void
Backward
(
XList
&
roots
,
LOSS_FUNCTION_NAME
loss
=
NOLOSS
);
/* backward computation for a given node */
/* backward computation for a given node */
void
BackwardNode
(
XTensor
*
node
);
void
BackwardNode
(
XTensor
*
node
);
...
@@ -76,6 +83,9 @@ struct XNet
...
@@ -76,6 +83,9 @@ struct XNet
/* depth-first search given a node (Tarjan's algorithm for topological ordering) */
/* depth-first search given a node (Tarjan's algorithm for topological ordering) */
void
TarjanVisit
(
XTensor
*
node
,
XList
&
orders
,
const
unsigned
int
code
);
void
TarjanVisit
(
XTensor
*
node
,
XList
&
orders
,
const
unsigned
int
code
);
/* dump network information */
void
Dump
(
FILE
*
file
);
};
};
/* we make a unique id for every tensor */
/* we make a unique id for every tensor */
...
...
source/tensor/XName.cpp
查看文件 @
4b7f7a18
...
@@ -80,6 +80,20 @@ const char * GetOPName(int type)
...
@@ -80,6 +80,20 @@ const char * GetOPName(int type)
else
if
(
type
==
SHAPE_UNSQUEEZE
)
else
if
(
type
==
SHAPE_UNSQUEEZE
)
return
"S_UNSQUEEZE"
;
return
"S_UNSQUEEZE"
;
}
}
else
if
((
type
&
FUNCTION_BASE
)
!=
0
){
if
(
type
==
FUNC_HARDTANH
)
return
"F_HARDTANH"
;
else
if
(
type
==
FUNC_IDENTITY
)
return
"F_IDENTITY"
;
else
if
(
type
==
FUNC_LOGSOFTMAX
)
return
"F_LOGSOFTMAX"
;
else
if
(
type
==
FUNC_RECTIFY
)
return
"F_RECTIFY"
;
else
if
(
type
==
FUNC_SIGMOID
)
return
"F_SIGMOID"
;
else
if
(
type
==
FUNC_SOFTMAX
)
return
"F_SOFTMAX"
;
}
return
"NULL"
;
return
"NULL"
;
}
}
...
...
source/tensor/core/arithmetic/MatrixMul.cpp
查看文件 @
4b7f7a18
...
@@ -30,7 +30,7 @@
...
@@ -30,7 +30,7 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/*
/*
matrix multiplication
matrix multiplication
c = trans(a) * trans(b) * alpha + c * beta
For the input tensors a and b, we perform matrix multiplication on the first two dimentsions.
For the input tensors a and b, we perform matrix multiplication on the first two dimentsions.
E.g., let A be a tensor of size y * z * m and B be a tensor of size x * y * n.
E.g., let A be a tensor of size y * z * m and B be a tensor of size x * y * n.
...
@@ -66,8 +66,7 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
...
@@ -66,8 +66,7 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
int
cn
=
c
->
dimSizeRDI
[
1
];
int
cn
=
c
->
dimSizeRDI
[
1
];
int
cm
=
c
->
dimSizeRDI
[
0
];
int
cm
=
c
->
dimSizeRDI
[
0
];
CheckNTErrors
((
am
==
bn
&&
an
==
cn
&&
bm
==
cm
),
CheckNTErrors
((
am
==
bn
&&
an
==
cn
&&
bm
==
cm
),
"Unmatched tensors in multiplication!"
);
"Unmatched tensors in multiplication!"
);
int
aBlockSize
=
a
->
dimSizeRDI
[
0
]
*
a
->
dimSizeRDI
[
1
];
int
aBlockSize
=
a
->
dimSizeRDI
[
0
]
*
a
->
dimSizeRDI
[
1
];
int
bBlockSize
=
b
->
dimSizeRDI
[
0
]
*
b
->
dimSizeRDI
[
1
];
int
bBlockSize
=
b
->
dimSizeRDI
[
0
]
*
b
->
dimSizeRDI
[
1
];
...
@@ -186,7 +185,7 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
...
@@ -186,7 +185,7 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
}
}
/*
/*
matrix multiplication (return a XTensor structure)
matrix multiplication (return a XTensor structure)
c = trans(a) * trans(b) * alpha
make a new tensor to keep the result and return it
make a new tensor to keep the result and return it
For the input tensors a and b, we perform matrix multiplication on the first two dimentsions.
For the input tensors a and b, we perform matrix multiplication on the first two dimentsions.
...
@@ -203,14 +202,13 @@ Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x
...
@@ -203,14 +202,13 @@ Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x
>> b - tensor b
>> b - tensor b
>> transposedB - indicates whether teh matrices in b are transposed
>> transposedB - indicates whether teh matrices in b are transposed
>> alpha - a coefficient
>> alpha - a coefficient
>> beta - another coefficient
>> parallelRunner - parallel processing module
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
<< return - the result of matrix multiplication
*/
*/
XTensor
MatrixMul
(
const
XTensor
&
a
,
MATRIX_TRANS_TYPE
transposedA
,
const
XTensor
&
b
,
MATRIX_TRANS_TYPE
transposedB
,
XTensor
MatrixMul
(
const
XTensor
&
a
,
MATRIX_TRANS_TYPE
transposedA
,
DTYPE
alpha
,
DTYPE
beta
,
XPRunner
*
parallelRunner
)
const
XTensor
&
b
,
MATRIX_TRANS_TYPE
transposedB
,
DTYPE
alpha
,
XPRunner
*
parallelRunner
)
{
{
CheckNTErrors
(
&
a
!=
&
NULLTensor
&&
&
b
!=
&
NULLTensor
,
"Empty input tensors!"
);
CheckNTErrors
(
a
.
dataType
==
b
.
dataType
,
"Input tensors should have the same data type!"
);
CheckNTErrors
(
a
.
dataType
==
b
.
dataType
,
"Input tensors should have the same data type!"
);
CheckNTErrors
(
a
.
order
>=
2
&&
b
.
order
>=
2
,
"Input tensors must have a order >= 2!"
);
CheckNTErrors
(
a
.
order
>=
2
&&
b
.
order
>=
2
,
"Input tensors must have a order >= 2!"
);
...
@@ -236,14 +234,65 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor
...
@@ -236,14 +234,65 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor
c
.
SetTMP
();
c
.
SetTMP
();
/* call _MatrixMul function */
/* call _MatrixMul function */
_MatrixMul
(
&
a
,
transposedA
,
&
b
,
transposedB
,
&
c
,
alpha
,
beta
,
parallelRunner
);
_MatrixMul
(
&
a
,
transposedA
,
&
b
,
transposedB
,
&
c
,
alpha
,
0
,
parallelRunner
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMUL
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMUL
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedA
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedA
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedB
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedB
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
/* destroy variables */
delete
[]
dimSize
;
return
c
;
}
/*
matrix multiplication with no transposition c = a * b * alpha
>> a - tensor a
>> transposedA - indicates whether the matrices in a are transposed
>> b - tensor b
>> transposedB - indicates whether teh matrices in b are transposed
>> alpha - a coefficient
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
*/
XTensor
MatrixMul
(
const
XTensor
&
a
,
const
XTensor
&
b
,
DTYPE
alpha
,
XPRunner
*
parallelRunner
)
{
CheckNTErrors
(
a
.
dataType
==
b
.
dataType
,
"Input tensors should have the same data type!"
);
CheckNTErrors
(
a
.
order
>=
2
&&
b
.
order
>=
2
,
"Input tensors must have a order >= 2!"
);
int
an
=
a
.
dimSizeRDI
[
0
];
int
am
=
a
.
dimSizeRDI
[
1
];
int
bn
=
b
.
dimSizeRDI
[
0
];
int
bm
=
b
.
dimSizeRDI
[
1
];
CheckNTErrors
(
am
==
bn
,
"Unmatched tensors in multiplication!"
);
int
order
=
a
.
order
+
b
.
order
-
2
;
int
sub
=
0
;
int
*
dimSize
=
new
int
[
order
];
for
(
int
i
=
2
;
i
<
a
.
order
;
i
++
)
dimSize
[
sub
++
]
=
a
.
dimSizeRDI
[
a
.
order
+
1
-
i
];
for
(
int
i
=
2
;
i
<
b
.
order
;
i
++
)
dimSize
[
sub
++
]
=
b
.
dimSizeRDI
[
b
.
order
+
1
-
i
];
dimSize
[
sub
++
]
=
an
;
dimSize
[
sub
++
]
=
bm
;
float
dr
=
(
!
a
.
isSparse
||
!
b
.
isSparse
)
?
1.0
F
:
MAX
(
a
.
denseRatio
,
b
.
denseRatio
);
XTensor
c
(
order
,
dimSize
,
a
.
dataType
,
dr
,
a
.
devID
,
a
.
mem
);
c
.
SetTMP
();
/* call _MatrixMul function */
_MatrixMul
(
&
a
,
X_NOTRANS
,
&
b
,
X_NOTRANS
,
&
c
,
alpha
,
0
,
parallelRunner
);
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMUL
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
AddParamToHeadTrans
(
&
c
,
X_NOTRANS
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
...
source/tensor/core/arithmetic/MatrixMul.h
查看文件 @
4b7f7a18
...
@@ -26,8 +26,10 @@
...
@@ -26,8 +26,10 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
#define MMul MatrixMul
/*
/*
matrix multiplication
matrix multiplication
c = trans(a) * trans(b) * alpha + c * beta
For the input tensors a and b, we perform matrix multiplicationon the first two dimentsions.
For the input tensors a and b, we perform matrix multiplicationon the first two dimentsions.
E.g., let A be a tensor of size y * z * m and B bea tensor of size x * y * n.
E.g., let A be a tensor of size y * z * m and B bea tensor of size x * y * n.
...
@@ -42,7 +44,7 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor
...
@@ -42,7 +44,7 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
DTYPE
beta
=
0
,
XPRunner
*
parallelRunner
=
NULL
);
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
DTYPE
beta
=
0
,
XPRunner
*
parallelRunner
=
NULL
);
/*
/*
matrix multiplication (return a XTensor structure)
matrix multiplication (return a XTensor structure)
c = trans(a) * trans(b) * alpha
make a new tensor c to keep the result and return it
make a new tensor c to keep the result and return it
For the input tensors a and b, we perform matrix multiplicationon the first two dimentsions.
For the input tensors a and b, we perform matrix multiplicationon the first two dimentsions.
...
@@ -55,7 +57,12 @@ C should be a tensor of z * x * n * m.
...
@@ -55,7 +57,12 @@ C should be a tensor of z * x * n * m.
Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x * y.
Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x * y.
*/
*/
XTensor
MatrixMul
(
const
XTensor
&
a
,
MATRIX_TRANS_TYPE
transposedA
,
const
XTensor
&
b
,
MATRIX_TRANS_TYPE
transposedB
,
XTensor
MatrixMul
(
const
XTensor
&
a
,
MATRIX_TRANS_TYPE
transposedA
,
const
XTensor
&
b
,
MATRIX_TRANS_TYPE
transposedB
,
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
DTYPE
beta
=
0
,
XPRunner
*
parallelRunner
=
NULL
);
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
XPRunner
*
parallelRunner
=
NULL
);
/* matrix multiplication with no transposition c = a * b * alpha*/
XTensor
MatrixMul
(
const
XTensor
&
a
,
const
XTensor
&
b
,
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
XPRunner
*
parallelRunner
=
NULL
);
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
...
...
source/tensor/core/arithmetic/MatrixMul2D.cu
查看文件 @
4b7f7a18
...
@@ -158,8 +158,11 @@ void _CudaMatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
...
@@ -158,8 +158,11 @@ void _CudaMatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
cublasSetStream(*handle, stream->stream);
cublasSetStream(*handle, stream->stream);
if (a->dataType == X_FLOAT && b->dataType == X_FLOAT && c->dataType == X_FLOAT) {
if (a->dataType == X_FLOAT && b->dataType == X_FLOAT && c->dataType == X_FLOAT) {
_CudaBLASMatrixMUL(handle, a->data, transposedA, a->dataType, b->data, transposedB, a->dataType, c->data, c->dataType,
_CudaBLASMatrixMUL(handle, a->data, transposedA, a->dataType,
a->dimSize[0], a->dimSize[1], b->dimSize[0], b->dimSize[1], c->dimSize[0], c->dimSize[1],
b->data, transposedB, a->dataType, c->data, c->dataType,
a->dimSize[0], a->dimSize[1],
b->dimSize[0], b->dimSize[1],
c->dimSize[0], c->dimSize[1],
alpha, beta);
alpha, beta);
}
}
else {
else {
...
...
source/tensor/core/arithmetic/MatrixMul2D.h
查看文件 @
4b7f7a18
source/tensor/core/arithmetic/MatrixMulBatched.cpp
查看文件 @
4b7f7a18
...
@@ -156,6 +156,7 @@ void _MatrixMulBatched(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
...
@@ -156,6 +156,7 @@ void _MatrixMulBatched(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
/*
/*
matrix multiplication of the two tensors (do it on site)
matrix multiplication of the two tensors (do it on site)
c = trans(a) * trans(b) * alpha
make a new tensor to keep the result and return it
make a new tensor to keep the result and return it
for each 2-dimensional data array in a (denoted as ai) and
for each 2-dimensional data array in a (denoted as ai) and
...
@@ -173,7 +174,7 @@ where trans() returns the transposed matrix if the flag is fired.
...
@@ -173,7 +174,7 @@ where trans() returns the transposed matrix if the flag is fired.
<< return - the result of matrix multiplication of the two tensors
<< return - the result of matrix multiplication of the two tensors
*/
*/
XTensor
MatrixMulBatched
(
const
XTensor
&
a
,
MATRIX_TRANS_TYPE
transposedA
,
const
XTensor
&
b
,
MATRIX_TRANS_TYPE
transposedB
,
XTensor
MatrixMulBatched
(
const
XTensor
&
a
,
MATRIX_TRANS_TYPE
transposedA
,
const
XTensor
&
b
,
MATRIX_TRANS_TYPE
transposedB
,
DTYPE
alpha
,
DTYPE
beta
,
XPRunner
*
parallelRunner
)
DTYPE
alpha
,
XPRunner
*
parallelRunner
)
{
{
CheckNTErrors
(
&
a
!=
&
NULLTensor
&&
&
b
!=
&
NULLTensor
,
"Empty input tensors!"
);
CheckNTErrors
(
&
a
!=
&
NULLTensor
&&
&
b
!=
&
NULLTensor
,
"Empty input tensors!"
);
CheckNTErrors
(
a
.
dataType
==
b
.
dataType
,
"Input tensors should have the same data type!"
);
CheckNTErrors
(
a
.
dataType
==
b
.
dataType
,
"Input tensors should have the same data type!"
);
...
@@ -200,14 +201,13 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const
...
@@ -200,14 +201,13 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const
c
.
SetTMP
();
c
.
SetTMP
();
/*call _MatrixMulBatched function */
/*call _MatrixMulBatched function */
_MatrixMulBatched
(
&
a
,
transposedA
,
&
b
,
transposedB
,
&
c
,
alpha
,
beta
,
parallelRunner
);
_MatrixMulBatched
(
&
a
,
transposedA
,
&
b
,
transposedB
,
&
c
,
alpha
,
0
,
parallelRunner
);
/* tensor connections */
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMULBATCHED
);
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_MATRIXMULBATCHED
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedA
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedA
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedB
);
XLink
::
AddParamToHeadTrans
(
&
c
,
transposedB
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHead
(
&
c
,
alpha
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
/* destroy variables */
/* destroy variables */
delete
[]
dimSize
;
delete
[]
dimSize
;
...
...
source/tensor/core/arithmetic/MatrixMulBatched.h
查看文件 @
4b7f7a18
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/*
/*
matrix multiplication of the two tensors
matrix multiplication of the two tensors
c = trans(a) * trans(b) * alpha + c * beta
for each 2-dimensional data array in a (denoted as ai) and
for each 2-dimensional data array in a (denoted as ai) and
each 2-dimensional data array in b (denoted as bi), we have
each 2-dimensional data array in b (denoted as bi), we have
...
@@ -38,7 +38,7 @@ void _MatrixMulBatched(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const X
...
@@ -38,7 +38,7 @@ void _MatrixMulBatched(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const X
XTensor
*
c
,
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
DTYPE
beta
=
0
,
XPRunner
*
parallelRunner
=
NULL
);
XTensor
*
c
,
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
DTYPE
beta
=
0
,
XPRunner
*
parallelRunner
=
NULL
);
/*
/*
matrix multiplication of the two tensors (return a XTensor structure)
matrix multiplication of the two tensors (return a XTensor structure)
c = trans(a) * trans(b) * alpha
make a new tensor to keep the result and return it
make a new tensor to keep the result and return it
for each 2-dimensional data array in a (denoted as ai) and
for each 2-dimensional data array in a (denoted as ai) and
...
@@ -47,7 +47,7 @@ ci = trans(ai) * trans(bi) * alpha + cm * beta
...
@@ -47,7 +47,7 @@ ci = trans(ai) * trans(bi) * alpha + cm * beta
where trans() returns the transposed matrix if the flag is fired
where trans() returns the transposed matrix if the flag is fired
*/
*/
XTensor
MatrixMulBatched
(
const
XTensor
&
a
,
MATRIX_TRANS_TYPE
transposedA
,
const
XTensor
&
b
,
MATRIX_TRANS_TYPE
transposedB
,
XTensor
MatrixMulBatched
(
const
XTensor
&
a
,
MATRIX_TRANS_TYPE
transposedA
,
const
XTensor
&
b
,
MATRIX_TRANS_TYPE
transposedB
,
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
DTYPE
beta
=
0
,
XPRunner
*
parallelRunner
=
NULL
);
DTYPE
alpha
=
(
DTYPE
)
1
.
0
,
XPRunner
*
parallelRunner
=
NULL
);
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
...
...
source/tensor/core/arithmetic/XTensorBLAS.cu
查看文件 @
4b7f7a18
source/tensor/core/getandset/SetData.cpp
查看文件 @
4b7f7a18
...
@@ -21,6 +21,8 @@
...
@@ -21,6 +21,8 @@
*/
*/
#include "SetData.h"
#include "SetData.h"
#include "SetData.cuh"
#include "../../XUtility.h"
#include "../movement/CopyValues.h"
#include "../movement/CopyValues.h"
#if !defined( WIN32 ) && !defined( _WIN32 )
#if !defined( WIN32 ) && !defined( _WIN32 )
...
@@ -36,12 +38,150 @@
...
@@ -36,12 +38,150 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/*
/*
generate data items with a fixed value p
>> tensor - the tensor whose data array would be initialized
>> p - pointer to the number for initializing the tensor
*/
void
_SetDataFixed
(
XTensor
*
tensor
,
void
*
valuePointer
)
{
int
num
=
tensor
->
unitNum
;
if
(
tensor
->
dataType
==
X_INT
){
int
p
=
*
(
int
*
)
valuePointer
;
if
(
tensor
->
devID
<
0
){
int
*
d
=
(
int
*
)
tensor
->
data
;
if
(
num
%
4
==
0
){
for
(
int
i
=
0
;
i
<
num
;
i
+=
4
){
d
[
i
]
=
p
;
d
[
i
+
1
]
=
p
;
d
[
i
+
2
]
=
p
;
d
[
i
+
3
]
=
p
;
}
}
else
{
for
(
int
i
=
0
;
i
<
num
;
i
++
)
d
[
i
]
=
p
;
}
}
else
{
#ifdef USE_CUDA
CudaSetDataFixedInt
(
tensor
,
p
);
#endif
}
}
else
if
(
tensor
->
dataType
==
X_FLOAT
){
float
p
=
*
(
float
*
)
valuePointer
;
if
(
tensor
->
devID
<
0
){
float
*
d
=
(
float
*
)
tensor
->
data
;
if
(
num
%
4
==
0
){
for
(
int
i
=
0
;
i
<
num
;
i
+=
4
){
d
[
i
]
=
p
;
d
[
i
+
1
]
=
p
;
d
[
i
+
2
]
=
p
;
d
[
i
+
3
]
=
p
;
}
}
else
{
for
(
int
i
=
0
;
i
<
num
;
i
++
)
d
[
i
]
=
p
;
}
}
else
{
#ifdef USE_CUDA
CudaSetDataFixedFloat
(
tensor
,
p
);
#endif
}
}
else
if
(
tensor
->
dataType
==
X_DOUBLE
){
double
p
=
*
(
double
*
)
valuePointer
;
if
(
tensor
->
devID
<
0
){
double
*
d
=
(
double
*
)
tensor
->
data
;
if
(
num
%
4
==
0
){
for
(
int
i
=
0
;
i
<
num
;
i
+=
4
){
d
[
i
]
=
p
;
d
[
i
+
1
]
=
p
;
d
[
i
+
2
]
=
p
;
d
[
i
+
3
]
=
p
;
}
}
else
{
for
(
int
i
=
0
;
i
<
num
;
i
++
)
d
[
i
]
=
p
;
}
}
else
{
#ifdef USE_CUDA
CudaSetDataFixedDouble
(
tensor
,
p
);
#endif
}
}
else
{
ShowNTErrors
(
"TODO"
);
}
}
/*
generate data items with a fixed value p (in default type)
>> tensor - the tensor whose data array would be initialized
>> p - number in default type
*/
void
SetDataFixed
(
XTensor
&
tensor
,
DTYPE
p
)
{
_SetDataFixed
(
&
tensor
,
&
p
);
}
/*
generate data items with a fixed value p (in integer)
>> tensor - the tensor whose data array would be initialized
>> p - an int-valued number
*/
void
_SetDataFixedInt
(
XTensor
*
tensor
,
int
p
)
{
CheckNTErrors
(
tensor
->
dataType
==
X_INT
,
"the tensor must be in X_INT"
);
if
(
p
==
0
)
tensor
->
SetZeroAll
();
else
_SetDataFixed
(
tensor
,
&
p
);
}
/*
generate data items with a fixed value p (in float)
>> tensor - the tensor whose data array would be initialized
>> p - a float-valued number
*/
void
_SetDataFixedFloat
(
XTensor
*
tensor
,
float
p
)
{
CheckNTErrors
(
tensor
->
dataType
==
X_FLOAT
,
"the tensor must be in X_INT"
);
if
(
p
==
0
)
tensor
->
SetZeroAll
();
else
_SetDataFixed
(
tensor
,
&
p
);
}
/*
generate data items with a fixed value p (in double)
>> tensor - the tensor whose data array would be initialized
>> p - a double-valued number
*/
void
_SetDataFixedDouble
(
XTensor
*
tensor
,
double
p
)
{
CheckNTErrors
(
tensor
->
dataType
==
X_DOUBLE
,
"the tensor must be in X_INT"
);
if
(
p
==
0
)
tensor
->
SetZeroAll
();
else
_SetDataFixed
(
tensor
,
&
p
);
}
/*
generate data items with a uniform distribution in [low,high]
generate data items with a uniform distribution in [low,high]
>> tensor - the tensor whose data array would be initialized
>> tensor - the tensor whose data array would be initialized
>> low - lower value of the range
>> low - lower value of the range
>> high - higher value of the range
>> high - higher value of the range
*/
*/
void
SetDataRand
(
XTensor
*
tensor
,
DTYPE
low
,
DTYPE
high
)
void
_
SetDataRand
(
XTensor
*
tensor
,
DTYPE
low
,
DTYPE
high
)
{
{
if
(
tensor
==
NULL
)
if
(
tensor
==
NULL
)
return
;
return
;
...
@@ -76,7 +216,7 @@ void SetDataRand(XTensor * tensor, DTYPE low, DTYPE high)
...
@@ -76,7 +216,7 @@ void SetDataRand(XTensor * tensor, DTYPE low, DTYPE high)
*/
*/
else
{
else
{
XTensor
*
t2
=
NewTensor
(
tensor
->
order
,
tensor
->
dimSize
,
tensor
->
dataType
,
tensor
->
denseRatio
,
-
1
);
XTensor
*
t2
=
NewTensor
(
tensor
->
order
,
tensor
->
dimSize
,
tensor
->
dataType
,
tensor
->
denseRatio
,
-
1
);
SetDataRand
(
t2
,
low
,
high
);
_
SetDataRand
(
t2
,
low
,
high
);
_CopyValues
(
t2
,
tensor
);
_CopyValues
(
t2
,
tensor
);
delete
t2
;
delete
t2
;
}
}
...
...
source/tensor/core/getandset/SetData.cu
查看文件 @
4b7f7a18
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-18
* I'm surprised that I did not write this file till today.
*/
#include "SetData.cuh"
#include "../../XDevice.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
set an integer data array with a fixed value p (in int)
>> d - pointer to the data array
>> size - size of the array
>> p - the initial value
*/
__global__
void KernelSetDataFixedInt(int * d, int size, int p)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
d[i] = p;
}
/*
generate data items with a fixed value p (in int)
>> tensor - the tensor for initialization
>> p - the initial value
*/
void CudaSetDataFixedInt(XTensor * tensor, int p)
{
CheckNTErrors(tensor->dataType == X_INT, "the tensor must be in X_INT!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(tensor->devID, tensor->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataFixedInt <<<blocks, threads >>>((int*)tensor->data, tensor->unitNum, p);
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
set a float data array with a fixed value p (in int)
>> d - pointer to the data array
>> size - size of the array
>> p - the initial value
*/
__global__
void KernelSetDataFixedFloat(float * d, int size, float p)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
d[i] = p;
}
/*
generate data items with a fixed value p (in float)
>> tensor - the tensor for initialization
>> p - the initial value
*/
void CudaSetDataFixedFloat(XTensor * tensor, float p)
{
CheckNTErrors(tensor->dataType == X_FLOAT, "the tensor must be in X_FLOAT!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(tensor->devID, tensor->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataFixedFloat <<<blocks, threads >>>((float*)tensor->data, tensor->unitNum, p);
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
set a double data array with a fixed value p (in int)
>> d - pointer to the data array
>> size - size of the array
>> p - the initial value
*/
__global__
void KernelSetDataFixedDouble(double * d, int size, double p)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
d[i] = p;
}
/*
generate data items with a fixed value p (in double)
>> tensor - the tensor for initialization
>> p - the initial value
*/
void CudaSetDataFixedDouble(XTensor * tensor, double p)
{
CheckNTErrors(tensor->dataType == X_DOUBLE, "the tensor must be in X_DOUBLE!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(tensor->devID, tensor->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataFixedDouble <<<blocks, threads >>>((double*)tensor->data, tensor->unitNum, p);
BacktoCudaDev(tensor->devID, devIDBackup);
}
} // namespace nts(NiuTrans.Tensor)
source/tensor/core/getandset/SetData.cuh
查看文件 @
4b7f7a18
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-18
* I'm surprised that I did not write this file till today.
*/
#ifndef __SETDATA_CUH__
#define __SETDATA_CUH__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* generate data items with a fixed value p (in int) */
void CudaSetDataFixedInt(XTensor * tensor, int p);
/* generate data items with a fixed value p (in float) */
void CudaSetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */
void CudaSetDataFixedDouble(XTensor * tensor, double p);
} // namespace nts(NiuTrans.Tensor)
#endif // __SETDATA_CUH__
\ No newline at end of file
source/tensor/core/getandset/SetData.h
查看文件 @
4b7f7a18
...
@@ -28,28 +28,25 @@
...
@@ -28,28 +28,25 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/* generate data items with a fixed value p */
/* generate data items with a fixed value p */
extern
"C"
void
_SetDataFixed
(
XTensor
*
tensor
,
void
*
valuePointer
);
void
SetDataFixed
(
XTensor
*
tensor
,
void
*
valuePointer
);
/* generate data items with a fixed value p (in default type) */
void
SetDataFixed
(
XTensor
&
tensor
,
DTYPE
p
);
/* generate data items with a fixed value p (in int) */
/* generate data items with a fixed value p (in int) */
extern
"C"
void
_SetDataFixedInt
(
XTensor
*
tensor
,
int
p
);
void
SetDataFixedInt
(
XTensor
*
tensor
,
int
p
);
/* generate data items with a fixed value p (in float) */
/* generate data items with a fixed value p (in float) */
extern
"C"
void
_SetDataFixedFloat
(
XTensor
*
tensor
,
float
p
);
void
SetDataFixedFloat
(
XTensor
*
tensor
,
float
p
);
/* generate data items with a fixed value p (in double) */
/* generate data items with a fixed value p (in double) */
extern
"C"
void
_SetDataFixedDouble
(
XTensor
*
tensor
,
double
p
);
void
SetDataFixedDouble
(
XTensor
*
tensor
,
double
p
);
/* generate data items with a uniform distribution in [low,high] */
/* generate data items with a uniform distribution in [low,high] */
extern
"C"
void
_SetDataRand
(
XTensor
*
tensor
,
DTYPE
low
,
DTYPE
high
);
void
SetDataRand
(
XTensor
*
tensor
,
DTYPE
low
,
DTYPE
high
);
/* generate data items with a normal distribution with specified mean and standard deviation */
/* generate data items with a normal distribution with specified mean and standard deviation */
extern
"C"
void
_SetDataRandN
(
XTensor
*
tensor
,
DTYPE
mean
,
DTYPE
standardDeviation
);
void
SetDataRandN
(
XTensor
*
tensor
,
DTYPE
mean
,
DTYPE
standardDeviation
);
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
...
...
source/tensor/function/HardTanH.cpp
查看文件 @
4b7f7a18
source/tensor/function/HardTanH.h
查看文件 @
4b7f7a18
...
@@ -27,6 +27,8 @@
...
@@ -27,6 +27,8 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
#define HTanH HardTanH
/*
/*
hard tanh function
hard tanh function
y = 1 if x > 1
y = 1 if x > 1
...
...
source/tensor/function/Loss.cpp
查看文件 @
4b7f7a18
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include <math.h>
#include <math.h>
#include "Loss.h"
#include "Loss.h"
#include "Loss.cuh"
#include "Loss.cuh"
#include "../core/getandset/SetData.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
...
@@ -383,6 +384,19 @@ void LossBackward(XTensor * dedy, XTensor * t, XTensor * y,
...
@@ -383,6 +384,19 @@ void LossBackward(XTensor * dedy, XTensor * t, XTensor * y,
LOSS_FUNCTION_NAME
LFName
,
LOSS_FUNCTION_NAME
LFName
,
int
leadDim
,
int
tBeg
,
int
tLen
,
int
yBeg
)
int
leadDim
,
int
tBeg
,
int
tLen
,
int
yBeg
)
{
{
if
(
t
==
NULL
){
if
(
dedy
->
dataType
==
X_FLOAT
)
_SetDataFixedFloat
(
dedy
,
1.0
F
);
else
if
(
dedy
->
dataType
==
X_DOUBLE
)
_SetDataFixedDouble
(
dedy
,
1.0
);
else
if
(
dedy
->
dataType
==
X_INT
)
_SetDataFixedInt
(
dedy
,
1
);
else
{
ShowNTErrors
(
"TODO"
);
}
return
;
}
if
(
t
->
order
<
0
)
if
(
t
->
order
<
0
)
return
;
return
;
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论