Commit 1f1413ca by liyinqiao

Support fp16 data type for more operations and fix the minor errors. (Don't use…

Support fp16 data type for more operations and fix the minor errors.  (Don't use this! It's an incomplete version)
parent 1bf5cc90
......@@ -97,7 +97,7 @@ void _CudaCopyBlocksSelected(void * source, int unitSize, int blockSize, int * s
GDevs.GetCudaThread2D(devID, bSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
if (unitSize == 4)
KernelCopyBlocksSelected <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)source, bSize, sourceBlocksTMP, blockNum, (DTYPE*)target, targetBlocksTMP);
((float*)source, bSize, sourceBlocksTMP, blockNum, (float*)target, targetBlocksTMP);
else if (unitSize == 2)
KernelCopyBlocksSelected <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((half*)source, bSize, sourceBlocksTMP, blockNum, (half*)target, targetBlocksTMP);
......
......@@ -128,9 +128,9 @@ void _CudaCopyIndexed(const XTensor * s, XTensor * t, int dim,
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
if (s->dataType == X_FLOAT && t->dataType == X_FLOAT) {
DTYPE * sData = (float*)s->data;
DTYPE * tData = (float*)t->data;
if (s->dataType == DEFAULT_DTYPE && t->dataType == DEFAULT_DTYPE) {
DTYPE * sData = (DTYPE*)s->data;
DTYPE * tData = (DTYPE*)t->data;
int * sIndex = (int*)srcIndex->data;
int * tIndex = (int*)tgtIndex->data;
......
......@@ -401,11 +401,11 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, XTensor * srcI
else
sIndex = (int *)srcIndex->data;
if (source->dataType == X_FLOAT && collection->dataType == X_FLOAT)
if (source->dataType == DEFAULT_DTYPE && collection->dataType == DEFAULT_DTYPE)
{
DTYPE * sData = (float*)source->data;
DTYPE * cData = (float*)collection->data;
KernelSpreadForGather<float> << <blocks, threads >> >(sData, cData, sIndex, indexSize, stride);
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
KernelSpreadForGather<DTYPE> << <blocks, threads >> >(sData, cData, sIndex, indexSize, stride);
}
else if (source->dataType == X_FLOAT16 && collection->dataType == X_FLOAT16)
{
......
......@@ -96,7 +96,7 @@ crossing of the i-th columne and the j-th row.
>> power - power of the item in the array
>> isExp - specify if we perform exp() on the input
*/
__global__
__global__
void KernelReduceSum(DTYPE * input, DTYPE * output,
int stride, int strideNum, int reducedStrideNum,
int blockSize, int blockNum,
......@@ -152,7 +152,7 @@ void KernelReduceSum(DTYPE * input, DTYPE * output,
output[(k * reducedStrideNum + blockIdx.x) * stride + iOffset] = iData[threadIdx.y * blockDim.x];
}
/*
/*
reduce a tensor to another that keeps the sum along a dimension - slow version
This is for float16 reduction.
Given a block of data, we go over each dimension i in the stride and we have
......@@ -171,7 +171,7 @@ crossing of the i-th columne and the j-th row.
>> power - power of the item in the array
>> isExp - specify if we perform exp() on the input
*/
__global__
__global__
void KernelReduceSum(__half * input, __half * output,
int stride, int strideNum, int reducedStrideNum,
int blockSize, int blockNum,
......@@ -726,7 +726,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
DTYPE * sp = shift != NULL ? (DTYPE*)shift->data : NULL;
if (stride == 1 && blockNum >= 10) {
if (stride == 1 && blockNum >= 10 && input->dataType == DEFAULT_DTYPE) {
dim3 grids;
dim3 blocks;
continuousStorageThreadAllocation(grids, blocks, (long long)blockNum, strideNum);
......@@ -742,7 +742,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
strideNum, blockNum, sp, power, isExp);
}
}
else if (stride != 1 && stride * blockNum > 4096) {
else if (stride != 1 && stride * blockNum > 4096 && input->dataType == DEFAULT_DTYPE){
//GDevs->GetGridAndBlockSize2D(devID, stride * blockNum, strideNum,MAX_INT, cudaGridSize, cudaBlockSize);
//unsigned int* goutput = (unsigned int *)input->data;
//convert2uintV2 << <dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >> > ((float*)input->data, goutput, stride, strideNum, blockNum, strideNum*blockNum*stride);
......@@ -766,15 +766,15 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
DTYPE * oData = NULL;
if (iter == 0) {
iData = (DTYPE*)input->data;
oData = buf1;
oData = (DTYPE*)buf1;
}
else if (iter % 2 == 1) {
iData = buf1;
oData = buf2;
iData = (DTYPE*)buf1;
oData = (DTYPE*)buf2;
}
else {
iData = buf2;
oData = buf1;
iData = (DTYPE*)buf2;
oData = (DTYPE*)buf1;
}
/* unroll the reduction procedure. The code is messy but it is faster. */
if (strideNum <= 32) {
......@@ -830,8 +830,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
__half * buf1ft16 = (__half *)buf1;
__half * buf2ft16 = (__half *)buf2;
__half * spft16 = (__half *)sp;
unsigned short power2 = FloatToFloat16(power);
__half * powerft16p = (__half*)&power2;
__half powerft16p = __float2half(power);
__half * iData = NULL;
__half * oData = NULL;
if (iter == 0) {
......@@ -854,7 +853,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
if (cudaGridSize[0] == 1)
oData = (__half*)output->data;
KernelReduceSum <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y,
blockSize, blockNum, spft16, *powerft16p, isExp);
blockSize, blockNum, spft16, powerft16p, isExp);
}
else if (strideNum < 128) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 64), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
......@@ -863,7 +862,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
oData = (__half*)output->data;
CheckNTErrors((cudaBlockSize[0] >= 64), "Incorrect thread number when calling the cuda kernel!");
KernelReduceSumFast<64> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y,
blockSize, blockNum, spft16, *powerft16p, isExp);
blockSize, blockNum, spft16, powerft16p, isExp);
}
else if (strideNum < 256) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 128), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
......@@ -872,7 +871,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
oData = (__half*)output->data;
CheckNTErrors((cudaBlockSize[0] >= 128), "Incorrect thread number when calling the cuda kernel!");
KernelReduceSumFast<128> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y,
blockSize, blockNum, spft16, *powerft16p, isExp);
blockSize, blockNum, spft16, powerft16p, isExp);
}
else if (strideNum < 512) {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 256), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
......@@ -881,7 +880,7 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
oData = (__half*)output->data;
CheckNTErrors((cudaBlockSize[0] >= 256), "Incorrect thread number when calling the cuda kernel!");
KernelReduceSumFast<256> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y,
blockSize, blockNum, spft16, *powerft16p, isExp);
blockSize, blockNum, spft16, powerft16p, isExp);
}
else {
GDevs.GetCudaThread2D(devID, MAX(strideNum / 2 + 1, 512), stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
......@@ -890,9 +889,12 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
oData = (__half*)output->data;
CheckNTErrors((cudaBlockSize[0] >= 512), "Incorrect thread number when calling the cuda kernel!");
KernelReduceSumFast<512> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y,
blockSize, blockNum, spft16, *powerft16p, isExp);
blockSize, blockNum, spft16, powerft16p, isExp);
}
}
else {
ShowNTErrors("Unsupported dataType!");
}
strideNum = cudaGridSize[0];
blockSize = cudaGridSize[0];
......
......@@ -268,6 +268,14 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
KernelUnsqueezeByCol<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockNumA, dSize, b->data);
}
else if (a->dataType == X_FLOAT16 && b->dataType == X_FLOAT16) {
if (cudaBlocks[1] == 1)
KernelUnsqueezeByColBigRow<__half> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockNumA, dSize, b->data);
else
KernelUnsqueezeByCol<__half> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockNumA, dSize, b->data);
}
else {
ShowNTErrors("TODO!");
}
......@@ -285,6 +293,10 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
KernelUnsqueeze<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockNumA, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_FLOAT16 && b->dataType == X_FLOAT16) {
KernelUnsqueeze<half> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockNumA, blockSize * dSize, b->data, dSize);
}
else {
ShowNTErrors("TODO!");
}
......@@ -300,6 +312,10 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
KernelUnsqueezeFlat2D<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_FLOAT16 && b->dataType == X_FLOAT16) {
KernelUnsqueezeFlat2D<half> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else {
ShowNTErrors("TODO!");
}
......@@ -315,6 +331,10 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
KernelUnsqueezeFlatBigram<int> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_FLOAT16 && b->dataType == X_FLOAT16) {
KernelUnsqueezeFlatBigram<half> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else {
ShowNTErrors("TODO!");
}
......@@ -330,6 +350,10 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
KernelUnsqueezeFlat<int> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_FLOAT16 && b->dataType == X_FLOAT16) {
KernelUnsqueezeFlat<half> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else {
ShowNTErrors("TODO!");
}
......
......@@ -863,6 +863,11 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k,
((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN,
(DTYPE*)b->data, (int*)index->data, isSorted);
}
else if (a->dataType == X_FLOAT16) {
KernelTopK3<__half> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((__half*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN,
(__half*)b->data, (int*)index->data, isSorted);
}
else {
ShowNTErrors("TODO!");
}
......
......@@ -128,16 +128,14 @@ void CudaGPUToCPUFlush(XTensor * tensor, int devID, XMem * CPUMem)
/* copy the data from GPU memory to CPU memory ((dataHost)) and do not delete the data */
void CudaGPUToCPUFlush(XTensor * tensor)
{
CheckNTErrors((sizeof(DTYPE) == tensor->unitSize), "Unsupported data type.");
if (tensor->dataHost != NULL)
delete[](char*)tensor->dataHost;
if (tensor->isSparse) {
int num = int(tensor->unitNum * tensor->denseRatio + 1);
cudaMemcpy(&num, (DTYPE*)tensor->data, sizeof(int), cudaMemcpyDeviceToHost);
cudaMemcpy(&num, tensor->data, sizeof(int), cudaMemcpyDeviceToHost);
int tupleSize = sizeof(int) + sizeof(DTYPE);
int tupleSize = sizeof(int) + tensor->unitSize;
int size = sizeof(int) + tupleSize*(num);
CheckNTErrors((size >= 0), "Illegal data size in the sparse matrix!");
......
......@@ -36,17 +36,18 @@ y = 1 if x > 1
>> y - output data array
>> size - size of input/output
*/
template <class T>
__global__
void KernelHardtanhCompute(DTYPE * x, DTYPE * y, int size)
void KernelHardtanhCompute(T * x, T * y, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size){
DTYPE p = x[i];
if(p > (DTYPE)1.0)
p = (DTYPE)1.0;
else if(p < (DTYPE)-1.0)
p = (DTYPE)-1.0;
T p = x[i];
if (p >(T)1.0)
p = (T)1.0;
else if (p < (T)-1.0)
p = (T)-1.0;
y[i] = p;
}
}
......@@ -71,7 +72,16 @@ void _CudaHardTanH(const XTensor * x, XTensor * y)
int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup);
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
KernelHardtanhCompute<<<dim3(gridSize[0]), dim3(blockSize[0])>>>((DTYPE*)x->data, (DTYPE*)y->data, x->unitNum);
}
else if (x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16) {
KernelHardtanhCompute<<<dim3(gridSize[0]), dim3(blockSize[0])>>>((__half *)x->data, (__half *)y->data, x->unitNum);
}
else {
//TODO!
ShowNTErrors("TODO!");
}
BacktoCudaDev(x->devID, devIDBackup);
}
......@@ -84,18 +94,18 @@ dy/dx = 1 if -1 <= x <= 1
>> dedy - dE/dy
>> dedx - dE/dx
>> y - y of the function
>> x - x of the function
>> size - size of y/x
*/
template <class T>
__global__
void KernelHardtanhBackward(DTYPE * dedy, DTYPE * dedx, DTYPE * x, int size)
void KernelHardtanhBackward(T * dedy, T * dedx, T * x, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size){
DTYPE s = x[i];
if(s > (DTYPE)1.0 || s < (DTYPE)-1.0)
T s = x[i];
if(s > (T)1.0 || s < (T)-1.0)
dedx[i] = 0;
else
dedx[i] = dedy[i];
......@@ -129,12 +139,25 @@ void _CudaHardTanHBackward(XTensor * y, XTensor * x,
int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup);
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
/* dE/dx = dE/dy * dy/dx */
KernelHardtanhBackward<<<dim3(gridSize[0]),dim3(blockSize[0])>>>
((DTYPE*)dedy->data,
(DTYPE*)dedx->data,
(DTYPE*)x->data,
x->unitNum);
}
else if (x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16) {
/* dE/dx = dE/dy * dy/dx */
KernelHardtanhBackward<<<dim3(gridSize[0]), dim3(blockSize[0])>>>
((half*)dedy->data,
(half*)dedx->data,
(half*)x->data,
x->unitNum);
}
else {
ShowNTErrors("Unsupported dataType!");
}
BacktoCudaDev(x->devID, devIDBackup);
}
......
......@@ -17,6 +17,7 @@
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-26
* $Update by: Lin Ye (email: linye2015@outlook.com) 2019-07-01 float16 added
*/
#include "LogSoftmax.h"
......@@ -27,6 +28,7 @@
#include "../core/reduce/ReduceMax.cuh"
#include "../core/shape/IsSameShaped.h"
#include "../XDevice.h"
#include <device_launch_parameters.h>
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -58,11 +60,12 @@ y_{i,j} = log(e^x_{i,j} / \sum_{i} e^{x_{i,j})
>> rowNum - row number of the matrix
>> colNum - column number of the matrix
*/
template <class T ,TENSOR_DATA_TYPE dataType>
__global__
void KernelLogSoftmaxComputeByRow(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y, int rowNum, int colNum)
void KernelLogSoftmaxComputeByRow(T * x, T * max, T * sum, T * y, int rowNum, int colNum)
{
__shared__ DTYPE inputSum[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE inputMax[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ T inputSum[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ T inputMax[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int i = blockDim.y * blockIdx.y + threadIdx.y;
int j = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -79,7 +82,8 @@ void KernelLogSoftmaxComputeByRow(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y
/* y_{i,j} = log(e^(s_{i,j} - max_{j}) / \sum_{k} e^{s_{k,j} - max_{j}}) */
if (i < rowNum && j < colNum) {
int key = i * colNum + j;
DTYPE r = log(exp(x[key] - inputMax[threadIdx.x]) / inputSum[threadIdx.x]);
if (dataType == DEFAULT_DTYPE) {
DTYPE r = log((DTYPE)exp((DTYPE)(x[key] - inputMax[threadIdx.x])) / (DTYPE)inputSum[threadIdx.x]);
if (isnan(r))
r = LOGPROB_MIN;
......@@ -88,6 +92,11 @@ void KernelLogSoftmaxComputeByRow(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y
y[key] = MAX(r, LOGPROB_MIN);
}
else if (dataType == X_FLOAT16) {
half r = hlog((half)hexp(x[key] - inputMax[threadIdx.y]) / (half)inputSum[threadIdx.y]);
y[key] = r;
}
}
}
/*
......@@ -105,11 +114,12 @@ y_{i,j} = log(e^x_{i,j} / \sum_{j} e^{x_{i,j})
>> rowNum - row number of the matrix
>> colNum - column number of the matrix
*/
template <class T ,TENSOR_DATA_TYPE dataType>
__global__
void KernelLogSoftmaxComputeByCol(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y, int rowNum, int colNum)
void KernelLogSoftmaxComputeByCol(T * x, T * max, T * sum, T * y, int rowNum, int colNum)
{
__shared__ DTYPE inputSum[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE inputMax[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ T inputSum[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ T inputMax[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int i = blockDim.y * blockIdx.y + threadIdx.y;
int j = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -126,7 +136,8 @@ void KernelLogSoftmaxComputeByCol(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y
/* y_{i,j} = log(e^(s_{i,j} - max_{i}) / \sum_{k} e^{s_{i,k} - max_{i}}) */
if (i < rowNum && j < colNum) {
int key = i * colNum + j;
DTYPE r = log(exp(x[key] - inputMax[threadIdx.y]) / inputSum[threadIdx.y]);
if (dataType == DEFAULT_DTYPE) {
DTYPE r = log((DTYPE)exp((DTYPE)(x[key] - inputMax[threadIdx.y])) / (DTYPE)inputSum[threadIdx.y]);
/*if (r < LOGPROB_MIN)
{
......@@ -140,6 +151,11 @@ void KernelLogSoftmaxComputeByCol(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y
y[key] = MAX(r, LOGPROB_MIN);
}
else if (dataType == X_FLOAT16) {
half r = hlog((half)hexp(x[key] - inputMax[threadIdx.y]) / (half)inputSum[threadIdx.y]);
y[key] = r;
}
}
}
/*
......@@ -174,17 +190,37 @@ void _CudaLogSoftmaxSumMax(XTensor * x, XTensor * y, int leadDim, XTensor * sum,
GDevs.GetCudaThread2D(x->devID, n, m, MAX_INT, gridSize, blockSize);
/* y_{i,j} = log(e^(s_{i,j} - max_{j}) / \sum_{k} e^{s_{k,j} - max_{j}}) */
KernelLogSoftmaxComputeByRow << <dim3(gridSize[1], gridSize[0]), dim3(blockSize[1], blockSize[0]) >> >
KernelLogSoftmaxComputeByRow<DTYPE, DEFAULT_DTYPE> <<<dim3(gridSize[1], gridSize[0]), dim3(blockSize[1], blockSize[0])>>>
((DTYPE*)x->data, maxData, sumData, (DTYPE*)y->data, n, m);
}
else {
GDevs.GetCudaThread2D(x->devID, m, n, MAX_INT, gridSize, blockSize);
/* y_{i,j} = log(e^(s_{i,j} - max_{i}) / \sum_{k} e^{s_{i,k} - max_{i}}) */
KernelLogSoftmaxComputeByCol << <dim3(gridSize[0], gridSize[1]), dim3(blockSize[0], blockSize[1]) >> >
KernelLogSoftmaxComputeByCol<DTYPE, DEFAULT_DTYPE> <<<dim3(gridSize[0], gridSize[1]), dim3(blockSize[0], blockSize[1])>>>
((DTYPE*)x->data, maxData, sumData, (DTYPE*)y->data, n, m);
}
}
else if (x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16) {
int gridSize[3], blockSize[3];
int n = x->dimSize[0];
int m = x->dimSize[1];
/* allocate the buffer */
__half * maxData = (half*)max->data;
__half * sumData = (half*)sum->data;
if (leadDim == 0) {
GDevs.GetCudaThread2D(x->devID, n, m, MAX_INT, gridSize, blockSize);
/* y_{i,j} = log(e^(s_{i,j} - max_{j}) / \sum_{k} e^{s_{k,j} - max_{j}}) */
KernelLogSoftmaxComputeByRow<half, X_FLOAT16> <<<dim3(gridSize[1], gridSize[0]), dim3(blockSize[1], blockSize[0])>>>
((half*)x->data, maxData, sumData, (half *)y->data, n, m);
}
else {
GDevs.GetCudaThread2D(x->devID, m, n, MAX_INT, gridSize, blockSize);
/* y_{i,j} = log(e^(s_{i,j} - max_{i}) / \sum_{k} e^{s_{i,k} - max_{i}}) */
KernelLogSoftmaxComputeByCol<half, X_FLOAT16> <<<dim3(gridSize[0], gridSize[1]), dim3(blockSize[0], blockSize[1])>>>
((half*)x->data, maxData, sumData, (half*)y->data, n, m);
}
}
else {
ShowNTErrors("TODO!");
}
......
......@@ -34,15 +34,16 @@ rectify : y = x if x >= 0
>> output - output tensor
>> size - size of input/output
*/
template<class T>
__global__
void KernelRectify(DTYPE * x, DTYPE * y, int size)
void KernelRectify(T * x, T * y, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size){
DTYPE p = x[i];
if(p < 0)
p = 0;
T p = x[i];
if(p < (T)0.0)
p = (T)0.0;
y[i] = p;
}
}
......@@ -61,8 +62,18 @@ void _CudaRectify(const XTensor * x, XTensor * y)
int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup);
if (x->dataType == DEFAULT_DTYPE) {
KernelRectify<<<dim3(gridSize[0]), dim3(blockSize[0])>>>
((DTYPE*)x->data, (DTYPE*)y->data, x->unitNum);
}
else if (x->dataType == X_FLOAT16) {
KernelRectify<<<dim3(gridSize[0]), dim3(blockSize[0]) >> >
((__half*)x->data, (__half*)y->data, x->unitNum);
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
BacktoCudaDev(x->devID, devIDBackup);
}
......@@ -78,17 +89,18 @@ dy/dx = 1 if x >= 0
>> x - input of the function
>> size - size of output/input
*/
template<class T>
__global__
void KernelRectifyBackward(DTYPE * dedy, DTYPE * dedx, DTYPE * x, int size)
void KernelRectifyBackward(T * dedy, T * dedx, T * x, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size){
DTYPE s = x[i];
if(s >= 0)
T s = x[i];
if(s >= (T)0.0)
dedx[i] = dedy[i];
else
dedx[i] = 0;
dedx[i] = (T)0.0;
}
}
......@@ -119,11 +131,24 @@ void _CudaRectifyBackward(XTensor * y, XTensor * x,
ProtectCudaDev(x->devID, devIDBackup);
/* dE/ds = dE/dy * dy/ds */
if (x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE) {
KernelRectifyBackward<<<dim3(gridSize[0]),dim3(blockSize[0])>>>
((DTYPE*)dedy->data,
(DTYPE*)dedx->data,
(DTYPE*)x->data,
x->unitNum);
}
else if (x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16) {
KernelRectifyBackward<<<dim3(gridSize[0]), dim3(blockSize[0]) >> >
((__half*)dedy->data,
(__half*)dedx->data,
(__half*)x->data,
x->unitNum);
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
BacktoCudaDev(x->devID, devIDBackup);
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论