Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
Emmay
NiuTrans.Tensor
Commits
80cfa480
Commit
80cfa480
authored
Aug 18, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add low-tri matrix data initialization metthod
parent
45a5a936
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
143 行增加
和
0 行删除
+143
-0
source/tensor/core/getandset/SetData.cpp
+50
-0
source/tensor/core/getandset/SetData.cu
+87
-0
source/tensor/core/getandset/SetData.cuh
+3
-0
source/tensor/core/getandset/SetData.h
+3
-0
没有找到文件。
source/tensor/core/getandset/SetData.cpp
查看文件 @
80cfa480
...
...
@@ -213,6 +213,56 @@ void _SetDataFixedDouble(XTensor * tensor, double p)
_SetDataFixed
(
tensor
,
&
p
);
}
/*
generate data as lower triangular matrics for last two dimensions
>> tensor - the tensor whose data to be set
>> p - the value for each entry of the lower triangular matrics
>> shift - the offset from diagonal
e.g., for a 3* 3 tensor,
when p = 1 ans shift = 0, we have
1 0 0
1 1 0
1 1 1
when p = 2 and shift = -1, we have
0 0 0
2 0 0
2 2 0
*/
void
_SetDataLowTri
(
XTensor
*
tensor
,
DTYPE
p
,
int
shift
)
{
int
n
=
tensor
->
order
;
CheckNTErrors
(
tensor
->
dataType
==
DEFAULT_DTYPE
,
"TODO!"
);
CheckNTErrors
(
n
>=
2
,
"The tensor must have a order no less than 2!"
);
CheckNTErrors
(
tensor
->
GetDim
(
n
-
1
)
==
tensor
->
GetDim
(
n
-
2
),
"The last two dimensions must be of the same size!"
);
if
(
tensor
->
devID
<
0
){
int
l
=
tensor
->
GetDim
(
-
1
);
int
blockNum
=
1
;
int
blockSize
=
l
*
l
;
for
(
int
i
=
0
;
i
<
n
-
2
;
i
++
)
blockNum
*=
tensor
->
GetDim
(
i
);
for
(
int
i
=
0
;
i
<
blockNum
;
i
++
){
DTYPE
*
d
=
(
DTYPE
*
)
tensor
->
data
+
i
*
blockSize
;
for
(
int
row
=
0
;
row
<
l
;
row
++
){
for
(
int
col
=
0
;
col
<
row
+
shift
;
col
++
){
d
[
row
*
l
+
col
]
=
row
;
}
for
(
int
col
=
row
+
shift
;
col
<
l
;
col
++
){
d
[
row
*
l
+
col
]
=
0
;
}
}
}
}
else
{
#ifdef USE_CUDA
_CudaSetDataLowTri
(
tensor
,
p
,
shift
);
#endif
}
}
/*
generate data items with a uniform distribution in [lower, upper]
>> tensor - the tensor whose data array would be initialized
...
...
source/tensor/core/getandset/SetData.cu
查看文件 @
80cfa480
...
...
@@ -184,6 +184,93 @@ void KernelSetDataRandDouble(double * d, int size, DTYPE lower, DTYPE variance)
}
}
/*
set lower triangular matrics for each block
>> d - pointer to the data array
>> l - row number (or column number) of each block, i.e,
a block is l * l matrix
>> blockSize - size of each block (blockSize = l * l)
>> blockNum - number of the blocks
>> p - the value for each entry of the lower triangular matrics
>> shift - the offset from diagonal
e.g., for a 3* 3 tensor,
when p = 1 ans shift = 0, we have
1 0 0
1 1 0
1 1 1
when p = 2 and shift = -1, we have
0 0 0
2 0 0
2 2 0
*/
__global__
void _KernalSetDataLowTri(DTYPE * d, int l, int blockSize, int blockNum, DTYPE p, int shift)
{
/* offset in each block */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* block id */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= blockSize || j > blockNum)
return;
int row = i / l;
int col = i % l;
DTYPE * d2 = d + blockSize * j + row * l + col;
if(col < row + shift)
*d2 = p;
else
*d2 = 0;
}
/*
generate data as lower triangular matrics for last two dimensions (cuda version)
>> tensor - the tensor whose data to be set
>> p - the value for each entry of the lower triangular matrics
>> shift - the offset from diagonal
e.g., for a 3* 3 tensor,
when p = 1 ans shift = 0, we have
1 0 0
1 1 0
1 1 1
when p = 2 and shift = -1, we have
0 0 0
2 0 0
2 2 0
*/
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(n >= 2, "The tensor must have a order no less than 2!");
CheckNTErrors(tensor->GetDim(n - 1) == tensor->GetDim(n - 2),
"The last two dimensions must be of the same size!");
int l = tensor->GetDim(-1);
int blockNum = 1;
int blockSize = l * l;
for(int i = 0; i < n - 2; i++)
blockNum *= tensor->GetDim(i);
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(tensor->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
_KernalSetDataLowTri<<<blocks, threads >>>((DTYPE*)tensor->data, l, blockSize, blockNum, p, shift);
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
generate data items with a uniform distribution in [lower, upper]
>> tensor - the tensor whose data array would be initialized
...
...
source/tensor/core/getandset/SetData.cuh
查看文件 @
80cfa480
...
...
@@ -37,6 +37,9 @@ void _CudaSetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */
void _CudaSetDataFixedDouble(XTensor * tensor, double p);
/* generate data as lower triangular matrics for last two dimensions (cuda version) */
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift);
/* generate data items with a uniform distribution in [lower, upper] */
void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
...
...
source/tensor/core/getandset/SetData.h
查看文件 @
80cfa480
...
...
@@ -45,6 +45,9 @@ void _SetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */
void
_SetDataFixedDouble
(
XTensor
*
tensor
,
double
p
);
/* generate data as lower triangular matrics for last two dimensions */
void
_SetDataLowTri
(
XTensor
*
tensor
,
DTYPE
p
,
int
shift
);
/* generate data items with a uniform distribution in [lower, upper] */
void
_SetDataRand
(
XTensor
*
tensor
,
DTYPE
lower
,
DTYPE
upper
);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论