Commit feb29d20 by xiaotong

updates of SetDevice

parent 2b03a447
...@@ -268,6 +268,28 @@ int XDevice::GetGPUDevice() ...@@ -268,6 +268,28 @@ int XDevice::GetGPUDevice()
#endif #endif
} }
/*
swith to a device (CPU or GPU)
>> devID - device id
*/
void XDevice::SetDevice(int devID)
{
if(devID >= 0)
SetGPUDevice(devID);
}
/*
swith to a device (CPU or GPU) with a backup of the device id
>> devID - device id
>> backupDevID - backup of the device id
*/
void XDevice::SetDevice(int devID, int &backupDevID)
{
backupDevID = GetGPUDevice();
if (devID >= 0)
SetGPUDevice(devID);
}
/* reset cuda flag for more efficient cuda execution. It should be called after "SetGPUDevice" when /* reset cuda flag for more efficient cuda execution. It should be called after "SetGPUDevice" when
no GPU context has been established. */ no GPU context has been established. */
void XDevice::SetFastFlags() void XDevice::SetFastFlags()
......
...@@ -138,7 +138,7 @@ public: ...@@ -138,7 +138,7 @@ public:
cublasHandle_t * GetCublasHandle(); cublasHandle_t * GetCublasHandle();
#endif #endif
/* switch to a device */ /* switch to a GPU device */
static static
void SetGPUDevice(int devID); void SetGPUDevice(int devID);
...@@ -146,10 +146,18 @@ public: ...@@ -146,10 +146,18 @@ public:
static static
void SetGPUDeviceFast(int devID); void SetGPUDeviceFast(int devID);
/* switch to a get current dev */ /* get current dev */
static static
int GetGPUDevice(); int GetGPUDevice();
/* swith to a device (CPU or GPU) */
static
void SetDevice(int devID);
/* swith to a device (CPU or GPU) with a backup of the device id */
static
void SetDevice(int devID, int &backupDevID);
/* reset cuda flag for more efficient cuda execution */ /* reset cuda flag for more efficient cuda execution */
static static
void SetFastFlags(); void SetFastFlags();
......
...@@ -170,7 +170,7 @@ void XQueue::RunJobConsumer(int jobDevID) ...@@ -170,7 +170,7 @@ void XQueue::RunJobConsumer(int jobDevID)
isJobQueue = true; isJobQueue = true;
jobDequeuerArgs->Clear(); jobDequeuerArgs->Clear();
// warning: this may cause unknown error /* warning: this may cause unknown errors */
jobDequeuerArgs->Add(this); jobDequeuerArgs->Add(this);
jobDequeuerArgs->Add(jobDevID >= 0 ? (devids + jobDevID) : &cpuid); jobDequeuerArgs->Add(jobDevID >= 0 ? (devids + jobDevID) : &cpuid);
...@@ -214,10 +214,8 @@ void XQueue::DequeueJobs(XList * args) ...@@ -214,10 +214,8 @@ void XQueue::DequeueJobs(XList * args)
XQueue * q = (XQueue*)args->GetItem(0); XQueue * q = (XQueue*)args->GetItem(0);
int devID = *(int*)args->GetItem(1); int devID = *(int*)args->GetItem(1);
int devIDBackup = XDevice::GetGPUDevice(); int devIDBackup = -1;
XDevice::SetDevice(devID, devIDBackup);
if(devID >= 0)
XDevice::SetGPUDevice(devID);
while(1){ while(1){
JobQueueNode * node = (JobQueueNode*)q->Dequeue(); JobQueueNode * node = (JobQueueNode*)q->Dequeue();
...@@ -238,8 +236,7 @@ void XQueue::DequeueJobs(XList * args) ...@@ -238,8 +236,7 @@ void XQueue::DequeueJobs(XList * args)
} }
if(devID >= 0) XDevice::SetDevice(devIDBackup);
XDevice::SetGPUDevice(devIDBackup);
} }
/* get the break flag */ /* get the break flag */
......
...@@ -129,7 +129,7 @@ public: ...@@ -129,7 +129,7 @@ public:
void WaitForEmptyJobQueue(); void WaitForEmptyJobQueue();
/* run the job consumer */ /* run the job consumer */
void RunJobConsumer(int jobDevID = 0); void RunJobConsumer(int jobDevID = -1);
/* stop the job consumer */ /* stop the job consumer */
void StopJobConsumer(); void StopJobConsumer();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论