Commit b69e10f6 by xiaotong

updates

parent b2be85f1
...@@ -306,6 +306,7 @@ run the neural network ...@@ -306,6 +306,7 @@ run the neural network
*/ */
bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* losses) bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* losses)
{ {
//fprintf(stderr, "run simple 0\n");
CheckNTErrors(inputs != NULL && inputs->count >= 1, "Wrong arguments!"); CheckNTErrors(inputs != NULL && inputs->count >= 1, "Wrong arguments!");
CheckNTErrors(outputs != NULL && outputs->count >= 1, "Wrong arguments!"); CheckNTErrors(outputs != NULL && outputs->count >= 1, "Wrong arguments!");
CheckNTErrors(golds != NULL && golds->count >= 1, "Wrong arguments!"); CheckNTErrors(golds != NULL && golds->count >= 1, "Wrong arguments!");
...@@ -338,6 +339,8 @@ bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* l ...@@ -338,6 +339,8 @@ bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* l
net.Backward(*loss); net.Backward(*loss);
delete[] dims; delete[] dims;
//fprintf(stderr, "run simple 1\n");
return true; return true;
} }
......
...@@ -119,6 +119,12 @@ float XLeader::GetLoss() ...@@ -119,6 +119,12 @@ float XLeader::GetLoss()
{ {
return serverRecord.lossAll; return serverRecord.lossAll;
} }
/* get sample number */
int XLeader::GetSampleNum()
{
return serverRecord.sampleNum;
}
/* get prediction number */ /* get prediction number */
int XLeader::GetPredictNum() int XLeader::GetPredictNum()
......
...@@ -111,6 +111,9 @@ public: ...@@ -111,6 +111,9 @@ public:
/* get loss */ /* get loss */
float GetLoss(); float GetLoss();
/* get sample number */
int GetSampleNum();
/* get prediction number */ /* get prediction number */
int GetPredictNum(); int GetPredictNum();
...@@ -143,4 +146,4 @@ public: ...@@ -143,4 +146,4 @@ public:
} }
#endif // __XLEADER_H__ #endif // __XLEADER_H__
\ No newline at end of file
...@@ -45,6 +45,7 @@ XNNRecord::~XNNRecord() ...@@ -45,6 +45,7 @@ XNNRecord::~XNNRecord()
void XNNRecord::Clear() void XNNRecord::Clear()
{ {
lossAll = 0; lossAll = 0;
sampleNum = 0;
predictNum = 0; predictNum = 0;
state = XWORKER_UNSTARTED; state = XWORKER_UNSTARTED;
} }
...@@ -53,6 +54,7 @@ void XNNRecord::Clear() ...@@ -53,6 +54,7 @@ void XNNRecord::Clear()
void XNNRecord::Update(XNNRecord & record) void XNNRecord::Update(XNNRecord & record)
{ {
lossAll += record.lossAll; lossAll += record.lossAll;
sampleNum += record.sampleNum;
predictNum += record.predictNum; predictNum += record.predictNum;
} }
......
...@@ -39,6 +39,9 @@ class XNNRecord ...@@ -39,6 +39,9 @@ class XNNRecord
public: public:
/* loss over all samples */ /* loss over all samples */
float lossAll; float lossAll;
/* sample number */
int sampleNum;
/* prediction number */ /* prediction number */
int predictNum; int predictNum;
...@@ -61,4 +64,4 @@ public: ...@@ -61,4 +64,4 @@ public:
}; };
} }
#endif #endif
\ No newline at end of file
...@@ -103,9 +103,6 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -103,9 +103,6 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
int * ids = new int[MAX_DEVICE_NUM_TRAINING]; int * ids = new int[MAX_DEVICE_NUM_TRAINING];
GetDevIDs(config, ids, jobNum, MAX_DEVICE_NUM_TRAINING); GetDevIDs(config, ids, jobNum, MAX_DEVICE_NUM_TRAINING);
float lossAll = 0;
int predictNum = 0;
/* create the server and workers */ /* create the server and workers */
XLeader leader; XLeader leader;
leader.Init(); leader.Init();
...@@ -127,11 +124,11 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -127,11 +124,11 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
/* one step of udpate */ /* one step of udpate */
ok = leader.Run(config, dataDistributor, model, optimizer); ok = leader.Run(config, dataDistributor, model, optimizer);
float loss = leader.GetLoss() / leader.GetPredictNum(); float loss = leader.GetLoss() / leader.GetSampleNum();
if ((step + 1) % 100 == 0) if ((step + 1) % 1 == 0)
fprintf(stderr, "epoch:%d step:%d loss:%f predict:%d\n", fprintf(stderr, "epoch:%d step:%d sample:%d loss:%f predict:%d\n",
epoch + 1, step + 1, loss, leader.GetPredictNum()); epoch + 1, step + 1, leader.GetSampleNum(), loss, leader.GetPredictNum());
if (step++ >= nstep) if (step++ >= nstep)
break; break;
......
...@@ -144,4 +144,4 @@ bool XWorkerBroadcast::AddJobBroadcast(XModel * source, XList * targetList) ...@@ -144,4 +144,4 @@ bool XWorkerBroadcast::AddJobBroadcast(XModel * source, XList * targetList)
return true; return true;
} }
} }
\ No newline at end of file
...@@ -134,13 +134,16 @@ XNNRecord * XWorkerJob::GetRecord() ...@@ -134,13 +134,16 @@ XNNRecord * XWorkerJob::GetRecord()
void XWorkerJob::RecordMe() void XWorkerJob::RecordMe()
{ {
float lossAll = 0; float lossAll = 0;
int sampleNum = 0;
for (int i = 0; i < losses.count; i++) { for (int i = 0; i < losses.count; i++) {
XTensor* loss = (XTensor*)losses[i]; XTensor* loss = (XTensor*)losses[i];
lossAll += ReduceSumAllValue(*loss); lossAll += ReduceSumAllValue(*loss);
sampleNum += loss->GetSize();
} }
record.lossAll = lossAll; record.lossAll = lossAll;
record.sampleNum = sampleNum;
int predictNum = 0; int predictNum = 0;
...@@ -157,6 +160,12 @@ float XWorkerJob::GetLossAll() ...@@ -157,6 +160,12 @@ float XWorkerJob::GetLossAll()
{ {
return record.lossAll; return record.lossAll;
} }
/* get the number of samples */
int XWorkerJob::GetSampleNum()
{
return record.sampleNum;
}
/* get the number of outputs (predictoins) */ /* get the number of outputs (predictoins) */
int XWorkerJob::GetPredictNum() int XWorkerJob::GetPredictNum()
......
...@@ -99,6 +99,9 @@ public: ...@@ -99,6 +99,9 @@ public:
/* get the sum of losses over samples */ /* get the sum of losses over samples */
float GetLossAll(); float GetLossAll();
/* get the number of samples */
int GetSampleNum();
/* get the number of outputs (predictoins) */ /* get the number of outputs (predictoins) */
int GetPredictNum(); int GetPredictNum();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论