Commit e9d68683 by xiaotong

bug fixes

parent fbb4331c
...@@ -42,7 +42,7 @@ int main( int argc, const char ** argv ) ...@@ -42,7 +42,7 @@ int main( int argc, const char ** argv )
if (argc > 1 && !strcmp(argv[1], "-test")) if (argc > 1 && !strcmp(argv[1], "-test"))
Test(); Test();
else if (argc > 1 && !strcmp(argv[1], "-testtrain")) else if (argc > 1 && !strcmp(argv[1], "-testtrain"))
TestTrain(argc - 1, argv + 1); TestTrain();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm")) else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
FNNLMMain(argc - 1, argv + 1); FNNLMMain(argc - 1, argv + 1);
else if(argc > 1 && !strcmp(argv[1], "-t2t")) else if(argc > 1 && !strcmp(argv[1], "-t2t"))
......
...@@ -68,12 +68,12 @@ void GeneateTTrainData(const char * fileName) ...@@ -68,12 +68,12 @@ void GeneateTTrainData(const char * fileName)
} }
/* run the test */ /* run the test */
void TestTrain(int argc, const char ** argv) void TestTrain()
{ {
GeneateTTrainData("ttrain.txt"); GeneateTTrainData("ttrain.txt");
XConfig config; XConfig config;
config.Create(argc, argv); config.Add("dev", -1);
TTDataLoader loader; TTDataLoader loader;
loader.SetFileName("ttrain.txt"); loader.SetFileName("ttrain.txt");
...@@ -141,35 +141,19 @@ bool TTDataLoader::End() ...@@ -141,35 +141,19 @@ bool TTDataLoader::End()
return true; return true;
} }
/* get a batch of samples */ /*
bool TTDataLoader::GetBatch(XList * args) get a batch of samples
>> inputs - inputs of the model
>> golds - gold standards
*/
bool TTDataLoader::GetBatchSimple(XList * inputs, XList * golds)
{ {
CheckNTErrors(file != NULL, "No input file specificed!"); CheckNTErrors(file != NULL, "No input file specificed!");
CheckNTErrors(inputs != NULL && inputs->count >= 1, "Wrong argument!");
CheckNTErrors(golds != NULL && golds->count >= 1, "Wrong argument!");
XTensor * input = NULL; XTensor * input = (XTensor*)inputs->GetItem(0);
XTensor * gold = NULL; XTensor * gold = (XTensor*)golds->GetItem(0);
XTensor * output = NULL;
if (args->count == 0) {
input = new XTensor();
args->Add(input);
}
else
input = (XTensor*)args->GetItem(0);
if (args->count == 1) {
output = new XTensor();
args->Add(output);
}
if (args->count == 2) {
gold = new XTensor();
args->Add(gold);
}
else
gold = (XTensor*)args->GetItem(1);
int count = 0; int count = 0;
int sampleSize = MAX_SAMPLE_SIZE; int sampleSize = MAX_SAMPLE_SIZE;
...@@ -249,9 +233,16 @@ void TTModel::Forward(int devID, XTensor * input, XTensor * output) ...@@ -249,9 +233,16 @@ void TTModel::Forward(int devID, XTensor * input, XTensor * output)
XTensor embeddingCat; XTensor embeddingCat;
XTensor hidden; XTensor hidden;
/* [e_0, e_1, e_2] = w_e * input(one-hot) */
embedding = Gather(embeddingW, *input); embedding = Gather(embeddingW, *input);
/* e = merge(e_0, e_1, e_2) */
embeddingCat = Merge(embedding, 0, 1); embeddingCat = Merge(embedding, 0, 1);
/* h = e * w_h */
hidden = MMul(embeddingCat, hiddenW); hidden = MMul(embeddingCat, hiddenW);
/* output = Softmax(h) */
*output = Softmax(hidden, 0); *output = Softmax(hidden, 0);
} }
...@@ -271,14 +262,21 @@ XModel * TTModel::Clone(int devID) ...@@ -271,14 +262,21 @@ XModel * TTModel::Clone(int devID)
return model; return model;
} }
/* run the neural network */ /*
bool TTModel::RunMe(XList * args) run the neural network
>> inputs - inputs of the model
>> outputs - outputs of the model
>> golds - gold standards
*/
bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds)
{ {
CheckNTErrors(args != NULL && args->count >= 3, "Illegal input arguments!"); CheckNTErrors(inputs != NULL && inputs->count >= 1, "Wrong arguments!");
CheckNTErrors(outputs != NULL && outputs->count >= 1, "Wrong arguments!");
CheckNTErrors(golds != NULL && golds->count >= 1, "Wrong arguments!");
XTensor * input = (XTensor*)args->GetItem(0); XTensor * input = (XTensor*)inputs->GetItem(0);
XTensor * output = (XTensor*)args->GetItem(1); XTensor * output = (XTensor*)outputs->GetItem(0);
XTensor * gold = (XTensor*)args->GetItem(2); XTensor * gold = (XTensor*)golds->GetItem(0);
XTensor loss; XTensor loss;
XNet net; XNet net;
......
...@@ -57,7 +57,7 @@ void GeneateTTrainData(const char * fileName); ...@@ -57,7 +57,7 @@ void GeneateTTrainData(const char * fileName);
/* run the test */ /* run the test */
extern extern
void TestTrain(int argc, const char ** argv); void TestTrain();
/* data loader */ /* data loader */
class TTDataLoader : public DataDistributeBase class TTDataLoader : public DataDistributeBase
...@@ -92,7 +92,7 @@ public: ...@@ -92,7 +92,7 @@ public:
bool End(); bool End();
/* get a batch of samples */ /* get a batch of samples */
bool GetBatch(XList * args); bool GetBatchSimple(XList * inputs, XList * golds);
}; };
/* the model */ /* the model */
...@@ -134,7 +134,7 @@ public: ...@@ -134,7 +134,7 @@ public:
XModel * Clone(int devID); XModel * Clone(int devID);
/* run the neural network */ /* run the neural network */
bool RunMe(XList * args); bool RunSimple(XList * inputs, XList * outputs, XList * golds);
}; };
/* */ /* */
......
...@@ -60,11 +60,29 @@ bool DataDistributeBase::End() ...@@ -60,11 +60,29 @@ bool DataDistributeBase::End()
return true; return true;
} }
/*
get a batch of samples
>> inputs - inputs of the model
>> golds - gold standards
*/
bool DataDistributeBase::GetBatchSimple(XList * inputs, XList * golds)
{
return false;
}
/* get a batch of samples */ /* get a batch of samples */
bool DataDistributeBase::GetBatch(XList * args) bool DataDistributeBase::GetBatch(XList * args)
{ {
ShowNTErrors("DataDistributeBase::GetBatch must be overloaded!"); CheckNTErrors(args->count >= 2, "More input arguments are required!");
return true;
XList * input = (XList*)args->GetItem(0);
XList * gold = (XList*)args->GetItem(1);
if (GetBatchSimple(input, gold))
return true;
ShowNTErrors("You must be overload one of these: DataDistributeBase::GetBatchSimple ... !");
return false;
} }
/* get a batch of samples (for multi-threading) */ /* get a batch of samples (for multi-threading) */
......
...@@ -69,9 +69,13 @@ public: ...@@ -69,9 +69,13 @@ public:
/* get a batch of samples */ /* get a batch of samples */
virtual virtual
bool GetBatchSimple(XList * inputs, XList * golds);
public:
/* get a batch of samples */
bool GetBatch(XList * args); bool GetBatch(XList * args);
protected:
/* get a batch of samples (for multi-threading) */ /* get a batch of samples (for multi-threading) */
bool GetBatchSafe(XList * args); bool GetBatchSafe(XList * args);
}; };
......
...@@ -216,13 +216,13 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -216,13 +216,13 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
XModel * jmodel = worker->GetModel(); XModel * jmodel = worker->GetModel();
/* get a batch of samples */ /* get a batch of samples */
bool fetched = dataDistributor->GetBatch(worker->GetInput()); bool fetched = dataDistributor->GetBatchSimple(worker->GetInput(), worker->GetGold());
/* job in queue 1: refresh the model */ /* job in queue 1: refresh the model */
worker->AddJobRefresh(jmodel); worker->AddJobRefresh(jmodel);
/* job in queue 1: run the model */ /* job in queue 1: run the model */
worker->AddJobNeuralNet(jmodel, worker->GetInput(), worker->GetOutput()); worker->AddJobNeuralNet(jmodel, worker->GetInput(), worker->GetOutput(), worker->GetGold());
/* clear it */ /* clear it */
worker->Clear(); worker->Clear();
......
...@@ -67,12 +67,31 @@ XModel * XModel::Clone(int devID) ...@@ -67,12 +67,31 @@ XModel * XModel::Clone(int devID)
/* /*
run the neural network run the neural network
>> inputs - inputs of the model
>> outputs - outputs of the model
*/
bool XModel::RunSimple(XList * inputs, XList * outputs, XList * golds)
{
return false;
}
/*
run the neural network
>> args - the arguments >> args - the arguments
*/ */
bool XModel::RunMe(XList * args) bool XModel::RunMe(XList * args)
{ {
ShowNTErrors("NetBase::Run must be overloaded!"); CheckNTErrors(args->count >= 3, "More arguments are required!");
return true;
XList * inputs = (XList*)args->GetItem(0);
XList * outputs = (XList*)args->GetItem(1);
XList * golds = (XList*)args->GetItem(2);
if (RunSimple(inputs, outputs, golds))
return true;
ShowNTErrors("You must be overload one of these: XModel::RunSimple ... !");
return false;
} }
/* refresh the model */ /* refresh the model */
...@@ -103,8 +122,12 @@ bool XModel::Run(XList * args) ...@@ -103,8 +122,12 @@ bool XModel::Run(XList * args)
{ {
CheckNTErrors(args != NULL || args->count == 0, "no arguments for XModel::Refresh"); CheckNTErrors(args != NULL || args->count == 0, "no arguments for XModel::Refresh");
XModel * model = (XModel*)args->GetItem(0); XModel * model = (XModel*)args->GetItem(0);
XList newArgs;
for (int i = 1; i < args->count; i++)
newArgs.Add(args->GetItem(i));
return model->Run(args); return model->RunMe(&newArgs);
} }
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -80,8 +80,12 @@ public: ...@@ -80,8 +80,12 @@ public:
virtual virtual
XModel * Clone(int devID); XModel * Clone(int devID);
/* run the neural network (would be overloaded) */ /* run the neural network */
virtual virtual
bool RunSimple(XList * inputs, XList * outputs, XList * golds);
protected:
/* run the neural network */
bool RunMe(XList * args); bool RunMe(XList * args);
public: public:
......
...@@ -84,7 +84,7 @@ void XOptimizer::UpdateParam(XTensor * param, XTensor * grad, int pid) ...@@ -84,7 +84,7 @@ void XOptimizer::UpdateParam(XTensor * param, XTensor * grad, int pid)
{ {
/* the delta rule /* the delta rule
\theta_new = \theta_old - \grad * \lrate */ \theta_new = \theta_old - \grad * \lrate */
Sum(param, grad, param, -lrate); Sum(*param, *grad, *param, -lrate);
} }
} }
...@@ -33,7 +33,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -33,7 +33,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */ /* constructor */
XWorkerJob::XWorkerJob() XWorkerJob::XWorkerJob()
{ {
Clear();
} }
/* de-constructor */ /* de-constructor */
...@@ -44,6 +44,9 @@ XWorkerJob::~XWorkerJob() ...@@ -44,6 +44,9 @@ XWorkerJob::~XWorkerJob()
for (int i = 0; i < outputs.count; i++) for (int i = 0; i < outputs.count; i++)
delete (XTensor*)outputs[i]; delete (XTensor*)outputs[i];
for (int i = 0; i < golds.count; i++)
delete (XTensor*)golds[i];
} }
/* set the model */ /* set the model */
...@@ -64,10 +67,17 @@ void XWorkerJob::Clear() ...@@ -64,10 +67,17 @@ void XWorkerJob::Clear()
for (int i = 0; i < inputs.count; i++) for (int i = 0; i < inputs.count; i++)
delete (XTensor*)inputs[i]; delete (XTensor*)inputs[i];
inputs.Clear(); inputs.Clear();
inputs.Add(new XTensor());
for (int i = 0; i < outputs.count; i++) for (int i = 0; i < outputs.count; i++)
delete (XTensor*)outputs[i]; delete (XTensor*)outputs[i];
outputs.Clear(); outputs.Clear();
outputs.Add(new XTensor());
for (int i = 0; i < golds.count; i++)
delete (XTensor*)golds[i];
golds.Clear();
golds.Add(new XTensor());
} }
/* get the input list */ /* get the input list */
...@@ -82,6 +92,12 @@ XList * XWorkerJob::GetOutput() ...@@ -82,6 +92,12 @@ XList * XWorkerJob::GetOutput()
return &outputs; return &outputs;
} }
/* get the gold standard */
XList * XWorkerJob::GetGold()
{
return &golds;
}
/* /*
add a new job of model refreshment add a new job of model refreshment
>> myModel - the model >> myModel - the model
...@@ -104,9 +120,10 @@ add a new job of neural network forward and backward computation (with the input ...@@ -104,9 +120,10 @@ add a new job of neural network forward and backward computation (with the input
>> myModel - the model >> myModel - the model
>> inputs - inputs of the neural network >> inputs - inputs of the neural network
>> outputs - outputs of the neural network >> outputs - outputs of the neural network
>> golds - gold standards
<< return - succeeded or not << return - succeeded or not
*/ */
bool XWorkerJob::AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs) bool XWorkerJob::AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs, XList * golds)
{ {
CheckNTErrors(myModel != NULL, "no input neural network!"); CheckNTErrors(myModel != NULL, "no input neural network!");
CheckNTErrors(inputs != NULL, "no inputs of the model!"); CheckNTErrors(inputs != NULL, "no inputs of the model!");
...@@ -114,8 +131,9 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outpu ...@@ -114,8 +131,9 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outpu
XList args; XList args;
args.Add(myModel); args.Add(myModel);
args.AddList(inputs); args.Add(inputs);
args.AddList(outputs); args.Add(outputs);
args.Add(golds);
queue.EnqueueJob((void*)(char*)XModel::Run, &args); queue.EnqueueJob((void*)(char*)XModel::Run, &args);
......
...@@ -50,7 +50,7 @@ protected: ...@@ -50,7 +50,7 @@ protected:
XList outputs; XList outputs;
/* the gold standard */ /* the gold standard */
XList gold; XList golds;
public: public:
...@@ -82,7 +82,7 @@ public: ...@@ -82,7 +82,7 @@ public:
bool AddJobRefresh(XModel * myModel); bool AddJobRefresh(XModel * myModel);
/* add a new job of neural network forward and backward computation (with the input) */ /* add a new job of neural network forward and backward computation (with the input) */
bool AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs); bool AddJobNeuralNet(XModel * myModel, XList * inputs, XList * outputs, XList * golds);
}; };
} }
......
...@@ -101,7 +101,7 @@ wrapper of UpdateModel ...@@ -101,7 +101,7 @@ wrapper of UpdateModel
*/ */
void XWorkerUpdate::Update(XList * args) void XWorkerUpdate::Update(XList * args)
{ {
CheckNTErrors(args != NULL && args->count > 3, "Illegal argument list!"); CheckNTErrors(args != NULL && args->count >= 3, "Illegal argument list!");
XWorkerUpdate * updater = (XWorkerUpdate*)args->GetItem(0); XWorkerUpdate * updater = (XWorkerUpdate*)args->GetItem(0);
XModel * model = (XModel*)args->GetItem(1); XModel * model = (XModel*)args->GetItem(1);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论