Commit 95b74dbb by linye

int and int8 sum supported

parent f151f061
......@@ -17,7 +17,7 @@
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
* $Update by: Lin Ye (linye2015@outlook.com) 2019-07-02 float16 added
* $Update by: Lin Ye (linye2015@outlook.com) 2019-07-02 float16 int added
*/
#include "../../XDevice.h"
......@@ -48,6 +48,16 @@ void KernelADD(T * a, T * b, T * c, int size, T beta)
}
__global__
void KernelADDInt(int * a, int * b, int * c, int size, DTYPE beta)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
c[i] = a[i] + b[i] * (int)beta;
}
/*
tensor summation c = a + b * \beta (cuda version)
>> a - a tensor
......@@ -65,7 +75,9 @@ void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
CheckNTErrors((a->devID == b->devID && a->devID == c->devID),
"The tensors must be on the same!");
CheckNTErrors((a->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE && c->dataType == DEFAULT_DTYPE) ||
(a->dataType == X_FLOAT16 && b->dataType == X_FLOAT16 && c->dataType == X_FLOAT16),
(a->dataType == X_FLOAT16 && b->dataType == X_FLOAT16 && c->dataType == X_FLOAT16) ||
(a->dataType == X_INT && b->dataType == X_INT && c->dataType == X_INT) ||
(a->dataType == X_INT8 && b->dataType == X_INT8 && c->dataType == X_INT8),
"The sum function does not support this datatype.");
int devIDBackup = XDevice::GetGPUDevice();
......@@ -119,7 +131,32 @@ void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
unsigned short temp = FloatToFloat16(beta);
half beta1 = *((half *)&temp);
KernelADD << <blocks, threads >> >((__half *)a->data, (__half *)b->data, (__half *)c->data, a->unitNum, beta1);
}
else if (a->dataType == X_INT &&
b->dataType == X_INT &&
c->dataType == X_INT)
{
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int beta1 = (int)beta;
KernelADD << <blocks, threads >> >((int *)a->data, (int *)b->data, (int *)c->data, a->unitNum, beta1);
}
else if (a->dataType == X_INT8 &&
b->dataType == X_INT8 &&
c->dataType == X_INT8)
{
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
__int8 beta1 = (__int8)beta;
KernelADD << <blocks, threads >> >((__int8 *)a->data, (__int8 *)b->data, (__int8 *)c->data, a->unitNum, beta1);
}
else {
// TODO!!
......
......@@ -35,10 +35,10 @@ bool Test()
//wrong = !TestConcatenate() || wrong;
//wrong = !TestConcatenateSolely() || wrong;
//wrong = !TestCos() || wrong;
wrong = !TestConvertDataType() || wrong;
//wrong = !TestConvertDataType() || wrong;
//wrong = !TestCopyIndexed() || wrong;
//wrong = !TestCopyValues() || wrong;
wrong = !TestDiv() || wrong;
//wrong = !TestDiv() || wrong;
//wrong = !TestDivDim() || wrong;
//wrong = !TestExp() || wrong;
//wrong = !TestGather() || wrong;
......@@ -70,7 +70,7 @@ bool Test()
//wrong = !TestSplit() || wrong;
//wrong = !TestSpread() || wrong;
//wrong = !TestSub() || wrong;
//wrong = !TestSum() || wrong;
wrong = !TestSum() || wrong;
//wrong = !TestSumByColumnTV() || wrong;
//wrong = !TestSumByColumnVT() || wrong;
//wrong = !TestSumDim() || wrong;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论