Commit d3d9ae88 by linye

1. XTensorBLAS bug fixed 2. gather backward float16 supported

parent 54adb703
......@@ -161,11 +161,11 @@ void _CudaBLASMatrixMULBatched(cublasHandle_t * handle,
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
}
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT) {
......@@ -173,11 +173,11 @@ void _CudaBLASMatrixMULBatched(cublasHandle_t * handle,
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
}
else if (dataTypeA == X_INT8 && dataTypeB == X_INT8 && dataTypeC == X_FLOAT) {
......@@ -193,11 +193,11 @@ void _CudaBLASMatrixMULBatched(cublasHandle_t * handle,
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
}
else {
......@@ -246,11 +246,11 @@ void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
}
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT) {
......@@ -278,11 +278,11 @@ void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
}
else {
......
......@@ -234,7 +234,6 @@ void _SpreadForGather(XTensor * source, XTensor * collection, XTensor * index)
int dim = 0;
int order = source->order;
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(collection->GetDim(-1) == source->GetDim(-1), "Illegal dimension!");
CheckNTErrors(collection->unitNum/collection->GetDim(-1) == index->unitNum,
"Illegal dimension!");
......
......@@ -330,12 +330,13 @@ Care of the operator "+=" instead of "=".
>> indexSize - the number of index
>> stride - stride of a data block
*/
template <class T, TENSOR_DATA_TYPE datatype>
__global__
void KernelSpreadForGather(DTYPE * sData, DTYPE * cData, int * srcIndex,
void KernelSpreadForGather(T * sData, T * cData, int * srcIndex,
int indexSize, int stride)
{
__shared__ DTYPE * sp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE * cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ T * sp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ T * cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -353,13 +354,18 @@ void KernelSpreadForGather(DTYPE * sData, DTYPE * cData, int * srcIndex,
__syncthreads();
DTYPE * s = sp[threadIdx.x];
DTYPE * c = cp[threadIdx.x];
T * s = sp[threadIdx.x];
T * c = cp[threadIdx.x];
//DTYPE * s = sData + srcIndex[i] * stride;
//DTYPE * c = cData + i * stride;
atomicAdd(s + offset, c[offset]);
if (datatype == X_FLOAT) {
atomicAdd(((DTYPE*)s + offset), *((DTYPE*)c + offset));
}
else if (datatype == X_FLOAT16) {
atomicAdd(((__half2*)s + offset), *((__half2*)c + offset));
}
}
/*
......@@ -372,6 +378,10 @@ And this is a special spread function for backward computation of gather functio
*/
void _CudaSpreadForGather(XTensor * source, XTensor * collection, XTensor * srcIndex)
{
CheckNTErrors((source->dataType == X_FLOAT) ||
(source->dataType == X_FLOAT16),
"Unmatched tensors in gather!");
int devID = source->devID;
XMem * mem = source->mem;
......@@ -384,8 +394,6 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, XTensor * srcI
int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup);
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
int * sIndex = NULL;
GDevs.GetCudaThread2D(devID, indexSize, stride, MAX_INT, cudaGrids, cudaBlocks);
......@@ -402,7 +410,19 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, XTensor * srcI
else
sIndex = (int *)srcIndex->data;
KernelSpreadForGather<<<blocks, threads >>>(sData, cData, sIndex, indexSize, stride);
if (source->dataType == DEFAULT_DTYPE && collection->dataType == DEFAULT_DTYPE)
{
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
KernelSpreadForGather<DTYPE, X_FLOAT> << <blocks, threads >> >(sData, cData, sIndex, indexSize, stride);
}
else if (source->dataType == X_FLOAT16 && collection->dataType == X_FLOAT16)
{
__half2 * sData = (__half2*)source->data;
__half2 * cData = (__half2*)collection->data;
KernelSpreadForGather<__half2, X_FLOAT16> << <blocks, threads >> >(sData, cData, sIndex, indexSize, stride);
}
if (srcIndex->devID < 0) {
if(mem != NULL)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论