Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
T
Tensor.LowPrecision
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
linye
Tensor.LowPrecision
Commits
f151f061
Commit
f151f061
authored
5 years ago
by
linye
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
no message
parent
70308bd9
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
22 行增加
和
0 行删除
+22
-0
source/tensor/core/arithmetic/XTensorBLAS.cu
+22
-0
没有找到文件。
source/tensor/core/arithmetic/XTensorBLAS.cu
查看文件 @
f151f061
...
@@ -291,6 +291,28 @@ void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
...
@@ -291,6 +291,28 @@ void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, (__int8*)&alpha2, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, (__int8*)&beta2, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, (__int8*)&alpha2, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, (__int8*)&beta2, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
}
}
else if (dataTypeA == X_INT8 && dataTypeB == X_INT8 && dataTypeC == X_INT) {
int alpha2 = (int)alpha;
int beta2 = (int)beta;
/*
CUDA requires that the dimension of two tensor( lda, ldb ) should be multiples of 4.
details in https://devtalk.nvidia.com/default/topic/999101/about-cublasgemm-int8-support/
*/
if (mb % 4 != 0 || ma % 4 != 0) {
ShowNTErrors("mb, ma( lda, ldb ) should be multiples of 4!");
return;
}
cublasSetMathMode(*handle, CUBLAS_TENSOR_OP_MATH);
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (__int8*)&alpha2, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, (__int8*)&beta2, c, CUDA_C_32I, mc, strideC, count, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, ma, (__int8*)&alpha2, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, (__int8*)&beta2, c, CUDA_C_32I, mc, strideC, count, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, (__int8*)&alpha2, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, (__int8*)&beta2, c, CUDA_C_32I, mc, strideC, count, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, (__int8*)&alpha2, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, (__int8*)&beta2, c, CUDA_C_32I, mc, strideC, count, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
}
else {
else {
ShowNTErrors("Unsupported data type!");
ShowNTErrors("Unsupported data type!");
}
}
...
...
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论