Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
229db8c6
Commit
229db8c6
authored
Nov 23, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
improve the implementation of Unsqueeze
parent
fe90b454
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
112 行增加
和
18 行删除
+112
-18
source/network/XBackwardMath.cpp
+8
-8
source/tensor/core/shape/Unsqueeze.cu
+104
-10
没有找到文件。
source/network/XBackwardMath.cpp
查看文件 @
229db8c6
...
...
@@ -1323,8 +1323,8 @@ gradient for reduceSumSquared
for
c = \sum_i (a_i - b)^2
we have
dE/da
= Unsqueeze(dE/dc) * 2a
dE/db = dE/dc * -2 * n *
b
dE/da
_i = Unsqueeze(dE/dc) * 2 * (a_i - b)
dE/db = dE/dc * -2 * n *
\sum_i (a_i - b)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
...
...
@@ -1352,12 +1352,12 @@ void XMathGrad::GradReduceSumSquared(XTensor * node, bool isEfficient)
_Sub
(
a
,
c
,
d
);
_ReduceSum
(
d
,
f
,
dim
);
/* dE/da
= Unsqueeze(dE/dc) * 2(a-
b) */
/* dE/da
_i = Unsqueeze(dE/dc) * 2 * (a_i -
b) */
_ScaleAndShiftMe
(
d
,
2.0
F
);
_Unsqueeze
(
node
->
grad
,
e
,
dim
,
n
);
_Multiply
(
d
,
e
,
a
->
grad
,
1.0
F
);
/* dE/db = dE/dc * -2 *
(a-b*n
) */
/* dE/db = dE/dc * -2 *
n * \sum_i (a_i - b
) */
_ScaleAndShiftMe
(
f
,
-
2.0
F
);
_Multiply
(
node
->
grad
,
f
,
b
->
grad
,
1.0
F
);
...
...
@@ -1375,8 +1375,8 @@ for
c = (sum_i (a_i - b)^2) * 1/n
where b is the mean, and n is the size of a
we have
dE/da
= Unsqueeze(dE/dc) * 2a
/n
dE/db = dE/dc * -2 *
b
dE/da
_i = Unsqueeze(dE/dc) * 2 * (a_i - b)
/n
dE/db = dE/dc * -2 *
\sum_i (a_i - b)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
...
...
@@ -1404,12 +1404,12 @@ void XMathGrad::GradReduceVariance(XTensor * node, bool isEfficient)
_Sub
(
a
,
c
,
d
);
_ReduceSum
(
d
,
f
,
dim
);
/* dE/da
= Unsqueeze(dE/dc) * 2 (a-
b) / n */
/* dE/da
_i = Unsqueeze(dE/dc) * 2 * (a_i -
b) / n */
_ScaleAndShiftMe
(
d
,
2.0
F
/
n
);
_Unsqueeze
(
node
->
grad
,
e
,
dim
,
n
);
_Multiply
(
d
,
e
,
a
->
grad
,
1.0
F
);
/* dE/db = dE/dc * -2 *
(a-
b) */
/* dE/db = dE/dc * -2 *
\sum_i (a_i -
b) */
_ScaleAndShiftMe
(
f
,
-
2.0
F
/
n
);
_Multiply
(
node
->
grad
,
f
,
b
->
grad
,
1.0
F
);
...
...
source/tensor/core/shape/Unsqueeze.cu
查看文件 @
229db8c6
...
...
@@ -127,7 +127,7 @@ insert a dimension by copying the blocks for n times (where n is the size of the
>> s - pointer to the source data array
>> blockSize - size of a block
>> blockNum - number of the blocks
>> totalSize - total size of the blocks (i.e., blockS
I
ze * n)
>> totalSize - total size of the blocks (i.e., blockS
i
ze * n)
>> t - pointer to the target data array
>> n - number of blocks to copy data
*/
...
...
@@ -155,6 +155,75 @@ void KernelUnsqueeze(void * s, int blockSize, int blockNum, int totalSize, void
}
/*
insert a dimension by copying the blocks for n times (where n is the size of the inerted dimension)
This is special case where we actually copy a v-dimentional column vector by n times to form a v * n matrix
>> s - pointer to the source data array
>> rowNum - number of rows (i.e., dimension size of s)
>> colNum - number of columns (i.e., number of copies)
>> t - pointer to the target data array
*/
template<class T>
__global__
void KernelUnsqueezeByCol(void * s, int rowNum, int colNum, void * t)
{
__shared__ T values[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ T * ts[MAX_CUDA_THREAD_NUM_PER_BLOCK];
/* column index */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* row index */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if (i >= colNum || j >= rowNum)
return;
if(threadIdx.x == 0){
values[threadIdx.y] = ((T*)s)[j];
ts[threadIdx.y] = (T*)t + colNum * j;
}
__syncthreads();
ts[threadIdx.y][i] = values[threadIdx.y];
}
/*
insert a dimension by copying the blocks for n times (where n is the size of the inerted dimension)
This is special case where we actually copy a v-dimentional column vector by n times to form a v * n matrix
And a row is very big so that it occupies the cuda threads in a block
>> s - pointer to the source data array
>> rowNum - number of rows (i.e., dimension size of s)
>> colNum - number of columns (i.e., number of copies)
>> t - pointer to the target data array
*/
template<class T>
__global__
void KernelUnsqueezeByColBigRow(void * s, int rowNum, int colNum, void * t)
{
__shared__ T value;
__shared__ T * tData;
/* column index */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* row index */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if (i >= colNum || j >= rowNum)
return;
if (threadIdx.x == 0) {
value = ((T*)s)[j];
tData = (T*)t + colNum * j;
}
__syncthreads();
tData[i] = value;
}
/*
insert a dimension by copying the blocks for x times (where x is the size of the inerted dimension)
>> a - input tensor
>> b - output tensor
...
...
@@ -181,14 +250,39 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
if(blockNumA > 1){
if (dimRDI == 0) {
GDevs.GetCudaThread2D(a->devID, dSize, blockNumA, MAX_INT, cudaGrids, cudaBlocks);
if (a->dataType == X_FLOAT && b->dataType == X_FLOAT) {
if (cudaBlocks[1] == 1)
KernelUnsqueezeByColBigRow<float> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockNumA, dSize, b->data);
else
KernelUnsqueezeByCol<float> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockNumA, dSize, b->data);
}
else if (a->dataType == X_INT && b->dataType == X_INT) {
if (cudaBlocks[1] == 1)
KernelUnsqueezeByColBigRow<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockNumA, dSize, b->data);
else
KernelUnsqueezeByCol<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockNumA, dSize, b->data);
}
else {
ShowNTErrors("TODO!");
}
}
else if(blockNumA > 1){
GDevs.GetCudaThread2D(a->devID, blockSize, blockNumA, MAX_INT, cudaGrids, cudaBlocks);
if (a->dataType == X_FLOAT &&
a
->dataType == X_FLOAT) {
if (a->dataType == X_FLOAT &&
b
->dataType == X_FLOAT) {
KernelUnsqueeze<float> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockNumA, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_INT &&
a
->dataType == X_INT) {
else if (a->dataType == X_INT &&
b
->dataType == X_INT) {
KernelUnsqueeze<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockNumA, blockSize * dSize, b->data, dSize);
}
...
...
@@ -199,11 +293,11 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
else if(blockNumA == 1 && blockSize < MAX_CUDA_THREAD_NUM_PER_BLOCK){
GDevs.GetCudaThread2D(a->devID, blockSize, dSize, MAX_CUDA_THREAD_NUM_PER_BLOCK/4, cudaGrids, cudaBlocks);
if (a->dataType == X_FLOAT &&
a
->dataType == X_FLOAT) {
if (a->dataType == X_FLOAT &&
b
->dataType == X_FLOAT) {
KernelUnsqueezeFlat2D<float> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_INT &&
a
->dataType == X_INT) {
else if (a->dataType == X_INT &&
b
->dataType == X_INT) {
KernelUnsqueezeFlat2D<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
...
...
@@ -214,11 +308,11 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
else if(blockNumA == 1 && blockSize % 2 == 0){
GDevs.GetCudaThread(a->devID, blockSize/2, cudaGrids, cudaBlocks);
if (a->dataType == X_FLOAT &&
a
->dataType == X_FLOAT) {
if (a->dataType == X_FLOAT &&
b
->dataType == X_FLOAT) {
KernelUnsqueezeFlatBigram<float> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_INT &&
a
->dataType == X_INT) {
else if (a->dataType == X_INT &&
b
->dataType == X_INT) {
KernelUnsqueezeFlatBigram<int> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
...
...
@@ -229,11 +323,11 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
else if(blockNumA == 1){
GDevs.GetCudaThread(a->devID, blockSize, cudaGrids, cudaBlocks);
if (a->dataType == X_FLOAT &&
a
->dataType == X_FLOAT) {
if (a->dataType == X_FLOAT &&
b
->dataType == X_FLOAT) {
KernelUnsqueezeFlat<float> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_INT &&
a
->dataType == X_INT) {
else if (a->dataType == X_INT &&
b
->dataType == X_INT) {
KernelUnsqueezeFlat<int> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论