Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
T
Tensor.LowPrecision
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
魏冰浩
Tensor.LowPrecision
Commits
6a3d713a
Commit
6a3d713a
authored
Jul 08, 2019
by
linye
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
no message
parent
30217de4
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
178 行增加
和
54 行删除
+178
-54
source/tensor/core/arithmetic/MultiplyDim.cu
+29
-0
source/tensor/core/arithmetic/Negate.cu
+22
-21
source/tensor/function/LogSoftmax.cu
+126
-32
source/tensor/test/Test.cpp
+1
-1
没有找到文件。
source/tensor/core/arithmetic/MultiplyDim.cu
查看文件 @
6a3d713a
...
...
@@ -169,6 +169,35 @@ void _CudaMultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n,
ShowNTErrors("Something is wrong!");
}
}
if (a->dataType == X_FLOAT16) {
unsigned short temp = FloatToFloat16(alpha);
half alpha1 = *((half *)&temp);
if (stride > 1) {
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
if (alpha == 0.0F)
KernelMultiplyWithCol<__half, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((__half*)a->data, (__half*)b->data, (__half*)c->data,
blockSize, stride, blockSize * stride, blockNum, alpha1);
else
KernelMultiplyWithCol<__half, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((__half*)a->data, (__half*)b->data, (__half*)c->data,
blockSize, stride, blockSize * stride, blockNum, alpha1);
}
else if (stride == 1) {
GDevs.GetCudaThread2D(a->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
if (alpha == 0.0F)
KernelMultiplyWithRow<__half, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((__half*)a->data, (__half*)b->data, (__half*)c->data,
blockNum, blockSize, alpha1);
else
KernelMultiplyWithRow<__half, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((__half*)a->data, (__half*)b->data, (__half*)c->data,
blockNum, blockSize, alpha1);
}
else {
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
...
...
source/tensor/core/arithmetic/Negate.cu
查看文件 @
6a3d713a
...
...
@@ -33,8 +33,9 @@ set each entry to its negtive value (CUDA Kernel)
>> b - pointer to the output data array
>> size - size of the data array
*/
template <class T>
__global__
void KernelNegate(
DTYPE * a, DTYPE
* b, int size)
void KernelNegate(
T * a, T
* b, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
...
...
@@ -42,26 +43,26 @@ void KernelNegate(DTYPE * a, DTYPE * b, int size)
b[i] = -a[i];
}
/*
set each entry to its negtive value (CUDA Kernel)
This is for float16 computation
>> a - pointer to the input data array
>> b - pointer to the output data array
>> size - size of the data array
*/
__global__
void KernelNegate(__half * a, __half * b, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
if (i < size)
b[i] = __hsub(__float2half(0), a[i]);
#else
if (i < size)
b[i] = __float2half(-__half2float(a[i]));
#endif
}
/
//
*
//
set each entry to its negtive value (CUDA Kernel)
//
This is for float16 computation
//
>> a - pointer to the input data array
//
>> b - pointer to the output data array
//
>> size - size of the data array
//
*/
//
__global__
//
void KernelNegate(__half * a, __half * b, int size)
//
{
//
int i = blockDim.x * blockIdx.x + threadIdx.x;
//
//
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
//
if (i < size)
//
b[i] = __hsub(__float2half(0), a[i]);
//
#else
//
if (i < size)
//
b[i] = __float2half(-__half2float(a[i]));
//
#endif
//
}
/*
set each entry to its negtive value
...
...
source/tensor/function/LogSoftmax.cu
查看文件 @
6a3d713a
差异被折叠。
点击展开。
source/tensor/test/Test.cpp
查看文件 @
6a3d713a
...
...
@@ -84,7 +84,7 @@ bool Test()
//wrong = !TestDropout() || wrong;
//wrong = !TestHardTanH() || wrong;
//wrong = !TestIdentity() || wrong;
//
wrong = !TestLogSoftmax() || wrong;
wrong
=
!
TestLogSoftmax
()
||
wrong
;
//wrong = !TestLoss() || wrong;
//wrong = !TestRectify() || wrong;
//wrong = !TestSigmoid() || wrong;
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论