Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
daa2f801
Commit
daa2f801
authored
Jul 29, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
full code of SumDim
parent
454bd870
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
64 行增加
和
6 行删除
+64
-6
source/network/Main.cpp
+3
-5
source/tensor/XName.h
+2
-1
source/tensor/core/arithmetic/SumDim.cpp
+51
-0
source/tensor/core/arithmetic/SumDim.cu
+4
-0
source/tensor/core/arithmetic/SumDim.cuh
+4
-0
没有找到文件。
source/network/Main.cpp
查看文件 @
daa2f801
...
...
@@ -181,11 +181,9 @@ void SumDimTest()
int
b
=
7
;
int
c
=
3
;
XDevice
::
SetGPUDevice
(
0
);
InitTensor3D
(
&
x
,
a
,
b
,
c
,
X_FLOAT
,
0
);
InitTensor1D
(
&
y
,
c
,
X_FLOAT
,
0
);
InitTensor3D
(
&
z
,
a
,
b
,
c
,
X_FLOAT
,
0
);
InitTensor3D
(
&
x
,
a
,
b
,
c
,
X_FLOAT
,
-
1
);
InitTensor1D
(
&
y
,
c
,
X_FLOAT
,
-
1
);
InitTensor3D
(
&
z
,
a
,
b
,
c
,
X_FLOAT
,
-
1
);
x
.
SetZeroAll
();
y
.
SetZeroAll
();
...
...
source/tensor/XName.h
查看文件 @
daa2f801
...
...
@@ -37,8 +37,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_NEGATE MATH_MULTIPLY + 1
#define MATH_SIGN MATH_NEGATE + 1
#define MATH_SUM MATH_SIGN + 1
#define MATH_SUMDIM MATH_SUM + 1
#define MATH_LOG MATH_SUM + 1
#define MATH_LOG MATH_SUM
DIM
+ 1
#define MATH_NORMALIZE MATH_LOG + 1
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
...
...
source/tensor/core/arithmetic/SumDim.cpp
查看文件 @
daa2f801
...
...
@@ -22,6 +22,7 @@
#include "Sum.h"
#include "SumDim.h"
#include "SumDim.cuh"
#include "../../XName.h"
#include "../movement/CopyValues.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
...
...
@@ -60,7 +61,11 @@ void _SumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE bet
}
if
(
a
->
devID
>=
0
||
b
->
devID
>=
0
||
c
->
devID
>=
0
){
#ifdef USE_CUDA
_CudaSumDim
(
a
,
b
,
c
,
n
,
beta
);
#else
ShowNTErrors
(
"Please specify USE_CUDA and recompile the code!"
);
#endif
}
else
{
int
stride
=
1
;
...
...
@@ -110,4 +115,50 @@ void _SumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE bet
}
}
/*
tensor summation (on site)
a = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> beta - the scaling factor
*/
void
_SumDim
(
XTensor
*
a
,
const
XTensor
*
b
,
int
n
,
DTYPE
beta
)
{
_SumDim
(
a
,
b
,
a
,
n
,
beta
);
}
/*
tensor summation (return a structure and make tensor connections)
c = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a+b*\beta. we save it in a if c is NULL
>> n - the dimension index
>> beta - the scaling factor
*/
XTensor
SumDim
(
const
XTensor
&
a
,
const
XTensor
&
b
,
int
n
,
DTYPE
beta
)
{
XTensor
c
(
&
a
);
c
.
SetTMP
();
/* call _Sum function */
_Sum
(
&
a
,
&
b
,
&
c
,
beta
);
/* tensor connections */
XLink
::
MakeLink
(
&
a
,
&
b
,
&
c
,
MATH_SUMDIM
);
XLink
::
AddParamToHeadInt
(
&
c
,
n
);
XLink
::
AddParamToHead
(
&
c
,
beta
);
return
c
;
}
}
source/tensor/core/arithmetic/SumDim.cu
查看文件 @
daa2f801
...
...
@@ -24,6 +24,8 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
tensor summation of a tensor and a row vector
c = a + b * \beta
...
...
@@ -168,5 +170,7 @@ void _CudaSumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE
}
}
#endif
} // namespace nts(NiuTrans.Tensor)
source/tensor/core/arithmetic/SumDim.cuh
查看文件 @
daa2f801
...
...
@@ -26,10 +26,14 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* tensor summation c = a + b * \beta where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting (cuda version) */
void _CudaSumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta = (DTYPE)1.0);
#endif
} // namespace nts(NiuTrans.Tensor)
#endif // __SUMDIM_CUH__
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论