Commit 052a62b5 by xiaotong

bug fixes in XQueue and XThread

parent 78307f09
......@@ -176,8 +176,9 @@ void XQueue::RunJobConsumer(int jobDevID)
jobDequeuer.SetFunc((TFunction)DequeueJobs, jobDequeuerArgs);
jobDequeuer.Start();
jobDequeuer.LetItGo();
//jobDequeuer.Start();
//jobDequeuer.LetItGo();
jobDequeuer.StartNow();
}
/* stop the job consumer */
......@@ -257,4 +258,24 @@ int XQueue::GetJobNum()
return c;
}
/*
get the number of items in the queue. Note that
this function is not the same as GetJobNum() because
"items" are the real elements we put into the queue.
"jobs" only make sense when the queue is running as a
job queue.
*/
int XQueue::GetItemNum()
{
MUTEX_LOCK(enqueueMutex);
MUTEX_LOCK(dequeueMutex);
int c = itemCount;
MUTEX_UNLOCK(dequeueMutex);
MUTEX_UNLOCK(enqueueMutex);
return c;
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -144,8 +144,15 @@ public:
/* get the break flag */
bool GetJobBreak();
/* get the number of jobs */
/* get the number of running jobs */
int GetJobNum();
/* get the number of items in the queue. Note that
this function is not the same as GetJobNum() because
"items" are the real elements we put into the queue.
"jobs" only make sense when the queue is running as a
job queue. */
int GetItemNum();
};
} /* end of the nts (NiuTrans.Tensor) namespace */
......
......@@ -225,6 +225,26 @@ void XThread::LetItGo()
#endif
}
/*
create the thread and run it immediately (a combination of
Start() and LetItGo() */
bool XThread::StartNow()
{
CheckNTErrors(jobCount == 0, "Cannot start a thread again when it is running!");
jobCount++;
Start();
#ifdef _WIN32
MUTEX_LOCK(workingMutex);
COND_RESET(jobCond);
MUTEX_UNLOCK(workingMutex);
COND_SIGNAL(jobCond);
#endif
return true;
}
/* waith for a singal */
void XThread::Wait(COND_HANDLE * c, MUTEX_HANDLE * m)
{
......
......@@ -143,6 +143,10 @@ public:
/* let the thread process a job */
void LetItGo();
/* create the thread and run it immediately (a combination of
Start() and LetItGo() */
bool StartNow();
/* waith for a singal */
static
void Wait(COND_HANDLE * c, MUTEX_HANDLE * m);
......
......@@ -182,25 +182,30 @@ void XLeader::WaitForFinishing(const int* activeJobWorkers, const int isToUpdate
XWorker* worker = (XWorker*)jworkers[i];
worker->DequeueFinishedJob();
activeCount++;
CheckNTErrors(worker->GetFinishedNumInQueue() == 0, "Incorrect job number!");
}
}
if (activeCount > 0 && isToUpdate) {
for (int i = 0; i < cworkers.count; i++) {
XWorker* worker = (XWorker*)cworkers[i];
for(int j = 0; j < serverModel.paramNum * activeCount; j++)
worker->DequeueFinishedJob();
CheckNTErrors(worker->GetFinishedNumInQueue() == 0, "Incorrect job number!");
}
for (int i = 0; i < uworkers.count; i++) {
XWorker* worker = (XWorker*)uworkers[i];
for (int j = 0; j < serverModel.paramNum; j++)
worker->DequeueFinishedJob();
CheckNTErrors(worker->GetFinishedNumInQueue() == 0, "Incorrect job number!");
}
for (int i = 0; i < bworkers.count; i++) {
XWorker* worker = (XWorker*)bworkers[i];
for (int j = 0; j < serverModel.paramNum; j++)
worker->DequeueFinishedJob();
CheckNTErrors(worker->GetFinishedNumInQueue() == 0, "Incorrect job number!");
}
}
}
......@@ -373,7 +378,6 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, XOptim
CheckNTErrors(bworkers.count > 0, "No bworkers!");
CheckNTErrors(pworkers.count > 0, "No pworkers!");
bool isDataOK = true;
bool isToUpdate = (optimizer != NULL);
int activeJobCount = 0;
int* active = new int[jworkers.count];
......@@ -526,9 +530,11 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
/* sp[j]->isGradFinished is true only if the model finishes the computation
(in another process) */
if (paramSource.flag == PARAM_STATE_NOT_READY && paramSource.param->isGradFinished) {
XQueue* jobQueue = (XQueue*)jobQueues.GetItem(j);
/* data transmit */
CollectP2P(paramSource.param->grad, paramServer.param->grad);
collecter->AddJobCollectDataP2P(jobQueue, paramSource.param->grad, paramServer.param->grad);
collecter->AddJobEnqueueFinished();
/* reset the flag */
paramSource.flag = PARAM_STATE_COLLECTED;
......@@ -538,21 +544,20 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
/* we call model update (in another thread) and then
broadcast the new parameters to member models
(in another thread) */
if (finishedCount[j] == memberActive->count) {
if (finishedCount[j] == members.count) {
paramServer.flag = PARAM_STATE_COLLECTED;
if (updater != NULL) {
XQueue* jobQueue = (XQueue*)jobQueues->GetItem(j);
/* update the parameters */
updater->AddJobUpdate(jobQueue, server, j, optimizer);
updater->AddJobUpdate(jobQueue, &serverModel, j, optimizer);
updater->AddJobEnqueueFinished(jobQueue);
/* broadcast the new parameter to other models*/
broadcaster->AddJobBroadcastSingle(jobQueue, server, memberAll, j);
broadcaster->AddJobBroadcastSingle(jobQueue, &serverModel, &membersAll, j);
broadcaster->AddJobEnqueueFinished(jobQueue);
}
}
else if (finishedCount[j] > memberActive->count) {
else if (finishedCount[j] > members.count) {
ShowNTErrors("Something is wrong with finishedCount!");
}
}
......@@ -560,10 +565,10 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
}
/* the collection finishes if all data tensors are processed */
if (finished == server->paramNum * memberActive->count)
if (finished == serverModel.paramNum * members.count)
break;
XSleep(sleepTime);
XSleep(SLEEP_TIME_IN_WAITING_JOB_WORKERS);
}
delete[] finishedCount;
......@@ -576,10 +581,10 @@ void XLeader::RunUpdate(XConfig * config, XOptimizer * optimizer, const int * ac
broadcast the lastest parameters to workers. NOTE that we would update
a worker to the laster model parameters, even if it is not involved
in this run. */
collecter->AddJobUpdateAll(&jobQueues,
&members, &membersAll, &serverModel,
optimizer, updater, broadcaster);
collecter->AddJobEnqueueFinished();
//collecter->AddJobUpdateAll(&jobQueues,
// &members, &membersAll, &serverModel,
// optimizer, updater, broadcaster);
//collecter->AddJobEnqueueFinished();
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -50,6 +50,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MAX_NUM_OF_WORKERS 1024
#define SLEEP_TIME_IN_WAITING_FOR_JOBS 20
#define SLEEP_TIME_IN_WAITING_JOB_WORKERS 5
/*
conmmunication mode of a leader. This offers a way of organizing a hierachy of the work
......
......@@ -112,7 +112,7 @@ bool XModel::RunMe(XList * args)
if (RunSimple(inputs, outputs, golds, losses))
return true;
ShowNTErrors("You must be overload one of these: XModel::RunSimple ... !");
ShowNTErrors("You must overload one of these: XModel::RunSimple ... !");
return false;
}
......
......@@ -177,4 +177,10 @@ void XWorker::AddJobDequeueFinished(XQueue* jobQueue)
}
/* get number of unflaged finished job */
int XWorker::GetFinishedNumInQueue()
{
return finishedQueue.GetItemNum();
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -126,6 +126,9 @@ public:
/* add a job of dequeuing a counting a finished job */
void AddJobDequeueFinished(XQueue* jobQueue = NULL);
/* get number of unflaged finished job */
int GetFinishedNumInQueue();
};
}
......
......@@ -202,6 +202,53 @@ void XWorkerCollect::UpdateAll(XList * args)
}
/*
add a new job of collecting data, update the parameter and
broadcast the new parameter
>> jobQueues - the queues that we would use in following jobs
>> memberActive - member models that are active, i.e., have generated gradients
>> memberAll - all member models
>> server - the server model
>> optimizer - the optimizer
>> updater - the worker that updates the parameters
>> broadcaster - the worker that broadcasts the new parameters to all member
models
<< return - successful or not
*/
bool XWorkerCollect::AddJobUpdateAll(XList * jobQueues,
XList * memberActive, XList * memberAll,
XModel * server,
XOptimizer * optimizer,
XWorkerUpdate * updater, XWorkerBroadcast * broadcaster)
{
CheckNTErrors(memberActive != NULL, "No input (active) member list!");
CheckNTErrors(memberAll != NULL, "No input (all) member list!");
CheckNTErrors(server != NULL, "No input server model!");
CheckNTErrors(optimizer != NULL, "No input optimizer!");
CheckNTErrors(updater != NULL, "No input updater!");
CheckNTErrors(broadcaster != NULL, "No input broadcaster!");
XList args;
args.Add(this);
args.AddInt(jobQueues->count);
args.AddList(jobQueues);
args.AddInt(memberActive->count);
args.AddList(memberActive);
args.AddInt(memberAll->count);
args.AddList(memberAll);
args.Add(server);
args.Add(optimizer);
args.Add(updater);
args.Add(broadcaster);
if (isInstantRun)
XWorkerCollect::UpdateAll(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerCollect::UpdateAll, &args);
return true;
}
/*
P2P data collection
target += source
......@@ -259,49 +306,41 @@ void XWorkerCollect::CollectAllReduce(XList * all)
ShowNTErrors("TODO!");
}
/* wrapper of Collect */
void XWorkerCollect::CollectDataP2P(XList * args)
{
int paramCount = 0;
XWorkerCollect * collecter = (XWorkerCollect*)args->GetItem(paramCount++);
XTensor * source = (XTensor*)args->GetItem(paramCount++);
XTensor * target = (XTensor*)args->GetItem(paramCount++);
if(collecter != NULL)
collecter->CollectP2P(source, target);
}
/*
add a new job of collecting data, update the parameter and
broadcast the new parameter
>> jobQueues - the queues that we would use in following jobs
>> memberActive - member models that are active, i.e., have generated gradients
>> memberAll - all member models
>> server - the server model
>> optimizer - the optimizer
>> updater - the worker that updates the parameters
>> broadcaster - the worker that broadcasts the new parameters to all member
models
<< return - successful or not
add a new job of collecting data
>> jobQueue - the queue where we run the job
>> source - where we collect the data from
>> target - where we place the data (on the server end)
*/
bool XWorkerCollect::AddJobUpdateAll(XList * jobQueues,
XList * memberActive, XList * memberAll,
XModel * server,
XOptimizer * optimizer,
XWorkerUpdate * updater, XWorkerBroadcast * broadcaster)
bool XWorkerCollect::AddJobCollectDataP2P(XQueue * jobQueue, XTensor * source, XTensor * target)
{
CheckNTErrors(memberActive != NULL, "No input (active) member list!");
CheckNTErrors(memberAll != NULL, "No input (all) member list!");
CheckNTErrors(server != NULL, "No input server model!");
CheckNTErrors(optimizer != NULL, "No input optimizer!");
CheckNTErrors(updater != NULL, "No input updater!");
CheckNTErrors(broadcaster != NULL, "No input broadcaster!");
CheckNTErrors(source != NULL, "No input soure tensor!");
CheckNTErrors(target != NULL, "No input target tensor!");
XList args;
args.Add(this);
args.AddInt(jobQueues->count);
args.AddList(jobQueues);
args.AddInt(memberActive->count);
args.AddList(memberActive);
args.AddInt(memberAll->count);
args.AddList(memberAll);
args.Add(server);
args.Add(optimizer);
args.Add(updater);
args.Add(broadcaster);
args.Add(source);
args.Add(target);
XQueue& queueRun = jobQueue != NULL ? *jobQueue : queue;
if (isInstantRun)
XWorkerCollect::UpdateAll(&args);
XWorkerCollect::CollectDataP2P(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerCollect::UpdateAll, &args);
queueRun.EnqueueJob((void*)(char*)XWorkerCollect::CollectDataP2P, &args);
return true;
}
......
......@@ -78,6 +78,10 @@ public:
static
void UpdateAll(XList * args);
/* add a new job of collecting data, update the parameter and broadcast the new parameter */
bool AddJobUpdateAll(XList * jobQueues, XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster);
/* P2P data collection */
void CollectP2P(XTensor * source, XTensor * target);
......@@ -87,9 +91,14 @@ public:
/* all-reduce */
void CollectAllReduce(XList * all);
/* add a new job of collecting data, update the parameter and broadcast the new parameter */
bool AddJobUpdateAll(XList * jobQueues, XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster);
/* wrapper of Collect */
static
void CollectDataP2P(XList * args);
/* add a new job of collecting data */
bool AddJobCollectDataP2P(XQueue * jobQueue, XTensor * source, XTensor * target);
};
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论