Commit 7c53210c by xiaotong

add counting-job methods in XWorker

parent fbd915c6
...@@ -38,12 +38,16 @@ XWorker::XWorker() ...@@ -38,12 +38,16 @@ XWorker::XWorker()
id = -1; id = -1;
state = XWORKER_UNSTARTED; state = XWORKER_UNSTARTED;
isInstantRun = false; isInstantRun = false;
jobCountFinished = 0;
jobCountExpected = -1;
MUTEX_INIT(countMutex);
} }
/* de-constructor */ /* de-constructor */
XWorker::~XWorker() XWorker::~XWorker()
{ {
Stop(); Stop();
MUTEX_DELE(countMutex);
} }
/* set device id */ /* set device id */
...@@ -109,5 +113,55 @@ bool XWorker::IsEmpty() ...@@ -109,5 +113,55 @@ bool XWorker::IsEmpty()
{ {
return queue.IsEmpty(); return queue.IsEmpty();
} }
/*
reset the job counts
>> myJobCountExpected - expected count
>> myJobCountFinished - current count
*/
void XWorker::ResetJobCount(int myJobCountExpected, int myJobCountFinished)
{
jobCountExpected = myJobCountExpected;
jobCountFinished = myJobCountFinished;
}
/* lock the count mutex */
void XWorker::LockCount()
{
MUTEX_LOCK(countMutex);
}
/* count a finished job */
void XWorker::CountFinishedJob()
{
jobCountFinished++;
if(jobCountExpected < 0 || jobCountFinished > jobCountExpected){
MUTEX_UNLOCK(countMutex);
}
else{
ShowNTErrors("More jobs than expected!");
}
}
/* wrapper of CountFinishedJob() */
void XWorker::CountFinished(XList * args)
{
XWorker * worker = (XWorker*)args->GetItem(0);
worker->CountFinishedJob();
}
/* add a new job of counting finished jobs */
void XWorker::AddJobCountFinished()
{
XList args;
args.Add(this);
if (isInstantRun)
XWorker::CountFinished(&args);
else
queue.EnqueueJob((void*)(char*)XWorker::CountFinished, &args);
}
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -63,6 +63,15 @@ protected: ...@@ -63,6 +63,15 @@ protected:
/* fire the flag of instant run */ /* fire the flag of instant run */
bool isInstantRun; bool isInstantRun;
/* count how many jobs have been finished */
int jobCountFinished;
/* the expected number of jobs */
int jobCountExpected;
/* mutex for accessing the job counts */
MUTEX_HANDLE countMutex;
public: public:
/* constructor */ /* constructor */
...@@ -100,6 +109,24 @@ public: ...@@ -100,6 +109,24 @@ public:
/* whether the job queue is empty? */ /* whether the job queue is empty? */
bool IsEmpty(); bool IsEmpty();
/* reset the job counts */
void ResetJobCount(int myJobCountExpected = -1, int myJobCountFinished = 0);
/* lock the count mutex */
void LockCount();
/* count a finished job */
void CountFinishedJob();
/* wrapper of CountFinishedJob() */
static
void CountFinished(XList * args);
/* add a new job of counting finished jobs */
void AddJobCountFinished();
}; };
} }
......
...@@ -77,7 +77,7 @@ broadcast data for a model ...@@ -77,7 +77,7 @@ broadcast data for a model
>> targetList - the target places that we recieve the data >> targetList - the target places that we recieve the data
>> sleepTime - the waiting time in broadcasting >> sleepTime - the waiting time in broadcasting
*/ */
void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long sleepTime) void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, int sleepTime)
{ {
int finished = 0; int finished = 0;
int * finishedFlag = new int[source->paramNum]; int * finishedFlag = new int[source->paramNum];
......
...@@ -64,7 +64,7 @@ public: ...@@ -64,7 +64,7 @@ public:
void BroadcastDataSingle(XModel * source, XList * targetList, int pid); void BroadcastDataSingle(XModel * source, XList * targetList, int pid);
/* broadcast data for a model */ /* broadcast data for a model */
void BroadcastData(XModel * source, XList * targetList, long sleepTime); void BroadcastData(XModel * source, XList * targetList, int sleepTime);
/* wrapper of BroadcastDataSingle */ /* wrapper of BroadcastDataSingle */
static static
...@@ -96,4 +96,4 @@ public: ...@@ -96,4 +96,4 @@ public:
} }
#endif #endif
\ No newline at end of file
...@@ -65,7 +65,7 @@ to member models back. ...@@ -65,7 +65,7 @@ to member models back.
*/ */
void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XModel * server, void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XOptimizer * optimizer, XWorkerUpdate * updater,
XWorkerBroadcast * broadcaster, long sleepTime) XWorkerBroadcast * broadcaster, int sleepTime)
{ {
int finished = 0; int finished = 0;
...@@ -359,7 +359,7 @@ collect the data of the run (i.e., loss). This is a reducer. ...@@ -359,7 +359,7 @@ collect the data of the run (i.e., loss). This is a reducer.
>> target - the record that we keep the reduce result >> target - the record that we keep the reduce result
>> sleepTime - waiting time in collecting data >> sleepTime - waiting time in collecting data
*/ */
void XWorkerCollect::CollectOtherData(XList* sourceList, XNNRecord* target, long sleepTime) void XWorkerCollect::CollectOtherData(XList* sourceList, XNNRecord* target, int sleepTime)
{ {
int finished = 0; int finished = 0;
int* flags = new int[sourceList->count]; int* flags = new int[sourceList->count];
......
...@@ -72,7 +72,7 @@ public: ...@@ -72,7 +72,7 @@ public:
to member models back. */ to member models back. */
void UpdateDataAll(XList * memberActive, XList * memberAll, XModel * server, void UpdateDataAll(XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster, XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster,
long sleepTime); int sleepTime);
/* wrapper of UpdateDataAll */ /* wrapper of UpdateDataAll */
static static
...@@ -95,7 +95,7 @@ public: ...@@ -95,7 +95,7 @@ public:
bool AddJobCollect(XList * sourceList, XModel * target); bool AddJobCollect(XList * sourceList, XModel * target);
/* collect the data of the run (i.e., loss). This is a reducer. */ /* collect the data of the run (i.e., loss). This is a reducer. */
void CollectOtherData(XList * sourceList, XNNRecord * target, long sleepTime); void CollectOtherData(XList * sourceList, XNNRecord * target, int sleepTime);
/* wrapper of CollectOtherData */ /* wrapper of CollectOtherData */
static static
......
...@@ -87,7 +87,7 @@ update the model ...@@ -87,7 +87,7 @@ update the model
>> optimizer - the optimizer >> optimizer - the optimizer
>> sleepTime - waiting time in each update >> sleepTime - waiting time in each update
*/ */
void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, long sleepTime) void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, int sleepTime)
{ {
int finished = 0; int finished = 0;
...@@ -220,4 +220,4 @@ bool XWorkerUpdate::AddJobUpdate(XModel * model, XOptimizer * optimizer) ...@@ -220,4 +220,4 @@ bool XWorkerUpdate::AddJobUpdate(XModel * model, XOptimizer * optimizer)
return true; return true;
} }
} }
\ No newline at end of file
...@@ -61,7 +61,7 @@ public: ...@@ -61,7 +61,7 @@ public:
XOptimizer * optimizer, XWorkerBroadcast * broadcaster); XOptimizer * optimizer, XWorkerBroadcast * broadcaster);
/* update the model */ /* update the model */
void UpdateModel(XModel * model, XOptimizer * optimizer, long sleepTime); void UpdateModel(XModel * model, XOptimizer * optimizer, int sleepTime);
/* wrapper of UpdateParameter */ /* wrapper of UpdateParameter */
static static
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论