Commit f151f061 by linye

no message

parent 70308bd9
......@@ -291,6 +291,28 @@ void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, (__int8*)&alpha2, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, (__int8*)&beta2, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
}
else if (dataTypeA == X_INT8 && dataTypeB == X_INT8 && dataTypeC == X_INT) {
int alpha2 = (int)alpha;
int beta2 = (int)beta;
/*
CUDA requires that the dimension of two tensor( lda, ldb ) should be multiples of 4.
details in https://devtalk.nvidia.com/default/topic/999101/about-cublasgemm-int8-support/
*/
if (mb % 4 != 0 || ma % 4 != 0) {
ShowNTErrors("mb, ma( lda, ldb ) should be multiples of 4!");
return;
}
cublasSetMathMode(*handle, CUBLAS_TENSOR_OP_MATH);
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (__int8*)&alpha2, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, (__int8*)&beta2, c, CUDA_C_32I, mc, strideC, count, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, ma, (__int8*)&alpha2, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, (__int8*)&beta2, c, CUDA_C_32I, mc, strideC, count, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, (__int8*)&alpha2, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, (__int8*)&beta2, c, CUDA_C_32I, mc, strideC, count, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, (__int8*)&alpha2, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, (__int8*)&beta2, c, CUDA_C_32I, mc, strideC, count, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
}
else {
ShowNTErrors("Unsupported data type!");
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论