Commit 02e8a2c1 by linye

update float16 datatype of Rectify

parent 5bc8e96b
......@@ -34,15 +34,16 @@ rectify : y = x if x >= 0
>> output - output tensor
>> size - size of input/output
*/
template<class T>
__global__
void KernelRectify(DTYPE * x, DTYPE * y, int size)
void KernelRectify(T * x, T * y, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size){
DTYPE p = x[i];
if(p < 0)
p = 0;
T p = x[i];
if(p < (T)0.0)
p = (T)0.0;
y[i] = p;
}
}
......@@ -61,8 +62,18 @@ void _CudaRectify(const XTensor * x, XTensor * y)
int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup);
if (x->dataType == DEFAULT_DTYPE) {
KernelRectify<<<dim3(gridSize[0]), dim3(blockSize[0])>>>
((DTYPE*)x->data, (DTYPE*)y->data, x->unitNum);
}
else if (x->dataType == X_FLOAT16) {
KernelRectify<<<dim3(gridSize[0]), dim3(blockSize[0]) >> >
((__half*)x->data, (__half*)y->data, x->unitNum);
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
BacktoCudaDev(x->devID, devIDBackup);
}
......@@ -78,17 +89,18 @@ dy/dx = 1 if x >= 0
>> x - input of the function
>> size - size of output/input
*/
template<class T>
__global__
void KernelRectifyBackward(DTYPE * dedy, DTYPE * dedx, DTYPE * x, int size)
void KernelRectifyBackward(T * dedy, T * dedx, T * x, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size){
DTYPE s = x[i];
if(s >= 0)
T s = x[i];
if(s >= (T)0.0)
dedx[i] = dedy[i];
else
dedx[i] = 0;
dedx[i] = (T)0.0;
}
}
......@@ -119,11 +131,24 @@ void _CudaRectifyBackward(XTensor * y, XTensor * x,
ProtectCudaDev(x->devID, devIDBackup);
/* dE/ds = dE/dy * dy/ds */
if (x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE) {
KernelRectifyBackward<<<dim3(gridSize[0]),dim3(blockSize[0])>>>
((DTYPE*)dedy->data,
(DTYPE*)dedx->data,
(DTYPE*)x->data,
x->unitNum);
}
else if (x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16) {
KernelRectifyBackward<<<dim3(gridSize[0]), dim3(blockSize[0]) >> >
((__half*)dedy->data,
(__half*)dedx->data,
(__half*)x->data,
x->unitNum);
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
BacktoCudaDev(x->devID, devIDBackup);
}
......
......@@ -20,6 +20,7 @@
*/
#include "TRectify.h"
#include "../core/getandset/ConvertDataType.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -199,6 +200,164 @@ bool TestRectify2()
#endif // USE_CUDA
}
/*
case 3: float16 test rectify function
In this case, y = max(0, x)
*/
bool TestRectify3()
{
/* a tensor of size (2, 3) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 2;
dimSize[1] = 3;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE xData[2][3] = { {0.0F, -1.0F, 2.0F},
{3.0F, -4.0F, -5.0F} };
DTYPE answer[2][3] = { {0.0F, 0.0F, 2.0F},
{3.0F, 0.0F, 0.0F} };
/* CPU test */
bool cpuTest = true;
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * xGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * yGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor yUserGPU;
/* create float16 tensor */
XTensor xHalfGPU;
XTensor yHalfGPU;
XTensor yUserHalfGPU;
/* Initialize variables */
xGPU->SetData(xData, unitNum);
yGPU->SetZeroAll();
/* convert data type from float to float16 */
xHalfGPU = ConvertDataType(*xGPU, X_FLOAT16);
yHalfGPU = ConvertDataType(*yGPU, X_FLOAT16);
/* call Rectify function */
_Rectify(&xHalfGPU, &yHalfGPU);
yUserHalfGPU = Rectify(xHalfGPU);
/* convert data type from float16 to float */
_ConvertDataType(&yHalfGPU, yGPU);
yUserGPU = ConvertDataType(yUserHalfGPU, X_FLOAT);
/* check results */
gpuTest = yGPU->CheckData(answer, unitNum, 1e-4F) &&
yUserGPU.CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete xGPU;
delete yGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 4: float16 backward computation
dE/dx = dE/dy * dy/dx
rectified: y = max(0, x)
In this case, lossName=CROSSENTROPY.
*/
bool TestRectify4()
{
/* a tensor of size (2, 3) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 2;
dimSize[1] = 3;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE xData[2][3] = { {-1.0F, 1.0F, 2.0F},
{-2.0F, 4.0F, 5.0F} };
DTYPE yData[2][3] = { {0.0F, 1.0F, 2.0F},
{0.0F, 4.0F, 5.0F} };
DTYPE dedyData[2][3] = { {-0.5F, -0.5F, -0.25F},
{-0.25F, -0.125F, -0.1F} };
DTYPE dedxAnswer[2][3] = { {0.0F, -0.5F, -0.25F},
{0.0F, -0.125F, -0.1F} };
/* CPU test */
bool cpuTest = true;
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * xGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * yGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * dedyGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * dedxGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
/* create float16 tensor */
XTensor xHalfGPU;
XTensor yHalfGPU;
XTensor dedyHalfGPU;
XTensor dedxHalfGPU;
/* initialize variables */
xGPU->SetData(xData, unitNum);
yGPU->SetData(yData, unitNum);
dedyGPU->SetData(dedyData, unitNum);
dedxGPU->SetZeroAll();
/* convert data type from float to float16 */
xHalfGPU = ConvertDataType(*xGPU, X_FLOAT16);
yHalfGPU = ConvertDataType(*yGPU, X_FLOAT16);
dedyHalfGPU = ConvertDataType(*dedyGPU, X_FLOAT16);
dedxHalfGPU = ConvertDataType(*dedxGPU, X_FLOAT16);
/* call Rectify function */
_Rectify(&xHalfGPU, &yHalfGPU);
/* call rectifybackward function */
_RectifyBackward(&yHalfGPU, &xHalfGPU, &dedyHalfGPU, &dedxHalfGPU);
/* convert data type from float16 to float */
_ConvertDataType(&dedxHalfGPU, dedxGPU);
/* check results */
gpuTest = dedxGPU->CheckData(dedxAnswer, unitNum, 1e-4F);
/* destroy variables */
delete xGPU;
delete yGPU;
delete dedyGPU;
delete dedxGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
......@@ -230,6 +389,26 @@ bool TestRectify()
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* case 3 test */
caseFlag = TestRectify3();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 3 failed!\n");
}
else
XPRINT(0, stdout, ">> case 3 passed!\n");
/* case 4 test */
caseFlag = TestRectify4();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 4 failed!\n");
}
else
XPRINT(0, stdout, ">> case 4 passed!\n");
/* other cases test */
/*
TODO!!
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论