Commit 9eda6d83 by xiaotong

bug fixes and a new class XNNRecord

parent 5f345e87
...@@ -276,17 +276,19 @@ run the neural network ...@@ -276,17 +276,19 @@ run the neural network
>> inputs - inputs of the model >> inputs - inputs of the model
>> outputs - outputs of the model >> outputs - outputs of the model
>> golds - gold standards >> golds - gold standards
>> losses - losses of the output respect to the gold standards
*/ */
bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds) bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* losses)
{ {
CheckNTErrors(inputs != NULL && inputs->count >= 1, "Wrong arguments!"); CheckNTErrors(inputs != NULL && inputs->count >= 1, "Wrong arguments!");
CheckNTErrors(outputs != NULL && outputs->count >= 1, "Wrong arguments!"); CheckNTErrors(outputs != NULL && outputs->count >= 1, "Wrong arguments!");
CheckNTErrors(golds != NULL && golds->count >= 1, "Wrong arguments!"); CheckNTErrors(golds != NULL && golds->count >= 1, "Wrong arguments!");
CheckNTErrors(losses != NULL && losses->count >= 1, "Wrong arguments!");
XTensor * input = (XTensor*)inputs->GetItem(0); XTensor * input = (XTensor*)inputs->GetItem(0);
XTensor * output = (XTensor*)outputs->GetItem(0); XTensor * output = (XTensor*)outputs->GetItem(0);
XTensor * gold = (XTensor*)golds->GetItem(0); XTensor * gold = (XTensor*)golds->GetItem(0);
XTensor loss; XTensor * loss = (XTensor*)losses->GetItem(0);
XTensor goldOneHot; XTensor goldOneHot;
XNet net; XNet net;
...@@ -301,9 +303,9 @@ bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds) ...@@ -301,9 +303,9 @@ bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds)
dims[goldOneHot.order - 2] = goldOneHot.GetDim(goldOneHot.order - 1); dims[goldOneHot.order - 2] = goldOneHot.GetDim(goldOneHot.order - 1);
goldOneHot.Reshape(goldOneHot.order - 1, dims); goldOneHot.Reshape(goldOneHot.order - 1, dims);
loss = CrossEntropy(output, goldOneHot); *loss = CrossEntropy(*output, goldOneHot);
net.Backward(loss); net.Backward(*loss);
delete[] dims; delete[] dims;
......
...@@ -146,7 +146,7 @@ public: ...@@ -146,7 +146,7 @@ public:
XModel * Clone(int devID); XModel * Clone(int devID);
/* run the neural network */ /* run the neural network */
bool RunSimple(XList * inputs, XList * outputs, XList * golds); bool RunSimple(XList * inputs, XList * outputs, XList * golds, XList * losses);
}; };
/* */ /* */
......
...@@ -65,6 +65,8 @@ void XLeader::Init() ...@@ -65,6 +65,8 @@ void XLeader::Init()
for (int i = 0; i < bworkers.count; i++) for (int i = 0; i < bworkers.count; i++)
delete (XWorkerBroadcast*)bworkers.GetItem(i); delete (XWorkerBroadcast*)bworkers.GetItem(i);
bworkers.Clear(); bworkers.Clear();
serverRecord.Clear();
} }
/* set id */ /* set id */
...@@ -109,6 +111,18 @@ void XLeader::SetServerModel(XConfig * config, XModel * model) ...@@ -109,6 +111,18 @@ void XLeader::SetServerModel(XConfig * config, XModel * model)
SetServerModel(config, model, &members); SetServerModel(config, model, &members);
} }
/* get loss */
float XLeader::GetLoss()
{
return serverRecord.lossAll;
}
/* get prediction number */
int XLeader::GetPredictNum()
{
return serverRecord.predictNum;
}
/* /*
set the communication mode set the communication mode
>> myMode - the mode >> myMode - the mode
...@@ -117,6 +131,7 @@ void XLeader::SetMode(XLEADER_MODE myMode) ...@@ -117,6 +131,7 @@ void XLeader::SetMode(XLEADER_MODE myMode)
{ {
mode = myMode; mode = myMode;
} }
/* start the workers */ /* start the workers */
void XLeader::Start() void XLeader::Start()
{ {
...@@ -195,7 +210,7 @@ void XLeader::AddJobBroadcastWorker() ...@@ -195,7 +210,7 @@ void XLeader::AddJobBroadcastWorker()
} }
/* /*
run the model (for one time) run the model (for one time). Basically this is a map-reduce process.
>> config - the configuration >> config - the configuration
>> dataDistributor - data distributor >> dataDistributor - data distributor
>> model - the neural network that we want to run >> model - the neural network that we want to run
...@@ -207,6 +222,10 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -207,6 +222,10 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
{ {
bool isDataOK = true; bool isDataOK = true;
int activeJobCount = 0; int activeJobCount = 0;
int* active = new int[jworkers.count];
for (int i = 0; i < jworkers.count; i++)
active[i] = 0;
/* Feed the input to each worker and geneate the output. /* Feed the input to each worker and geneate the output.
For each worker, we define a job queue and enqueue jobs For each worker, we define a job queue and enqueue jobs
...@@ -226,55 +245,78 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -226,55 +245,78 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
worker->AddJobRefresh(jmodel); worker->AddJobRefresh(jmodel);
/* job in queue 1: run the model */ /* job in queue 1: run the model */
worker->AddJobNeuralNet(jmodel, worker->GetInput(), worker->GetOutput(), worker->GetGold()); worker->AddJobNeuralNet(jmodel,
worker->GetInput(), worker->GetOutput(),
worker->GetGold(), worker->GetLoss());
/* job in queue 1: make a record of the run */
worker->AddJobRecord();
active[i] = 1;
activeJobCount++; activeJobCount++;
} }
} }
if (activeJobCount == 0) if (activeJobCount >= 0) {
return false; /* member models that are active in this run */
XList members(jworkers.count);
XList members(jworkers.count);
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i];
members.Add(worker->GetModel());
}
/* job in queue 2: collect the (gradient) data */ /* all member models */
if (cworkers.count > 0) { XList membersAll(jworkers.count);
XWorkerCollect * collecter = (XWorkerCollect*)cworkers.GetItem(0);
collecter->AddJobCollect(&members, &serverModel);
}
else {
ShowNTErrors("No data-collecting workers!");
}
/* job in queue 3: update the model */ /* records of the active member models */
if (uworkers.count > 0) { XList memberRecords(jworkers.count);
XWorkerUpdate * updater = (XWorkerUpdate*)uworkers.GetItem(0);
updater->AddJobUpdate(&serverModel, optimizer);
}
else {
ShowNTErrors("No model-update workers!");
}
/* job in queue 4: broadcast the lastest parameters to workers */ for (int i = 0; i < jworkers.count; i++) {
if (bworkers.count > 0) { XWorkerJob* worker = (XWorkerJob*)jworkers[i];
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)bworkers.GetItem(0); membersAll.Add(worker->GetModel());
broadcaster->AddJobBroadcast(&serverModel, &members); if (active[i] == 1) {
} members.Add(worker->GetModel());
else { memberRecords.Add(worker->GetRecord());
ShowNTErrors("No data-broadcasting workers!"); }
} }
WaitForFinishing(); /* jobs in queue 2: collect the (gradient) data and other stuff. This
is a reduce process. */
if (cworkers.count > 0) {
XWorkerCollect* collecter = (XWorkerCollect*)cworkers.GetItem(0);
collecter->AddJobCollect(&members, &serverModel);
collecter->AddJobCollectOther(&memberRecords, &serverRecord);
}
else {
ShowNTErrors("No data-collecting workers!");
}
/* job in queue 3: update the model */
if (uworkers.count > 0) {
XWorkerUpdate* updater = (XWorkerUpdate*)uworkers.GetItem(0);
updater->AddJobUpdate(&serverModel, optimizer);
}
else {
ShowNTErrors("No model-update workers!");
}
/* job in queue 4: broadcast the lastest parameters to workers. NOTE that
we would update a worker to the laster model parameters, even if it is
not involved in this run. */
if (bworkers.count > 0) {
XWorkerBroadcast* broadcaster = (XWorkerBroadcast*)bworkers.GetItem(0);
broadcaster->AddJobBroadcast(&serverModel, &membersAll);
}
else {
ShowNTErrors("No data-broadcasting workers!");
}
WaitForFinishing();
}
for (int i = 0; i < jworkers.count; i++) { for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i]; XWorkerJob * worker = (XWorkerJob*)jworkers[i];
worker->Clear(); worker->Clear();
} }
delete[] active;
return isDataOK; return isDataOK;
} }
......
...@@ -69,6 +69,9 @@ protected: ...@@ -69,6 +69,9 @@ protected:
/* a model that keeps the parameters (as a server) */ /* a model that keeps the parameters (as a server) */
XModel serverModel; XModel serverModel;
/* a record that keeps the information of the run */
XNNRecord serverRecord;
/* communication mode */ /* communication mode */
XLEADER_MODE mode; XLEADER_MODE mode;
...@@ -106,6 +109,12 @@ public: ...@@ -106,6 +109,12 @@ public:
/* set the server model */ /* set the server model */
void SetServerModel(XConfig * config, XModel * model); void SetServerModel(XConfig * config, XModel * model);
/* get loss */
float GetLoss();
/* get prediction number */
int GetPredictNum();
/* start the workers */ /* start the workers */
void Start(); void Start();
......
...@@ -69,8 +69,10 @@ XModel * XModel::Clone(int devID) ...@@ -69,8 +69,10 @@ XModel * XModel::Clone(int devID)
run the neural network run the neural network
>> inputs - inputs of the model >> inputs - inputs of the model
>> outputs - outputs of the model >> outputs - outputs of the model
>> golds - gold standards
>> losses - losses of the input with respect to the gold standards
*/ */
bool XModel::RunSimple(XList * inputs, XList * outputs, XList * golds) bool XModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList * losses)
{ {
return false; return false;
} }
...@@ -86,8 +88,9 @@ bool XModel::RunMe(XList * args) ...@@ -86,8 +88,9 @@ bool XModel::RunMe(XList * args)
XList * inputs = (XList*)args->GetItem(0); XList * inputs = (XList*)args->GetItem(0);
XList * outputs = (XList*)args->GetItem(1); XList * outputs = (XList*)args->GetItem(1);
XList * golds = (XList*)args->GetItem(2); XList * golds = (XList*)args->GetItem(2);
XList* losses = (XList*)args->GetItem(3);
if (RunSimple(inputs, outputs, golds)) if (RunSimple(inputs, outputs, golds, losses))
return true; return true;
ShowNTErrors("You must be overload one of these: XModel::RunSimple ... !"); ShowNTErrors("You must be overload one of these: XModel::RunSimple ... !");
......
...@@ -82,7 +82,7 @@ public: ...@@ -82,7 +82,7 @@ public:
/* run the neural network */ /* run the neural network */
virtual virtual
bool RunSimple(XList * inputs, XList * outputs, XList * golds); bool RunSimple(XList * inputs, XList * outputs, XList * golds, XList * losses);
protected: protected:
/* run the neural network */ /* run the neural network */
......
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* A record that keeps some information in running and training neural networks
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-06
* I will climb mountains with my wife and son this afternoon, hahaha :)
*/
#include "XNNRecord.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */
XNNRecord::XNNRecord()
{
Clear();
}
/* de-constructor */
XNNRecord::~XNNRecord()
{
}
/* clear it */
void XNNRecord::Clear()
{
lossAll = 0;
predictNum = 0;
state = XWORKER_UNSTARTED;
}
/* update me with another record */
void XNNRecord::Update(XNNRecord & record)
{
lossAll += record.lossAll;
predictNum += record.predictNum;
}
}
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* A record that keeps some information in running and training neural networks
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-06
* I will climb mountains with my wife and son this afternoon, hahaha :)
*/
#ifndef __XNNRECORD_H__
#define __XNNRECORD_H__
#include "XWorker.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* a record of keeping some stuff during training */
class XNNRecord
{
public:
/* loss over all samples */
float lossAll;
/* prediction number */
int predictNum;
/* state */
XWORKER_STATE state;
public:
/* constructor */
XNNRecord();
/* de-constructor */
~XNNRecord();
/* clear it */
void Clear();
/* update me with another record */
void Update(XNNRecord & record);
};
}
#endif
\ No newline at end of file
...@@ -103,6 +103,9 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -103,6 +103,9 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
int * ids = new int[MAX_DEVICE_NUM_TRAINING]; int * ids = new int[MAX_DEVICE_NUM_TRAINING];
GetDevIDs(config, ids, jobNum, MAX_DEVICE_NUM_TRAINING); GetDevIDs(config, ids, jobNum, MAX_DEVICE_NUM_TRAINING);
float lossAll = 0;
int predictNum = 0;
/* create the server and workers */ /* create the server and workers */
XLeader leader; XLeader leader;
leader.Init(); leader.Init();
...@@ -124,8 +127,11 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -124,8 +127,11 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
/* one step of udpate */ /* one step of udpate */
ok = leader.Run(config, dataDistributor, model, optimizer); ok = leader.Run(config, dataDistributor, model, optimizer);
float loss = leader.GetLoss() / leader.GetPredictNum();
if ((step + 1) % 100 == 0) if ((step + 1) % 100 == 0)
fprintf(stderr, "epoch:%d step:%d\n", epoch + 1, step + 1); fprintf(stderr, "epoch:%d step:%d loss:%f predict:%d\n",
epoch + 1, step + 1, loss, leader.GetPredictNum());
if (step++ >= nstep) if (step++ >= nstep)
break; break;
......
...@@ -34,6 +34,9 @@ namespace nts { ...@@ -34,6 +34,9 @@ namespace nts {
/* constructor */ /* constructor */
XWorker::XWorker() XWorker::XWorker()
{ {
devID = -1;
id = -1;
state = XWORKER_UNSTARTED;
} }
/* de-constructor */ /* de-constructor */
......
...@@ -35,6 +35,14 @@ ...@@ -35,6 +35,14 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/*
state of a worker
1) unstarted
2) started
3) finished
*/
enum XWORKER_STATE { XWORKER_UNSTARTED, XWORKER_STARTED, XWORKER_FINISHED };
/* the worker class */ /* the worker class */
class XWorker class XWorker
{ {
...@@ -49,6 +57,9 @@ protected: ...@@ -49,6 +57,9 @@ protected:
/* the queue */ /* the queue */
XQueue queue; XQueue queue;
/* state of the worker */
XWORKER_STATE state;
public: public:
/* constructor */ /* constructor */
XWorker(); XWorker();
......
...@@ -100,7 +100,7 @@ void XWorkerBroadcast::Broadcast(XList * args) ...@@ -100,7 +100,7 @@ void XWorkerBroadcast::Broadcast(XList * args)
/* target models */ /* target models */
int targetNum = args->GetItemInt(2); int targetNum = args->GetItemInt(2);
XList target(targetNum); XList target;
for (int i = 0; i < targetNum; i++) { for (int i = 0; i < targetNum; i++) {
XModel * model = (XModel*)args->GetItem(3 + i); XModel * model = (XModel*)args->GetItem(3 + i);
target.Add(model); target.Add(model);
......
...@@ -177,7 +177,7 @@ void XWorkerCollect::Collect(XList * args) ...@@ -177,7 +177,7 @@ void XWorkerCollect::Collect(XList * args)
int sourceNum = args->GetItemInt(1); int sourceNum = args->GetItemInt(1);
/* the source models */ /* the source models */
XList source(sourceNum); XList source;
for (int i = 0; i < sourceNum; i++) { for (int i = 0; i < sourceNum; i++) {
XModel * model = (XModel*)args->GetItem(2 + i); XModel * model = (XModel*)args->GetItem(2 + i);
source.Add(model); source.Add(model);
...@@ -257,4 +257,86 @@ bool XWorkerCollect::AddJobCollect(XList * sourceList, XModel * target) ...@@ -257,4 +257,86 @@ bool XWorkerCollect::AddJobCollect(XList * sourceList, XModel * target)
return true; return true;
} }
/*
collect the data of the run (i.e., loss). This is a reducer.
>> sourceList - the list of record
>> target - the record that we keep the reduce result
>> sleepTime - waiting time in collecting data
*/
void XWorkerCollect::CollectOtherData(XList* sourceList, XNNRecord* target, long sleepTime)
{
int finished = 0;
int* flags = new int[sourceList->count];
for (int i = 0; i < sourceList->count; i++)
flags[i] = 0;
while (1) {
for (int i = 0; i < sourceList->count; i++) {
if (flags[i] != 0)
continue;
XNNRecord* source = (XNNRecord*)sourceList->GetItem(i);
if (source->state == XWORKER_FINISHED) {
if(target != source)
target->Update(*source);
flags[i] = 1;
finished++;
}
}
if (finished == sourceList->count)
break;
#ifdef _WIN32
Sleep((DWORD)sleepTime);
#else
sleep((unsigned)sleepTime / 1000);
#endif
}
delete[] flags;
}
/* wrapper of CollectOtherData */
void XWorkerCollect::CollectOther(XList* args)
{
XWorkerCollect* collecter = (XWorkerCollect*)args->GetItem(0);
int sourceNum = args->GetItemInt(1);
/* the source records */
XList source;
for (int i = 0; i < sourceNum; i++) {
XNNRecord * record = (XNNRecord*)args->GetItem(2 + i);
source.Add(record);
}
/* the target record */
XNNRecord* target = (XNNRecord*)args->GetItem(2 + sourceNum);
collecter->CollectOtherData(&source, target, SLEEP_TIME_IN_COLLECTING_OTHER);
}
/*
add a new job of collecting data of the run (i.e., loss)
collect the data of the run (i.e., loss). This is a reducer.
>> sourceList - the list of record
>> target - the record that we keep the reduce result
*/
bool XWorkerCollect::AddJobCollectOther(XList* sourceList, XNNRecord* target)
{
CheckNTErrors(sourceList != NULL, "no input source record list!");
CheckNTErrors(target != NULL, "no input target record!");
XList args;
args.Add(this);
args.AddInt(sourceList->count);
args.AddList(sourceList);
args.Add(target);
queue.EnqueueJob((void*)(char*)XWorkerCollect::CollectOther, &args);
return true;
}
} }
...@@ -31,10 +31,12 @@ ...@@ -31,10 +31,12 @@
#include "XWorker.h" #include "XWorker.h"
#include "XModel.h" #include "XModel.h"
#include "XWorkerJob.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
#define SLEEP_TIME_IN_COLLECTING 10 #define SLEEP_TIME_IN_COLLECTING 10
#define SLEEP_TIME_IN_COLLECTING_OTHER 10
/* /*
data collection method data collection method
...@@ -61,7 +63,7 @@ public: ...@@ -61,7 +63,7 @@ public:
/* set the collection type */ /* set the collection type */
void SetCollectMode(DATA_COLLECT_TYPE myMode); void SetCollectMode(DATA_COLLECT_TYPE myMode);
/* collect data */ /* collect the gradient data (i.e., a reducer) */
void CollectData(XList * sourceList, XModel * target, long sleepTime); void CollectData(XList * sourceList, XModel * target, long sleepTime);
/* wrapper of CollectData */ /* wrapper of CollectData */
...@@ -79,6 +81,16 @@ public: ...@@ -79,6 +81,16 @@ public:
/* add a new job of collecting data */ /* add a new job of collecting data */
bool AddJobCollect(XList * sourceList, XModel * target); bool AddJobCollect(XList * sourceList, XModel * target);
/* collect the data of the run (i.e., loss). This is a reducer. */
void CollectOtherData(XList * sourceList, XNNRecord * target, long sleepTime);
/* wrapper of CollectOtherData */
static
void CollectOther(XList * args);
/* add a new job of collecting data of the run (i.e., loss) */
bool AddJobCollectOther(XList * sourceList, XNNRecord * target);
}; };
} }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "XWorkerJob.h" #include "XWorkerJob.h"
#include "../tensor/XList.h" #include "../tensor/XList.h"
#include "../tensor/core/CHeader.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
...@@ -47,6 +48,9 @@ XWorkerJob::~XWorkerJob() ...@@ -47,6 +48,9 @@ XWorkerJob::~XWorkerJob()
for (int i = 0; i < golds.count; i++) for (int i = 0; i < golds.count; i++)
delete (XTensor*)golds[i]; delete (XTensor*)golds[i];
for (int i = 0; i < losses.count; i++)
delete (XTensor*)losses[i];
} }
/* set the model */ /* set the model */
...@@ -61,6 +65,13 @@ XModel * XWorkerJob::GetModel() ...@@ -61,6 +65,13 @@ XModel * XWorkerJob::GetModel()
return model; return model;
} }
/* set the state of the worker */
void XWorkerJob::SetState(XWORKER_STATE myState)
{
state = myState;
record.state = myState;
}
/* clear the worker */ /* clear the worker */
void XWorkerJob::Clear() void XWorkerJob::Clear()
{ {
...@@ -78,6 +89,15 @@ void XWorkerJob::Clear() ...@@ -78,6 +89,15 @@ void XWorkerJob::Clear()
delete (XTensor*)golds[i]; delete (XTensor*)golds[i];
golds.Clear(); golds.Clear();
golds.Add(new XTensor()); golds.Add(new XTensor());
for (int i = 0; i < losses.count; i++)
delete (XTensor*)losses[i];
losses.Clear();
losses.Add(new XTensor());
record.Clear();
SetState(XWORKER_UNSTARTED);
} }
/* get the input list */ /* get the input list */
...@@ -98,6 +118,52 @@ XList * XWorkerJob::GetGold() ...@@ -98,6 +118,52 @@ XList * XWorkerJob::GetGold()
return &golds; return &golds;
} }
/* get the loss */
XList * XWorkerJob::GetLoss()
{
return &losses;
}
/* get the record of the run */
XNNRecord * XWorkerJob::GetRecord()
{
return &record;
}
/* record some stuff */
void XWorkerJob::RecordMe()
{
float lossAll = 0;
for (int i = 0; i < losses.count; i++) {
XTensor* loss = (XTensor*)losses[i];
lossAll += ReduceSumAllValue(*loss);
}
record.lossAll = lossAll;
int predictNum = 0;
for (int i = 0; i < outputs.count; i++) {
XTensor* output = (XTensor*)outputs[i];
predictNum += output->GetSize();
}
record.predictNum = predictNum;
}
/* get the sum of losses over samples */
float XWorkerJob::GetLossAll()
{
return record.lossAll;
}
/* get the number of outputs (predictoins) */
int XWorkerJob::GetPredictNum()
{
return record.predictNum;
}
/* /*
add a new job of model refreshment add a new job of model refreshment
>> myModel - the model >> myModel - the model
...@@ -121,9 +187,11 @@ add a new job of neural network forward and backward computation (with the input ...@@ -121,9 +187,11 @@ add a new job of neural network forward and backward computation (with the input
>> inputs - inputs of the neural network >> inputs - inputs of the neural network
>> outputs - outputs of the neural network >> outputs - outputs of the neural network
>> golds - gold standards >> golds - gold standards
>> losses - losses of the outputs respect to the gold standards
<< return - succeeded or not << return - succeeded or not
*/ */
bool XWorkerJob::AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs, XList * golds) bool XWorkerJob::AddJobNeuralNet(XModel * myModel,
XList * inputs, XList * outputs, XList * golds, XList * losses)
{ {
CheckNTErrors(myModel != NULL, "no input neural network!"); CheckNTErrors(myModel != NULL, "no input neural network!");
CheckNTErrors(inputs != NULL, "no inputs of the model!"); CheckNTErrors(inputs != NULL, "no inputs of the model!");
...@@ -134,11 +202,37 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outpu ...@@ -134,11 +202,37 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outpu
args.Add(inputs); args.Add(inputs);
args.Add(outputs); args.Add(outputs);
args.Add(golds); args.Add(golds);
args.Add(losses);
queue.EnqueueJob((void*)(char*)XModel::Run, &args); queue.EnqueueJob((void*)(char*)XModel::Run, &args);
SetState(XWORKER_STARTED);
return true;
}
/* add a new job of recording the running of the nerual network */
bool XWorkerJob::AddJobRecord()
{
XList args;
args.Add(this);
queue.EnqueueJob((void*)(char*)XWorkerJob::RecordMeStatic, &args);
return true; return true;
} }
/* wrapper of RecordMe */
void XWorkerJob::RecordMeStatic(XList* args)
{
CheckNTErrors(args != NULL && args->count > 0, "Illegal arguments!");
XWorkerJob * worker = (XWorkerJob*)args->GetItem(0);
worker->RecordMe();
worker->SetState(XWORKER_FINISHED);
}
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "XWorker.h" #include "XWorker.h"
#include "XModel.h" #include "XModel.h"
#include "XNNRecord.h"
#include "XBaseTemplate.h" #include "XBaseTemplate.h"
#include "../tensor/XList.h" #include "../tensor/XList.h"
...@@ -51,6 +52,12 @@ protected: ...@@ -51,6 +52,12 @@ protected:
/* the gold standard */ /* the gold standard */
XList golds; XList golds;
/* the loss */
XList losses;
/* record the information in running the neural network */
XNNRecord record;
public: public:
...@@ -66,6 +73,9 @@ public: ...@@ -66,6 +73,9 @@ public:
/* get the parameter keeper */ /* get the parameter keeper */
XModel * GetModel(); XModel * GetModel();
/* set the state of the worker */
void SetState(XWORKER_STATE myState);
/* clear the worker */ /* clear the worker */
void Clear(); void Clear();
...@@ -78,11 +88,34 @@ public: ...@@ -78,11 +88,34 @@ public:
/* get the gold standard */ /* get the gold standard */
XList * GetGold(); XList * GetGold();
/* get the loss */
XList * GetLoss();
/* get the record of the run */
XNNRecord * GetRecord();
/* record some stuff */
void RecordMe();
/* get the sum of losses over samples */
float GetLossAll();
/* get the number of outputs (predictoins) */
int GetPredictNum();
/* add a new job of model refreshment */ /* add a new job of model refreshment */
bool AddJobRefresh(XModel * myModel); bool AddJobRefresh(XModel * myModel);
/* add a new job of neural network forward and backward computation (with the input) */ /* add a new job of neural network forward and backward computation (with the input) */
bool AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs, XList * golds); bool AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs, XList * golds, XList * losses);
/* add a new job of recording the running of the nerual network */
bool AddJobRecord();
private:
/* wrapper of RecordMe */
static
void RecordMeStatic(XList * args);
}; };
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论