Commit 9fd6a28f by xiaotong

bug fixes

parent 02b6c379
......@@ -71,7 +71,6 @@ void _CudaBLASMatrixMUL(cublasHandle_t * handle,
cublasSgemm(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha2, (const float*)b, mb, (const float*)a, ma, &beta2, (float*)c, mc);
}
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT16) {
#if CUDACC_VER_MAJOR >= 10
__half alpha2 = __float2half(alpha);
__half beta2 = __float2half(beta);
cublasSetMathMode(*handle, CUBLAS_TENSOR_OP_MATH);
......@@ -84,9 +83,6 @@ void _CudaBLASMatrixMUL(cublasHandle_t * handle,
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, (void*)&alpha2, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta2, c, CUDA_R_16F, mc, CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
#else
ShowNTErrors("Require Cuda Version >= 10.0!");
#endif
}
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT) {
float alpha2 = (float)alpha;
......@@ -149,7 +145,6 @@ void _CudaBLASMatrixMULBatched(cublasHandle_t * handle,
cublasSgemmBatched(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha2, (const float**)b, mb, (const float**)a, ma, &beta2, (float**)c, mc, count);
}
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT16) {
#if CUDACC_VER_MAJOR >= 10
__half alpha2 = __float2half(alpha);
__half beta2 = __float2half(beta);
cublasSetMathMode(*handle, CUBLAS_TENSOR_OP_MATH);
......@@ -162,12 +157,8 @@ void _CudaBLASMatrixMULBatched(cublasHandle_t * handle,
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, (void*)&alpha2, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta2, c, CUDA_R_16F, mc, count, CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
#else
ShowNTErrors("Require Cuda Version >= 10.0!");
#endif
}
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT) {
#if CUDACC_VER_MAJOR >= 10
float alpha2 = (float)alpha;
float beta2 = (float)beta;
cublasSetMathMode(*handle, CUBLAS_TENSOR_OP_MATH);
......@@ -180,9 +171,6 @@ void _CudaBLASMatrixMULBatched(cublasHandle_t * handle,
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, (void*)&alpha2, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta2, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
#else
ShowNTErrors("Require Cuda Version >= 10.0!");
#endif
}
else {
ShowNTErrors("Unsupported data type!");
......@@ -226,7 +214,6 @@ void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
cublasSgemmStridedBatched(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha2, (const float*)b, mb, strideB, (const float*)a, ma, strideA, &beta2, (float*)c, mc, strideC, count);
}
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT16) {
#if CUDACC_VER_MAJOR >= 10
__half alpha2 = __float2half(alpha);
__half beta2 = __float2half(beta);
cublasSetMathMode(*handle, CUBLAS_TENSOR_OP_MATH);
......@@ -239,12 +226,8 @@ void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, (void*)&alpha2, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta2, c, CUDA_R_16F, mc, strideC, count, CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
#else
ShowNTErrors("Require Cuda Version >= 10.0!");
#endif
}
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT) {
#if CUDACC_VER_MAJOR >= 10
float alpha2 = (float)alpha;
float beta2 = (float)beta;
cublasSetMathMode(*handle, CUBLAS_TENSOR_OP_MATH);
......@@ -257,9 +240,6 @@ void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, (void*)&alpha2, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta2, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
#else
ShowNTErrors("Require Cuda Version >= 10.0!");
#endif
}
else {
ShowNTErrors("Unsupported data type!");
......
......@@ -827,7 +827,6 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
}
}
else if (input->dataType == X_FLOAT16) {
#if CUDACC_VER_MAJOR >= 10
__half * buf1ft16 = (__half *)buf1;
__half * buf2ft16 = (__half *)buf2;
__half * spft16 = (__half *)sp;
......@@ -892,9 +891,6 @@ void _CudaReduceSum(const XTensor * input, XTensor * output, int dim, const XTen
KernelReduceSumFast<512> <<<blocks, threads>>> (iData, oData, stride, strideNum, blocks.y,
blockSize, blockNum, spft16, powerft16p, isExp);
}
#else
ShowNTErrors("Require Cuda Version >= 10.0!");
#endif
}
else {
ShowNTErrors("Unsupported dataType!");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论