Commit 7c53210c by xiaotong

add counting-job methods in XWorker

parent fbd915c6
......@@ -38,12 +38,16 @@ XWorker::XWorker()
id = -1;
state = XWORKER_UNSTARTED;
isInstantRun = false;
jobCountFinished = 0;
jobCountExpected = -1;
MUTEX_INIT(countMutex);
}
/* de-constructor */
XWorker::~XWorker()
{
Stop();
MUTEX_DELE(countMutex);
}
/* set device id */
......@@ -109,5 +113,55 @@ bool XWorker::IsEmpty()
{
return queue.IsEmpty();
}
/*
reset the job counts
>> myJobCountExpected - expected count
>> myJobCountFinished - current count
*/
void XWorker::ResetJobCount(int myJobCountExpected, int myJobCountFinished)
{
jobCountExpected = myJobCountExpected;
jobCountFinished = myJobCountFinished;
}
/* lock the count mutex */
void XWorker::LockCount()
{
MUTEX_LOCK(countMutex);
}
/* count a finished job */
void XWorker::CountFinishedJob()
{
jobCountFinished++;
if(jobCountExpected < 0 || jobCountFinished > jobCountExpected){
MUTEX_UNLOCK(countMutex);
}
else{
ShowNTErrors("More jobs than expected!");
}
}
/* wrapper of CountFinishedJob() */
void XWorker::CountFinished(XList * args)
{
XWorker * worker = (XWorker*)args->GetItem(0);
worker->CountFinishedJob();
}
/* add a new job of counting finished jobs */
void XWorker::AddJobCountFinished()
{
XList args;
args.Add(this);
if (isInstantRun)
XWorker::CountFinished(&args);
else
queue.EnqueueJob((void*)(char*)XWorker::CountFinished, &args);
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -63,6 +63,15 @@ protected:
/* fire the flag of instant run */
bool isInstantRun;
/* count how many jobs have been finished */
int jobCountFinished;
/* the expected number of jobs */
int jobCountExpected;
/* mutex for accessing the job counts */
MUTEX_HANDLE countMutex;
public:
/* constructor */
......@@ -100,6 +109,24 @@ public:
/* whether the job queue is empty? */
bool IsEmpty();
/* reset the job counts */
void ResetJobCount(int myJobCountExpected = -1, int myJobCountFinished = 0);
/* lock the count mutex */
void LockCount();
/* count a finished job */
void CountFinishedJob();
/* wrapper of CountFinishedJob() */
static
void CountFinished(XList * args);
/* add a new job of counting finished jobs */
void AddJobCountFinished();
};
}
......
......@@ -77,7 +77,7 @@ broadcast data for a model
>> targetList - the target places that we recieve the data
>> sleepTime - the waiting time in broadcasting
*/
void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long sleepTime)
void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, int sleepTime)
{
int finished = 0;
int * finishedFlag = new int[source->paramNum];
......
......@@ -64,7 +64,7 @@ public:
void BroadcastDataSingle(XModel * source, XList * targetList, int pid);
/* broadcast data for a model */
void BroadcastData(XModel * source, XList * targetList, long sleepTime);
void BroadcastData(XModel * source, XList * targetList, int sleepTime);
/* wrapper of BroadcastDataSingle */
static
......@@ -96,4 +96,4 @@ public:
}
#endif
\ No newline at end of file
#endif
......@@ -65,7 +65,7 @@ to member models back.
*/
void XWorkerCollect::UpdateDataAll(XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater,
XWorkerBroadcast * broadcaster, long sleepTime)
XWorkerBroadcast * broadcaster, int sleepTime)
{
int finished = 0;
......@@ -359,7 +359,7 @@ collect the data of the run (i.e., loss). This is a reducer.
>> target - the record that we keep the reduce result
>> sleepTime - waiting time in collecting data
*/
void XWorkerCollect::CollectOtherData(XList* sourceList, XNNRecord* target, long sleepTime)
void XWorkerCollect::CollectOtherData(XList* sourceList, XNNRecord* target, int sleepTime)
{
int finished = 0;
int* flags = new int[sourceList->count];
......
......@@ -72,7 +72,7 @@ public:
to member models back. */
void UpdateDataAll(XList * memberActive, XList * memberAll, XModel * server,
XOptimizer * optimizer, XWorkerUpdate * updater, XWorkerBroadcast * broadcaster,
long sleepTime);
int sleepTime);
/* wrapper of UpdateDataAll */
static
......@@ -95,7 +95,7 @@ public:
bool AddJobCollect(XList * sourceList, XModel * target);
/* collect the data of the run (i.e., loss). This is a reducer. */
void CollectOtherData(XList * sourceList, XNNRecord * target, long sleepTime);
void CollectOtherData(XList * sourceList, XNNRecord * target, int sleepTime);
/* wrapper of CollectOtherData */
static
......
......@@ -87,7 +87,7 @@ update the model
>> optimizer - the optimizer
>> sleepTime - waiting time in each update
*/
void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, long sleepTime)
void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, int sleepTime)
{
int finished = 0;
......@@ -220,4 +220,4 @@ bool XWorkerUpdate::AddJobUpdate(XModel * model, XOptimizer * optimizer)
return true;
}
}
\ No newline at end of file
}
......@@ -61,7 +61,7 @@ public:
XOptimizer * optimizer, XWorkerBroadcast * broadcaster);
/* update the model */
void UpdateModel(XModel * model, XOptimizer * optimizer, long sleepTime);
void UpdateModel(XModel * model, XOptimizer * optimizer, int sleepTime);
/* wrapper of UpdateParameter */
static
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论