Commit 70e478c4 by xiaotong

make a default stream for each device

parent a027f72e
......@@ -40,6 +40,7 @@ XDevManager GDevs;
/* constructor */
XDevice::XDevice()
{
stream = NULL;
Clear();
#ifdef USE_CUDA
......@@ -55,6 +56,8 @@ XDevice::~XDevice()
MUTEX_DELE(cublasMutex);
if(isHandleReady)
cublasDestroy(cublasHandle);
if(stream != NULL)
delete stream;
#endif
}
......@@ -118,6 +121,8 @@ void XDevice::Init(int myDevID)
}
else
sprintf(name2, "GPU-%d %s", devID, name);
stream = new XStream(0, devID);
#endif
}
......@@ -161,6 +166,14 @@ cublasHandle_t * XDevice::GetCublasHandle()
return &cublasHandle;
}
/* get the stream of cuda */
cudaStream_t * XDevice::GetCudaStream()
{
CheckNTErrors(stream != NULL, "the stream is not initialized!");
return &stream->stream;
}
#endif // USE_CUDA
/* switch to a device */
......@@ -311,11 +324,19 @@ void XDevManager::Clear()
/* get the handle of GPU */
cublasHandle_t * XDevManager::GetCudaHandle(const int devID)
{
CheckNTErrors((devID < nGPU), "index of GPU is out of range.");
CheckNTErrors(devID < nGPU, "index of GPU is out of range.");
return GPUs[devID].GetCublasHandle();
}
/* get the stream of cuda */
cudaStream_t * XDevManager::GetCudaStream(const int devID)
{
CheckNTErrors(devID < nGPU, "index of GPU is out of range.");
return GPUs[devID].GetCudaStream();
}
#endif
/*
......
......@@ -25,6 +25,7 @@
#define __XDEVICE_H__
#include "XThread.h"
#include "XStream.h"
#ifdef USE_CUDA
......@@ -93,6 +94,9 @@ public:
/* specify whether Unified Virtual Address Space (UVA) is supported */
bool isUVASupported;
/* default stream for the device */
XStream * stream;
#ifdef USE_CUDA
/* mutex for handle (GPU cublas) */
MUTEX_HANDLE cublasMutex;
......@@ -121,6 +125,9 @@ public:
#ifdef USE_CUDA
/* get cublas handle */
cublasHandle_t * GetCublasHandle();
/* get the stream of cuda */
cudaStream_t * GetCudaStream();
#endif
/* switch to a device */
......@@ -178,6 +185,9 @@ public:
#ifdef USE_CUDA
/* get the handle of GPU */
cublasHandle_t * GetCudaHandle(const int devID);
/* get the stream of cuda */
cudaStream_t * GetCudaStream(const int devID);
#endif
/* get grid and block sizes that max potential */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论