Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
e155205c
Commit
e155205c
authored
Feb 18, 2020
by
huchi
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
replace some functions for GPU
parent
9ebfb0be
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
47 行增加
和
54 行删除
+47
-54
source/tensor/core/arithmetic/Sum.cu
+20
-0
source/tensor/core/math/Clip.cu
+7
-18
source/tensor/core/math/Clip.cuh
+2
-2
source/tensor/core/math/ScaleAndShift.cu
+18
-34
没有找到文件。
source/tensor/core/arithmetic/Sum.cu
查看文件 @
e155205c
...
@@ -45,6 +45,15 @@ void KernelADD(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta)
...
@@ -45,6 +45,15 @@ void KernelADD(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta)
c[i] = a[i] + b[i] * beta;
c[i] = a[i] + b[i] * beta;
}
}
__global__
void KernelADD(int * a, int * b, int * c, int size, int beta)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
c[i] = a[i] + b[i] * beta;
}
/*
/*
tensor summation c = a + b * \beta (cuda version)
tensor summation c = a + b * \beta (cuda version)
>> a - a tensor
>> a - a tensor
...
@@ -100,6 +109,17 @@ void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
...
@@ -100,6 +109,17 @@ void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
KernelADD << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, a->unitNum, beta);
KernelADD << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, a->unitNum, beta);
}
}
}
}
else if (a->dataType == X_INT &&
b->dataType == X_INT &&
c->dataType == X_INT)
{
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
KernelADD << <blocks, threads >> >((int*)a->data, (int*)b->data, (int*)c->data, a->unitNum, (int)beta);
}
else {
else {
// TODO!!
// TODO!!
ShowNTErrors("TODO!");
ShowNTErrors("TODO!");
...
...
source/tensor/core/math/Clip.cu
查看文件 @
e155205c
...
@@ -36,8 +36,9 @@ set each entry to its clip value (CUDA Kernel)
...
@@ -36,8 +36,9 @@ set each entry to its clip value (CUDA Kernel)
>> upper - the upper border
>> upper - the upper border
>> size - size of the data array
>> size - size of the data array
*/
*/
template <class T>
__global__
__global__
void KernelClip(
DTYPE * a, DTYPE * b, DTYPE lower, DTYPE
upper, int size)
void KernelClip(
T * a, T * b, T lower, T
upper, int size)
{
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.x * blockIdx.x + threadIdx.x;
...
@@ -52,21 +53,6 @@ void KernelClip(DTYPE * a, DTYPE * b, DTYPE lower, DTYPE upper, int size)
...
@@ -52,21 +53,6 @@ void KernelClip(DTYPE * a, DTYPE * b, DTYPE lower, DTYPE upper, int size)
}
}
/*
/*
set each entry to its clip value with float16 data type value (CUDA Kernel)
This is for float16 computation
>> a - pointer to input data array
>> b - pointer to output data array
>> lower - the lower border
>> upper - the upper border
>> size - size of the data array
*/
__global__
void KernelClip(__half * a, __half * b, DTYPE lower, DTYPE upper, int size)
{
return;
}
/*
set each entry to its clip value
set each entry to its clip value
>> a - input tensor we are processing
>> a - input tensor we are processing
>> b - output tensor we are processing
>> b - output tensor we are processing
...
@@ -92,8 +78,11 @@ void _CudaClip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper)
...
@@ -92,8 +78,11 @@ void _CudaClip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper)
if (a->dataType == DEFAULT_DTYPE) {
if (a->dataType == DEFAULT_DTYPE) {
KernelClip << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, lower, upper, a->unitNum);
KernelClip << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, lower, upper, a->unitNum);
}
}
else if (a->dataType == X_FLOAT16) {
else if (a->dataType == X_INT) {
KernelClip << <blocks, threads >> >((__half*)a->data, (__half*)b->data, lower, upper, a->unitNum);
int lower1 = (int)lower;
int upper1 = (int)upper;
KernelClip << <blocks, threads >> >((int *)a->data, (int *)b->data, lower1, upper1, a->unitNum);
}
}
else {
else {
ShowNTErrors("TODO!");
ShowNTErrors("TODO!");
...
...
source/tensor/core/math/Clip.cuh
查看文件 @
e155205c
...
@@ -29,8 +29,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
...
@@ -29,8 +29,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
#ifdef USE_CUDA
/* set each entry to its clip value (CUDA Kernel) */
/* set each entry to its clip value (CUDA Kernel) */
__global__
template <class T>
__global__
void KernelClip(
DTYPE * a, DTYPE * b, DTYPE lower, DTYPE
upper, int size);
void KernelClip(
T * a, T * b, T lower, T
upper, int size);
/* set each entry to its clip value (CUDA Kernel) with float16 data type*/
/* set each entry to its clip value (CUDA Kernel) with float16 data type*/
__global__
__global__
...
...
source/tensor/core/math/ScaleAndShift.cu
查看文件 @
e155205c
...
@@ -34,9 +34,9 @@ scale and shift all tensor entires b = a * scale + shift (CUDA Kernel)
...
@@ -34,9 +34,9 @@ scale and shift all tensor entires b = a * scale + shift (CUDA Kernel)
>> scale - how much we want to scale it
>> scale - how much we want to scale it
>> shift - how much we want to shift it
>> shift - how much we want to shift it
*/
*/
template<bool isUnitScale, bool isZeroShift>
template<
class T,
bool isUnitScale, bool isZeroShift>
__global__
__global__
void KernelScaleAndShift(
DTYPE * a, DTYPE * b, int size, DTYPE scale, DTYPE
shift)
void KernelScaleAndShift(
T * a, T * b, int size, T scale, T
shift)
{
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.x * blockIdx.x + threadIdx.x;
...
@@ -56,28 +56,6 @@ void KernelScaleAndShift(DTYPE * a, DTYPE * b, int size, DTYPE scale, DTYPE shif
...
@@ -56,28 +56,6 @@ void KernelScaleAndShift(DTYPE * a, DTYPE * b, int size, DTYPE scale, DTYPE shif
}
}
}
}
/*
scale and shift all tensor entires p = p * scale + shift (CUDA Kernel)
This is for float16 computation
>> a - the input data array
>> b - the output data array
>> size - the size of d
>> scale - how much we want to scale it
>> shift - how much we want to shift it
*/
__global__
void KernelScaleAndShift(__half * a, __half * b, int size, __half scale, __half shift)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
if(i < size)
b[i] = __hadd(__hmul(a[i], scale), shift);
#else
if (i < size)
b[i] = __float2half(__half2float(a[i]) * __half2float(scale) + __half2float(shift));
#endif
}
/*
/*
scale and shift all tensor entires
scale and shift all tensor entires
...
@@ -108,20 +86,26 @@ void _CudaScaleAndShift(const XTensor * a, XTensor * b, DTYPE scale, DTYPE shift
...
@@ -108,20 +86,26 @@ void _CudaScaleAndShift(const XTensor * a, XTensor * b, DTYPE scale, DTYPE shift
if(a->dataType == DEFAULT_DTYPE){
if(a->dataType == DEFAULT_DTYPE){
if(scale == 1.0F && shift == 0)
if(scale == 1.0F && shift == 0)
KernelScaleAndShift<true, true> <<<blocks, threads>>>((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
KernelScaleAndShift<
DTYPE,
true, true> <<<blocks, threads>>>((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
else if (scale == 1.0F && shift != 0)
else if (scale == 1.0F && shift != 0)
KernelScaleAndShift<true, false> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
KernelScaleAndShift<
DTYPE,
true, false> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
else if(scale != 1.0F && shift == 0)
else if(scale != 1.0F && shift == 0)
KernelScaleAndShift<false, true> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
KernelScaleAndShift<
DTYPE,
false, true> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
else
else
KernelScaleAndShift<false, false> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
KernelScaleAndShift<
DTYPE,
false, false> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
}
}
else if(a->dataType == X_FLOAT16){
else if (a->dataType == X_INT) {
unsigned short scale2 = FloatToFloat16(scale);
int scale2 = int(scale);
unsigned short shift2 = FloatToFloat16(shift);
int shift2 = int(shift);
__half * scaleft16p = (__half*)&scale2;
__half * shiftft16p = (__half*)&shift2;
if (scale == 1.0F && shift == 0)
KernelScaleAndShift<<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p);
KernelScaleAndShift<int, true, true><<<blocks, threads>>>((int *)a->data, (int *)b->data, a->unitNum, scale2, shift2);
else if (scale == 1.0F && shift != 0)
KernelScaleAndShift<int, true, false><<<blocks, threads>>>((int *)a->data, (int *)b->data, a->unitNum, scale2, shift2);
else if (scale != 1.0F && shift == 0)
KernelScaleAndShift<int, false, true><<<blocks, threads>>>((int *)a->data, (int *)b->data, a->unitNum, scale2, shift2);
else
KernelScaleAndShift<int, false, false><<<blocks, threads>>>((int *)a->data, (int *)b->data, a->unitNum, scale2, shift2);
}
}
else{
else{
ShowNTErrors("TODO!");
ShowNTErrors("TODO!");
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论