Commit 70e478c4 by xiaotong

make a default stream for each device

parent a027f72e
...@@ -40,6 +40,7 @@ XDevManager GDevs; ...@@ -40,6 +40,7 @@ XDevManager GDevs;
/* constructor */ /* constructor */
XDevice::XDevice() XDevice::XDevice()
{ {
stream = NULL;
Clear(); Clear();
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -55,6 +56,8 @@ XDevice::~XDevice() ...@@ -55,6 +56,8 @@ XDevice::~XDevice()
MUTEX_DELE(cublasMutex); MUTEX_DELE(cublasMutex);
if(isHandleReady) if(isHandleReady)
cublasDestroy(cublasHandle); cublasDestroy(cublasHandle);
if(stream != NULL)
delete stream;
#endif #endif
} }
...@@ -118,6 +121,8 @@ void XDevice::Init(int myDevID) ...@@ -118,6 +121,8 @@ void XDevice::Init(int myDevID)
} }
else else
sprintf(name2, "GPU-%d %s", devID, name); sprintf(name2, "GPU-%d %s", devID, name);
stream = new XStream(0, devID);
#endif #endif
} }
...@@ -161,6 +166,14 @@ cublasHandle_t * XDevice::GetCublasHandle() ...@@ -161,6 +166,14 @@ cublasHandle_t * XDevice::GetCublasHandle()
return &cublasHandle; return &cublasHandle;
} }
/* get the stream of cuda */
cudaStream_t * XDevice::GetCudaStream()
{
CheckNTErrors(stream != NULL, "the stream is not initialized!");
return &stream->stream;
}
#endif // USE_CUDA #endif // USE_CUDA
/* switch to a device */ /* switch to a device */
...@@ -311,11 +324,19 @@ void XDevManager::Clear() ...@@ -311,11 +324,19 @@ void XDevManager::Clear()
/* get the handle of GPU */ /* get the handle of GPU */
cublasHandle_t * XDevManager::GetCudaHandle(const int devID) cublasHandle_t * XDevManager::GetCudaHandle(const int devID)
{ {
CheckNTErrors((devID < nGPU), "index of GPU is out of range."); CheckNTErrors(devID < nGPU, "index of GPU is out of range.");
return GPUs[devID].GetCublasHandle(); return GPUs[devID].GetCublasHandle();
} }
/* get the stream of cuda */
cudaStream_t * XDevManager::GetCudaStream(const int devID)
{
CheckNTErrors(devID < nGPU, "index of GPU is out of range.");
return GPUs[devID].GetCudaStream();
}
#endif #endif
/* /*
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define __XDEVICE_H__ #define __XDEVICE_H__
#include "XThread.h" #include "XThread.h"
#include "XStream.h"
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -93,6 +94,9 @@ public: ...@@ -93,6 +94,9 @@ public:
/* specify whether Unified Virtual Address Space (UVA) is supported */ /* specify whether Unified Virtual Address Space (UVA) is supported */
bool isUVASupported; bool isUVASupported;
/* default stream for the device */
XStream * stream;
#ifdef USE_CUDA #ifdef USE_CUDA
/* mutex for handle (GPU cublas) */ /* mutex for handle (GPU cublas) */
MUTEX_HANDLE cublasMutex; MUTEX_HANDLE cublasMutex;
...@@ -121,6 +125,9 @@ public: ...@@ -121,6 +125,9 @@ public:
#ifdef USE_CUDA #ifdef USE_CUDA
/* get cublas handle */ /* get cublas handle */
cublasHandle_t * GetCublasHandle(); cublasHandle_t * GetCublasHandle();
/* get the stream of cuda */
cudaStream_t * GetCudaStream();
#endif #endif
/* switch to a device */ /* switch to a device */
...@@ -178,6 +185,9 @@ public: ...@@ -178,6 +185,9 @@ public:
#ifdef USE_CUDA #ifdef USE_CUDA
/* get the handle of GPU */ /* get the handle of GPU */
cublasHandle_t * GetCudaHandle(const int devID); cublasHandle_t * GetCudaHandle(const int devID);
/* get the stream of cuda */
cudaStream_t * GetCudaStream(const int devID);
#endif #endif
/* get grid and block sizes that max potential */ /* get grid and block sizes that max potential */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论