Commit b69e10f6 by xiaotong

updates

parent b2be85f1
......@@ -306,6 +306,7 @@ run the neural network
*/
bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* losses)
{
//fprintf(stderr, "run simple 0\n");
CheckNTErrors(inputs != NULL && inputs->count >= 1, "Wrong arguments!");
CheckNTErrors(outputs != NULL && outputs->count >= 1, "Wrong arguments!");
CheckNTErrors(golds != NULL && golds->count >= 1, "Wrong arguments!");
......@@ -338,6 +339,8 @@ bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* l
net.Backward(*loss);
delete[] dims;
//fprintf(stderr, "run simple 1\n");
return true;
}
......
......@@ -119,6 +119,12 @@ float XLeader::GetLoss()
{
return serverRecord.lossAll;
}
/* get sample number */
int XLeader::GetSampleNum()
{
return serverRecord.sampleNum;
}
/* get prediction number */
int XLeader::GetPredictNum()
......
......@@ -111,6 +111,9 @@ public:
/* get loss */
float GetLoss();
/* get sample number */
int GetSampleNum();
/* get prediction number */
int GetPredictNum();
......@@ -143,4 +146,4 @@ public:
}
#endif // __XLEADER_H__
\ No newline at end of file
#endif // __XLEADER_H__
......@@ -45,6 +45,7 @@ XNNRecord::~XNNRecord()
void XNNRecord::Clear()
{
lossAll = 0;
sampleNum = 0;
predictNum = 0;
state = XWORKER_UNSTARTED;
}
......@@ -53,6 +54,7 @@ void XNNRecord::Clear()
void XNNRecord::Update(XNNRecord & record)
{
lossAll += record.lossAll;
sampleNum += record.sampleNum;
predictNum += record.predictNum;
}
......
......@@ -39,6 +39,9 @@ class XNNRecord
public:
/* loss over all samples */
float lossAll;
/* sample number */
int sampleNum;
/* prediction number */
int predictNum;
......@@ -61,4 +64,4 @@ public:
};
}
#endif
\ No newline at end of file
#endif
......@@ -103,9 +103,6 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
int * ids = new int[MAX_DEVICE_NUM_TRAINING];
GetDevIDs(config, ids, jobNum, MAX_DEVICE_NUM_TRAINING);
float lossAll = 0;
int predictNum = 0;
/* create the server and workers */
XLeader leader;
leader.Init();
......@@ -127,11 +124,11 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
/* one step of udpate */
ok = leader.Run(config, dataDistributor, model, optimizer);
float loss = leader.GetLoss() / leader.GetPredictNum();
float loss = leader.GetLoss() / leader.GetSampleNum();
if ((step + 1) % 100 == 0)
fprintf(stderr, "epoch:%d step:%d loss:%f predict:%d\n",
epoch + 1, step + 1, loss, leader.GetPredictNum());
if ((step + 1) % 1 == 0)
fprintf(stderr, "epoch:%d step:%d sample:%d loss:%f predict:%d\n",
epoch + 1, step + 1, leader.GetSampleNum(), loss, leader.GetPredictNum());
if (step++ >= nstep)
break;
......
......@@ -144,4 +144,4 @@ bool XWorkerBroadcast::AddJobBroadcast(XModel * source, XList * targetList)
return true;
}
}
\ No newline at end of file
}
......@@ -134,13 +134,16 @@ XNNRecord * XWorkerJob::GetRecord()
void XWorkerJob::RecordMe()
{
float lossAll = 0;
int sampleNum = 0;
for (int i = 0; i < losses.count; i++) {
XTensor* loss = (XTensor*)losses[i];
lossAll += ReduceSumAllValue(*loss);
sampleNum += loss->GetSize();
}
record.lossAll = lossAll;
record.sampleNum = sampleNum;
int predictNum = 0;
......@@ -157,6 +160,12 @@ float XWorkerJob::GetLossAll()
{
return record.lossAll;
}
/* get the number of samples */
int XWorkerJob::GetSampleNum()
{
return record.sampleNum;
}
/* get the number of outputs (predictoins) */
int XWorkerJob::GetPredictNum()
......
......@@ -99,6 +99,9 @@ public:
/* get the sum of losses over samples */
float GetLossAll();
/* get the number of samples */
int GetSampleNum();
/* get the number of outputs (predictoins) */
int GetPredictNum();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论