Commit 8f665e61 by xiaotong

improve the code of _CudaBLASMatrixMULList

parent ac620226
...@@ -253,14 +253,14 @@ void _CudaBLASMatrixMULList(cublasHandle_t * handle, ...@@ -253,14 +253,14 @@ void _CudaBLASMatrixMULList(cublasHandle_t * handle,
if (isUniform) { if (isUniform) {
XMem * mem = a0->mem; XMem * mem = a0->mem;
if (isStrided && a->count > 1) { if (isStrided) {
_CudaBLASMatrixMULBatchedStrided(handle, _CudaBLASMatrixMULBatchedStrided(handle,
a0->data, transposedA, a0->dataType, strideA / a0->unitSize, a0->data, transposedA, a0->dataType, strideA / a0->unitSize,
b0->data, transposedB, b0->dataType, strideB / b0->unitSize, b0->data, transposedB, b0->dataType, strideB / b0->unitSize,
c0->data, c0->dataType, strideC / c0->unitSize, a->count, c0->data, c0->dataType, strideC / c0->unitSize, a->count,
a0->dimSize[0], a0->dimSize[1], a0->dimSize[0], a0->dimSize[1],
b0->dimSize[0], b0->dimSize[1], b0->dimSize[0], b0->dimSize[1],
c0->dimSize[0], c0->dimSize[1], alpha, beta); c0->dimSize[0], c0->dimSize[1], alpha, beta);
} }
else { else {
DTYPE ** ap = new DTYPE*[a->count]; DTYPE ** ap = new DTYPE*[a->count];
...@@ -324,12 +324,12 @@ void _CudaBLASMatrixMULList(cublasHandle_t * handle, ...@@ -324,12 +324,12 @@ void _CudaBLASMatrixMULList(cublasHandle_t * handle,
XTensor * ci = (XTensor*)c->GetItem(i); XTensor * ci = (XTensor*)c->GetItem(i);
_CudaBLASMatrixMUL(handle, _CudaBLASMatrixMUL(handle,
ai->data, transposedA, ai->dataType, ai->data, transposedA, ai->dataType,
bi->data, transposedB, bi->dataType, bi->data, transposedB, bi->dataType,
ci->data, ci->dataType, ci->data, ci->dataType,
ai->dimSize[0], ai->dimSize[1], ai->dimSize[0], ai->dimSize[1],
bi->dimSize[0], bi->dimSize[1], bi->dimSize[0], bi->dimSize[1],
ci->dimSize[0], ci->dimSize[1], alpha, beta); ci->dimSize[0], ci->dimSize[1], alpha, beta);
} }
} }
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论