Commit 95b74dbb by linye

int and int8 sum supported

parent f151f061
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
/* /*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24 * $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
* $Update by: Lin Ye (linye2015@outlook.com) 2019-07-02 float16 added * $Update by: Lin Ye (linye2015@outlook.com) 2019-07-02 float16 int added
*/ */
#include "../../XDevice.h" #include "../../XDevice.h"
...@@ -48,6 +48,16 @@ void KernelADD(T * a, T * b, T * c, int size, T beta) ...@@ -48,6 +48,16 @@ void KernelADD(T * a, T * b, T * c, int size, T beta)
} }
__global__
void KernelADDInt(int * a, int * b, int * c, int size, DTYPE beta)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
c[i] = a[i] + b[i] * (int)beta;
}
/* /*
tensor summation c = a + b * \beta (cuda version) tensor summation c = a + b * \beta (cuda version)
>> a - a tensor >> a - a tensor
...@@ -65,7 +75,9 @@ void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta) ...@@ -65,7 +75,9 @@ void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
CheckNTErrors((a->devID == b->devID && a->devID == c->devID), CheckNTErrors((a->devID == b->devID && a->devID == c->devID),
"The tensors must be on the same!"); "The tensors must be on the same!");
CheckNTErrors((a->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE && c->dataType == DEFAULT_DTYPE) || CheckNTErrors((a->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE && c->dataType == DEFAULT_DTYPE) ||
(a->dataType == X_FLOAT16 && b->dataType == X_FLOAT16 && c->dataType == X_FLOAT16), (a->dataType == X_FLOAT16 && b->dataType == X_FLOAT16 && c->dataType == X_FLOAT16) ||
(a->dataType == X_INT && b->dataType == X_INT && c->dataType == X_INT) ||
(a->dataType == X_INT8 && b->dataType == X_INT8 && c->dataType == X_INT8),
"The sum function does not support this datatype."); "The sum function does not support this datatype.");
int devIDBackup = XDevice::GetGPUDevice(); int devIDBackup = XDevice::GetGPUDevice();
...@@ -119,7 +131,32 @@ void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta) ...@@ -119,7 +131,32 @@ void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
unsigned short temp = FloatToFloat16(beta); unsigned short temp = FloatToFloat16(beta);
half beta1 = *((half *)&temp); half beta1 = *((half *)&temp);
KernelADD << <blocks, threads >> >((__half *)a->data, (__half *)b->data, (__half *)c->data, a->unitNum, beta1); KernelADD << <blocks, threads >> >((__half *)a->data, (__half *)b->data, (__half *)c->data, a->unitNum, beta1);
}
else if (a->dataType == X_INT &&
b->dataType == X_INT &&
c->dataType == X_INT)
{
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int beta1 = (int)beta;
KernelADD << <blocks, threads >> >((int *)a->data, (int *)b->data, (int *)c->data, a->unitNum, beta1);
}
else if (a->dataType == X_INT8 &&
b->dataType == X_INT8 &&
c->dataType == X_INT8)
{
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
__int8 beta1 = (__int8)beta;
KernelADD << <blocks, threads >> >((__int8 *)a->data, (__int8 *)b->data, (__int8 *)c->data, a->unitNum, beta1);
} }
else { else {
// TODO!! // TODO!!
......
...@@ -35,10 +35,10 @@ bool Test() ...@@ -35,10 +35,10 @@ bool Test()
//wrong = !TestConcatenate() || wrong; //wrong = !TestConcatenate() || wrong;
//wrong = !TestConcatenateSolely() || wrong; //wrong = !TestConcatenateSolely() || wrong;
//wrong = !TestCos() || wrong; //wrong = !TestCos() || wrong;
wrong = !TestConvertDataType() || wrong; //wrong = !TestConvertDataType() || wrong;
//wrong = !TestCopyIndexed() || wrong; //wrong = !TestCopyIndexed() || wrong;
//wrong = !TestCopyValues() || wrong; //wrong = !TestCopyValues() || wrong;
wrong = !TestDiv() || wrong; //wrong = !TestDiv() || wrong;
//wrong = !TestDivDim() || wrong; //wrong = !TestDivDim() || wrong;
//wrong = !TestExp() || wrong; //wrong = !TestExp() || wrong;
//wrong = !TestGather() || wrong; //wrong = !TestGather() || wrong;
...@@ -70,7 +70,7 @@ bool Test() ...@@ -70,7 +70,7 @@ bool Test()
//wrong = !TestSplit() || wrong; //wrong = !TestSplit() || wrong;
//wrong = !TestSpread() || wrong; //wrong = !TestSpread() || wrong;
//wrong = !TestSub() || wrong; //wrong = !TestSub() || wrong;
//wrong = !TestSum() || wrong; wrong = !TestSum() || wrong;
//wrong = !TestSumByColumnTV() || wrong; //wrong = !TestSumByColumnTV() || wrong;
//wrong = !TestSumByColumnVT() || wrong; //wrong = !TestSumByColumnVT() || wrong;
//wrong = !TestSumDim() || wrong; //wrong = !TestSumDim() || wrong;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论