Commit ba4c8180 by liyinqiao

Bug fixed and clean codes. (IMPORTANT)

1. Fix the bugs for MatrixMul functions which may cause nan in some cases. (IMPORTANT)
2. Fix the minor errors.
parent a85e2079
...@@ -159,6 +159,10 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -159,6 +159,10 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
"The code must be run on the same GPU!"); "The code must be run on the same GPU!");
int devIDBackup; int devIDBackup;
if (beta == 0)
c->SetZeroAll();
ProtectCudaDev(a->devID, devIDBackup); ProtectCudaDev(a->devID, devIDBackup);
cublasHandle_t * handle = a->mem != NULL ? a->mem->GetCublasHandle() : GDevs.GetCudaHandle(a->devID); cublasHandle_t * handle = a->mem != NULL ? a->mem->GetCublasHandle() : GDevs.GetCudaHandle(a->devID);
......
...@@ -156,6 +156,9 @@ void _CudaMatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -156,6 +156,9 @@ void _CudaMatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
if (stream != NULL) if (stream != NULL)
cublasSetStream(*handle, stream->stream); cublasSetStream(*handle, stream->stream);
if (beta == 0)
c->SetZeroAll();
if (a->dataType == X_FLOAT && b->dataType == X_FLOAT && c->dataType == X_FLOAT) { if (a->dataType == X_FLOAT && b->dataType == X_FLOAT && c->dataType == X_FLOAT) {
_CudaBLASMatrixMUL(handle, a->data, transposedA, a->dataType, _CudaBLASMatrixMUL(handle, a->data, transposedA, a->dataType,
b->data, transposedB, a->dataType, c->data, c->dataType, b->data, transposedB, a->dataType, c->data, c->dataType,
......
...@@ -54,6 +54,9 @@ void _MatrixMul2DParallel(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -54,6 +54,9 @@ void _MatrixMul2DParallel(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
int aColNum = am; int aColNum = am;
int bColNum = bm; int bColNum = bm;
if (beta == 0)
c->SetZeroAll();
/* a * b */ /* a * b */
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS) { if (transposedA == X_NOTRANS && transposedB == X_NOTRANS) {
RunParallel2D(parallelRunner, (void*)_MatrixMul2DMultiTheading, an * am * bm, RunParallel2D(parallelRunner, (void*)_MatrixMul2DMultiTheading, an * am * bm,
......
...@@ -118,6 +118,9 @@ void _MatrixMulBatchedGPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -118,6 +118,9 @@ void _MatrixMulBatchedGPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
blockNum *= a->dimSize[i]; blockNum *= a->dimSize[i];
} }
if (beta == 0)
c->SetZeroAll();
int devIDBackup = 0; int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup); ProtectCudaDev(a->devID, devIDBackup);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论