Commit 5c0d8bfd by xiaotong

protection of cuda code

Merge branch 'xiaotong-working' of 47.105.50.196:NiuTrans/NiuTrans.Tensor into xiaotong-working

# Conflicts:
#	.gitignore
parent 2bb8754f
......@@ -186,6 +186,7 @@ void _MatrixMulBatchedGPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{
#ifdef USE_CUDA
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
"Input tensors should have the same data type!");
......@@ -226,6 +227,7 @@ void _MatrixMulBatchedGPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
a->dimSizeRDI[1], a->dimSizeRDI[0],
b->dimSizeRDI[1], b->dimSizeRDI[0],
c->dimSizeRDI[1], c->dimSizeRDI[0], alpha, beta);
#endif
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论