Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
4e8872e9
Commit
4e8872e9
authored
6 years ago
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bug fixes in matrix multiplication
parent
f21e1b48
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
76 行增加
和
15 行删除
+76
-15
source/network/XBackwardMath.cpp
+44
-4
source/network/XBackwardMath.h
+7
-2
source/sample/transformer/T2TUtility.cpp
+2
-3
source/sample/transformer/T2TUtility.h
+2
-3
source/tensor/core/arithmetic/MatrixMul.cpp
+21
-3
没有找到文件。
source/network/XBackwardMath.cpp
查看文件 @
4e8872e9
...
@@ -259,16 +259,58 @@ void XMathGrad::GradMatrixMul(XTensor * node)
...
@@ -259,16 +259,58 @@ void XMathGrad::GradMatrixMul(XTensor * node)
XNoder
::
MakeGrad
(
a
);
XNoder
::
MakeGrad
(
a
);
XNoder
::
MakeGrad
(
b
);
XNoder
::
MakeGrad
(
b
);
XTensor
*
c
=
node
;
XTensor
*
dedc
=
node
->
grad
;
XTensor
*
dedc
=
node
->
grad
;
XTensor
*
deda
=
a
->
grad
;
XTensor
*
deda
=
a
->
grad
;
XTensor
*
dedb
=
b
->
grad
;
XTensor
*
dedb
=
b
->
grad
;
if
(
deda
->
order
==
2
&&
dedb
->
order
==
2
)
GradMatrixMul
(
a
,
deda
,
transA
,
b
,
dedb
,
transB
,
dedc
,
alpha
);
else
if
(
transA
==
X_NOTRANS
&&
deda
->
order
>
2
&&
dedb
->
order
==
2
){
int
orderBackupA
=
a
->
order
;
int
orderBackupC
=
c
->
order
;
int
dimsBackupA
[
MAX_TENSOR_DIM_NUM
];
int
dimsBackupC
[
MAX_TENSOR_DIM_NUM
];
memcpy
(
dimsBackupA
,
a
->
dimSize
,
sizeof
(
int
)
*
a
->
order
);
memcpy
(
dimsBackupC
,
c
->
dimSize
,
sizeof
(
int
)
*
c
->
order
);
int
dimsA
[
2
]
=
{
a
->
unitNum
/
a
->
GetDim
(
-
1
),
a
->
GetDim
(
-
1
)};
int
dimsC
[
2
]
=
{
c
->
unitNum
/
c
->
GetDim
(
-
1
),
c
->
GetDim
(
-
1
)};
a
->
Reshape
(
2
,
dimsA
);
c
->
Reshape
(
2
,
dimsC
);
deda
->
Reshape
(
2
,
dimsA
);
dedc
->
Reshape
(
2
,
dimsC
);
GradMatrixMul
(
a
,
deda
,
transA
,
b
,
dedb
,
transB
,
dedc
,
alpha
);
a
->
Reshape
(
orderBackupA
,
dimsBackupA
);
c
->
Reshape
(
orderBackupC
,
dimsBackupC
);
deda
->
Reshape
(
orderBackupA
,
dimsBackupA
);
dedc
->
Reshape
(
orderBackupC
,
dimsBackupC
);
}
else
{
ShowNTErrors
(
"TODO!"
);
}
node
->
visitMark
=
NODE_FINISHED
;
}
/*
gradient for matrix multiply: c = matmul(a, b) * \alpha
>> a - as it is
>> deda - dE/da
>> b - as it is
>> dedb - dE/db
>> dedc - dE/dc
>> alpha - the scalar
*/
void
XMathGrad
::
GradMatrixMul
(
XTensor
*
a
,
XTensor
*
deda
,
MATRIX_TRANS_TYPE
transA
,
XTensor
*
b
,
XTensor
*
dedb
,
MATRIX_TRANS_TYPE
transB
,
XTensor
*
dedc
,
DTYPE
alpha
)
{
/* c = a * b * \alpha */
/* c = a * b * \alpha */
if
(
transA
==
X_NOTRANS
&&
transB
==
X_NOTRANS
){
if
(
transA
==
X_NOTRANS
&&
transB
==
X_NOTRANS
){
/* dE/da = dE/dc * b^T * \alpha */
/* dE/da = dE/dc * b^T * \alpha */
_MatrixMul
(
dedc
,
X_NOTRANS
,
b
,
X_TRANS
,
deda
,
alpha
,
1.0
F
);
_MatrixMul
(
dedc
,
X_NOTRANS
,
b
,
X_TRANS
,
deda
,
alpha
,
1.0
F
);
/* dE/db = a^T * dE/dc * \alpha */
/* dE/db = a^T * dE/dc * \alpha */
_MatrixMul
(
a
,
X_TRANS
,
dedc
,
X_NOTRANS
,
dedb
,
alpha
,
1.0
F
);
_MatrixMul
(
a
,
X_TRANS
,
dedc
,
X_NOTRANS
,
dedb
,
alpha
,
1.0
F
);
}
}
...
@@ -302,8 +344,6 @@ void XMathGrad::GradMatrixMul(XTensor * node)
...
@@ -302,8 +344,6 @@ void XMathGrad::GradMatrixMul(XTensor * node)
/* dE/db = a * dE/dc * \alpha */
/* dE/db = a * dE/dc * \alpha */
_MatrixMul
(
a
,
X_NOTRANS
,
dedc
,
X_NOTRANS
,
dedb
,
alpha
,
1.0
F
);
_MatrixMul
(
a
,
X_NOTRANS
,
dedc
,
X_NOTRANS
,
dedb
,
alpha
,
1.0
F
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
...
This diff is collapsed.
Click to expand it.
source/network/XBackwardMath.h
查看文件 @
4e8872e9
...
@@ -56,6 +56,12 @@ private:
...
@@ -56,6 +56,12 @@ private:
/* gradient for matrix multiply: c = matmul(a, b) */
/* gradient for matrix multiply: c = matmul(a, b) */
static
static
void
GradMatrixMul
(
XTensor
*
node
);
void
GradMatrixMul
(
XTensor
*
node
);
/* gradient for matrix multiply: c = matmul(a, b) */
static
void
GradMatrixMul
(
XTensor
*
a
,
XTensor
*
deda
,
MATRIX_TRANS_TYPE
transA
,
XTensor
*
b
,
XTensor
*
dedb
,
MATRIX_TRANS_TYPE
transB
,
XTensor
*
dedc
,
DTYPE
alpha
);
/* gradient for log: c = log(a) */
/* gradient for log: c = log(a) */
static
static
...
@@ -128,4 +134,4 @@ private:
...
@@ -128,4 +134,4 @@ private:
}
}
#endif
#endif
\ No newline at end of file
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TUtility.cpp
查看文件 @
4e8872e9
...
@@ -26,7 +26,7 @@
...
@@ -26,7 +26,7 @@
namespace
transformer
namespace
transformer
{
{
void
LoadParamString
(
int
argc
,
const
char
**
argv
,
const
char
*
name
,
char
*
p
,
char
*
defaultP
)
void
LoadParamString
(
int
argc
,
const
char
**
argv
,
const
char
*
name
,
char
*
p
,
c
onst
c
har
*
defaultP
)
{
{
char
vname
[
128
];
char
vname
[
128
];
vname
[
0
]
=
'-'
;
vname
[
0
]
=
'-'
;
...
@@ -108,4 +108,4 @@ void ShowParams(int argc, const char ** argv)
...
@@ -108,4 +108,4 @@ void ShowParams(int argc, const char ** argv)
fprintf
(
stderr
,
"
\n
"
);
fprintf
(
stderr
,
"
\n
"
);
}
}
}
}
\ No newline at end of file
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TUtility.h
查看文件 @
4e8872e9
...
@@ -28,7 +28,7 @@ namespace transformer
...
@@ -28,7 +28,7 @@ namespace transformer
{
{
/* load arguments */
/* load arguments */
void
LoadParamString
(
int
argc
,
const
char
**
argv
,
const
char
*
name
,
char
*
p
,
char
*
defaultP
);
void
LoadParamString
(
int
argc
,
const
char
**
argv
,
const
char
*
name
,
char
*
p
,
c
onst
c
har
*
defaultP
);
void
LoadParamInt
(
int
argc
,
const
char
**
argv
,
const
char
*
name
,
int
*
p
,
int
defaultP
);
void
LoadParamInt
(
int
argc
,
const
char
**
argv
,
const
char
*
name
,
int
*
p
,
int
defaultP
);
void
LoadParamBool
(
int
argc
,
const
char
**
argv
,
const
char
*
name
,
bool
*
p
,
bool
defaultP
);
void
LoadParamBool
(
int
argc
,
const
char
**
argv
,
const
char
*
name
,
bool
*
p
,
bool
defaultP
);
void
LoadParamFloat
(
int
argc
,
const
char
**
argv
,
const
char
*
name
,
float
*
p
,
float
defaultP
);
void
LoadParamFloat
(
int
argc
,
const
char
**
argv
,
const
char
*
name
,
float
*
p
,
float
defaultP
);
...
@@ -38,4 +38,4 @@ void ShowParams(int argc, const char ** argv);
...
@@ -38,4 +38,4 @@ void ShowParams(int argc, const char ** argv);
}
}
#endif
#endif
\ No newline at end of file
This diff is collapsed.
Click to expand it.
source/tensor/core/arithmetic/MatrixMul.cpp
查看文件 @
4e8872e9
...
@@ -53,11 +53,29 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
...
@@ -53,11 +53,29 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const
XTensor
*
b
,
MATRIX_TRANS_TYPE
transposedB
,
const
XTensor
*
b
,
MATRIX_TRANS_TYPE
transposedB
,
XTensor
*
c
,
DTYPE
alpha
,
DTYPE
beta
,
XPRunner
*
parallelRunner
)
XTensor
*
c
,
DTYPE
alpha
,
DTYPE
beta
,
XPRunner
*
parallelRunner
)
{
{
CheckNTErrors
(
(
a
&&
b
&&
c
)
,
"Empty input tensors!"
);
CheckNTErrors
(
a
&&
b
&&
c
,
"Empty input tensors!"
);
CheckNTErrors
(
(
a
->
dataType
==
b
->
dataType
&&
a
->
dataType
==
c
->
dataType
)
,
CheckNTErrors
(
a
->
dataType
==
b
->
dataType
&&
a
->
dataType
==
c
->
dataType
,
"Input tensors should have the same data type!"
);
"Input tensors should have the same data type!"
);
CheckNTErrors
(
(
a
->
order
>=
2
&&
b
->
order
>=
2
&&
c
->
order
>=
2
)
,
CheckNTErrors
(
a
->
order
>=
2
&&
b
->
order
>=
2
&&
c
->
order
>=
2
,
"Input tensors must have a order >= 2!"
);
"Input tensors must have a order >= 2!"
);
CheckNTErrors
(
c
->
order
==
a
->
order
+
b
->
order
-
2
,
"wrong tensor order"
)
/* we transform a higher order tensor to a matrix to kill the number
of calls of matrix multiplication */
if
(
transposedA
==
X_NOTRANS
&&
a
->
order
>
2
&&
b
->
order
==
2
){
int
ncolA
=
a
->
dimSize
[
a
->
order
-
1
];
int
ncolC
=
c
->
dimSize
[
c
->
order
-
1
];
XTensor
*
a2
=
NewTensor2D
(
a
->
unitNum
/
ncolA
,
-
ncolA
,
a
->
dataType
,
a
->
devID
,
a
->
mem
);
XTensor
*
c2
=
NewTensor2D
(
c
->
unitNum
/
ncolC
,
-
ncolC
,
c
->
dataType
,
c
->
devID
,
c
->
mem
);
a2
->
data
=
a
->
data
;
c2
->
data
=
c
->
data
;
_MatrixMul2D
(
a2
,
transposedA
,
b
,
transposedB
,
c2
,
alpha
,
beta
,
parallelRunner
);
a2
->
data
=
NULL
;
c2
->
data
=
NULL
;
delete
a2
;
delete
c2
;
return
;
}
int
an
=
transposedA
==
X_TRANS
?
a
->
dimSizeRDI
[
0
]
:
a
->
dimSizeRDI
[
1
];
int
an
=
transposedA
==
X_TRANS
?
a
->
dimSizeRDI
[
0
]
:
a
->
dimSizeRDI
[
1
];
int
am
=
transposedA
==
X_TRANS
?
a
->
dimSizeRDI
[
1
]
:
a
->
dimSizeRDI
[
0
];
int
am
=
transposedA
==
X_TRANS
?
a
->
dimSizeRDI
[
1
]
:
a
->
dimSizeRDI
[
0
];
...
...
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论