Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
ece0dc78
Commit
ece0dc78
authored
Jul 31, 2018
by
张裕浩
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
测试softmax优化方法
parent
dd6646ed
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
138 行增加
和
12 行删除
+138
-12
source/network/Main.cpp
+5
-4
source/tensor/function/Softmax.cu
+53
-4
source/tensor/test/TSoftmax.cpp
+78
-2
source/tensor/test/Test.cpp
+2
-2
没有找到文件。
source/network/Main.cpp
查看文件 @
ece0dc78
...
...
@@ -24,6 +24,7 @@
#include "../tensor/function/FHeader.h"
#include "../tensor/core/CHeader.h"
#include "../sample/fnnlm/FNNLM.h"
#include "../tensor/test/Test.h"
//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>
...
...
@@ -36,9 +37,9 @@ using namespace samplefnnlm;
int
main
(
int
argc
,
const
char
**
argv
)
{
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-test"
))
1
;
//
Test();
else
if
(
argc
>
1
&&
!
strcmp
(
argv
[
1
],
"-fnnlm"
))
//
if(argc > 1 && !strcmp(argv[1], "-test"))
Test
();
/*
else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
FNNLMMain(argc - 1, argv + 1);
else{
fprintf(stderr, "Thanks for using NiuTrans.Network! This is a library for building\n");
...
...
@@ -76,7 +77,7 @@ int main( int argc, const char ** argv )
net.Dump(stderr);
//_CrtDumpMemoryLeaks();
//_CrtDumpMemoryLeaks();
*/
return
0
;
}
source/tensor/function/Softmax.cu
查看文件 @
ece0dc78
...
...
@@ -155,6 +155,46 @@ void KernelSoftmaxComputeTensor(__half * x, __half * max, __half * sum, __half *
}
}
__device__ __forceinline__ float broadCast(float input)
{
float output;
asm(
"{"
"shfl.idx.b32 %0,%1,0x0,0x1f;"
"}"
:"=f"(output) : "f"(input)
);
return output;
}
__global__
void KernelSoftmaxComputeTensorUseBroadcast(DTYPE * input, DTYPE * max, DTYPE * sum, DTYPE * output, int stride, int strideNum, int blockNum)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
int j = blockDim.y * blockIdx.y + threadIdx.y;
int i2 = j % stride;
int blockSize = stride * strideNum;
if (j < stride * blockNum)
{
DTYPE sumData, maxData;
if (i % 32 == 0)
{
sumData = sum[j];
maxData = max[j];
}
//sumData = __shfl_sync(0xffffffff,sumData, 0);
//maxData = __shfl_sync(0xffffffff,maxData, 0);
sumData = broadCast(sumData);
maxData = broadCast(maxData);
if (i < strideNum)
{
int offset = int(j / stride) * blockSize + i * stride + i2;
output[offset] = exp(input[offset] - maxData) / sumData;
}
}
}
/*
softmax y = e^x / \sum_{i} e^{x_i} (Cuda version)
>> x - x vector
...
...
@@ -183,15 +223,24 @@ void _CudaSoftmaxSumMax(const XTensor * x, XTensor * y, int leadDim, XTensor * s
int cudaGridSize[3];
int cudaBlockSize[3];
GDevs.GetCudaThread2D(x->devID, stride * blockNum, dimensionSize, MAX_INT, cudaGridSize, cudaBlockSize);
//GDevs.GetCudaThread2D(x->devID, stride * blockNum, dimensionSize, MAX_INT, cudaGridSize, cudaBlockSize);
GDevs.GetCudaThread2D(x->devID, dimensionSize, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
if (cudaBlockSize[0] % 32 != 0)
cudaBlockSize[0] += (32 - cudaBlockSize[0] % 32);
/**/
int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup);
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
KernelSoftmaxComputeTensor<<<dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1])>>>
printf("run here\n");
/*KernelSoftmaxComputeTensor<<<dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1])>>>
((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data,
stride, dimensionSize, stride * dimensionSize, blockNum, stride * blockNum);
stride, dimensionSize, stride * dimensionSize, blockNum, stride * blockNum);*/
KernelSoftmaxComputeTensorUseBroadcast << <dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >> >
((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data,
stride, dimensionSize, blockNum);
//printf("%d %d %d %d %d %d\n", stride, dimensionSize, stride * dimensionSize, blockNum, stride * blockNum);
}
else if(x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16){
KernelSoftmaxComputeTensor<<<dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1])>>>
...
...
source/tensor/test/TSoftmax.cpp
查看文件 @
ece0dc78
...
...
@@ -68,7 +68,6 @@ bool TestSoftmax1()
#ifdef USE_CUDA
/* GPU test */
bool
gpuTest
=
true
;
/* create tensors */
XTensor
*
xGPU
=
NewTensor
(
order
,
dimSize
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
yGPU
=
NewTensor
(
order
,
dimSize
,
X_FLOAT
,
1.0
F
,
0
);
...
...
@@ -81,7 +80,6 @@ bool TestSoftmax1()
/* call Softmax function */
_Softmax
(
xGPU
,
yGPU
,
1
);
yUserGPU
=
Softmax
(
*
xGPU
,
1
);
/* check result */
gpuTest
=
yGPU
->
CheckData
(
answer
,
unitNum
,
1e-4
F
)
&&
yUserGPU
.
CheckData
(
answer
,
unitNum
,
1e-4
F
);
...
...
@@ -208,6 +206,77 @@ bool TestSoftmax2()
}
/* other cases */
bool
TestSoftmax3Gpu
()
{
#ifdef USE_CUDA
/* GPU test */
bool
gpuTest
=
true
;
int
order
=
2
;
int
*
dimSize
=
new
int
[
order
];
dimSize
[
0
]
=
32
;
dimSize
[
1
]
=
1000
;
int
unitNum
=
1
;
for
(
int
i
=
0
;
i
<
order
;
i
++
)
unitNum
*=
dimSize
[
i
];
/* create tensors */
XTensor
*
xGPU
=
NewTensor
(
order
,
dimSize
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
yGPU
=
NewTensor
(
order
,
dimSize
,
X_FLOAT
,
1.0
F
,
0
);
/* initialize variables */
FILE
*
dataFile
;
char
dataString
[
32
];
const
int
dataSize
=
32
*
1000
;
DTYPE
xData
[
dataSize
];
if
((
dataFile
=
fopen
(
"D:
\\
Work
\\
TensorFlowLearn
\\
testdata.in"
,
"r"
))
==
NULL
)
{
printf
(
"file open fail"
);
exit
(
1
);
}
for
(
int
i
=
0
;
i
<
dataSize
;
++
i
)
{
if
(
fscanf
(
dataFile
,
"%s"
,
dataString
)
!=
EOF
)
{
//printf("%s", dataString);
xData
[
i
]
=
atof
(
dataString
);
//srcTensorData[i] = i;
}
else
{
printf
(
"read wrong"
);
break
;
}
}
xGPU
->
SetData
(
xData
,
unitNum
);
yGPU
->
SetZeroAll
();
/* call Softmax function */
_Softmax
(
xGPU
,
yGPU
,
0
);
/* check result */
//gpuTest = yGPU->CheckData(yAnswer, unitNum, 1e-4F)
DTYPE
check
=
0
;
DTYPE
TensorData
[
dataSize
];
cudaMemcpy
(
TensorData
,
yGPU
->
data
,
sizeof
(
DTYPE
)
*
unitNum
,
cudaMemcpyDeviceToHost
);
//float check = 0;
for
(
int
i
=
0
;
i
<
32
;
++
i
)
{
check
+=
TensorData
[
i
];
printf
(
"%f "
,
TensorData
[
i
]);
}
printf
(
"
\n
%f
\n
"
,
check
);
/* destroy variables */
delete
xGPU
;
delete
yGPU
;
delete
[]
dimSize
;
return
gpuTest
;
#endif
}
/*
TODO!!
*/
...
...
@@ -239,6 +308,13 @@ bool TestSoftmax()
XPRINT
(
0
,
stdout
,
">> case 2 passed!
\n
"
);
/* other cases test */
caseFlag
=
TestSoftmax3Gpu
();
if
(
!
caseFlag
)
{
returnFlag
=
false
;
XPRINT
(
0
,
stdout
,
">> case 3 failed!
\n
"
);
}
else
XPRINT
(
0
,
stdout
,
">> case 3 passed!
\n
"
);
/*
TODO!!
*/
...
...
source/tensor/test/Test.cpp
查看文件 @
ece0dc78
...
...
@@ -29,7 +29,7 @@ bool Test()
bool
wrong
=
false
;
XPRINT
(
0
,
stdout
,
"Testing the XTensor utilites ...
\n\n
"
);
wrong
=
!
TestAbsolute
()
||
wrong
;
/*
wrong = !TestAbsolute() || wrong;
wrong = !TestConcatenate() || wrong;
wrong = !TestConcatenateSolely() || wrong;
wrong = !TestConvertDataType() || wrong;
...
...
@@ -70,7 +70,7 @@ bool Test()
wrong = !TestLogSoftmax() || wrong;
wrong = !TestLoss() || wrong;
wrong = !TestRectify() || wrong;
wrong
=
!
TestSigmoid
()
||
wrong
;
wrong = !TestSigmoid() || wrong;
*/
wrong
=
!
TestSoftmax
()
||
wrong
;
/* other test */
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论