Commit db6c4449 by xiaotong

updates

parent 2736ad16
......@@ -28,6 +28,7 @@
#include "Utility.h"
#include "../../tensor/XGlobal.h"
#include "../../tensor/XConfig.h"
using namespace nts;
using namespace std;
......@@ -157,90 +158,6 @@ int Config::LoadFromFile(const char* configFN, char** args) {
return argsNum;
}
void LoadParamString(int argc, char** argv, const char* name, char* p, const char* defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for (int i = 0; i < argc; i++) {
if (!strcmp(argv[i], vname) && i + 1 < argc) {
strcpy(p, argv[i + 1]);
hit = true;
break;
}
}
if (!hit)
strcpy(p, defaultP);
}
void LoadParamInt(int argc, char** argv, const char* name, int* p, int defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for (int i = 0; i < argc; i++) {
if (!strcmp(argv[i], vname) && i + 1 < argc) {
*(int*)p = atoi(argv[i + 1]);
hit = true;
break;
}
}
if (!hit)
*p = defaultP;
}
void LoadParamBool(int argc, char** argv, const char* name, bool* p, bool defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for (int i = 0; i < argc; i++) {
if (!strcmp(argv[i], vname)) {
*(bool*)p = true;
hit = true;
break;
}
}
if (!hit)
*p = defaultP;
}
void LoadParamFloat(int argc, char** argv, const char* name, float* p, float defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for (int i = 0; i < argc; i++) {
if (!strcmp(argv[i], vname) && i + 1 < argc) {
*p = (float)atof(argv[i + 1]);
hit = true;
break;
}
}
if (!hit)
*p = defaultP;
}
void ShowParams(int argc, char** argv)
{
fprintf(stderr, "args:\n");
for (int i = 0; i < argc; i++) {
if (argv[i][1] == 0)
continue;
if (argv[i][0] == '-' && (argv[i][1] < '1' || argv[i][1] > '9')) {
if (i + 1 < argc && argv[i + 1][0] != '-')
fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]);
else
fprintf(stderr, " %s=yes\n", argv[i]);
}
}
fprintf(stderr, "\n");
}
#define MAX_WORD_NUM 120
/*
......
......@@ -33,17 +33,6 @@ using namespace nts;
namespace nmt
{
#define MAX_PARAM_NUM 100
/* load arguments */
void LoadParamInt(int argc, char** argv, const char* name, int* p, int defaultP);
void LoadParamBool(int argc, char** argv, const char* name, bool* p, bool defaultP);
void LoadParamFloat(int argc, char** argv, const char* name, float* p, float defaultP);
void LoadParamString(int argc, char** argv, const char* name, char* p, const char* defaultP);
/* show arguments */
void ShowParams(int argc, char** argv);
/* split string */
IntList SplitInt(const string& s, const string& delimiter);
FloatList SplitFloat(const string& s, const string& delimiter);
......
......@@ -29,4 +29,334 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */
XConfig::XConfig()
{
n = 0;
args = NULL;
nReal = 0;
}
/* de-constructor */
XConfig::~XConfig()
{
for (int i = 0; i < n; i++) {
delete[] args[i];
}
delete[] args;
}
/*
create a config
>> myN - number of the input arguments
>> myArgs - the input arguments
*/
void XConfig::Create(const int myN, const char ** myArgs)
{
CheckNTErrors(myN > 0, "No input parameters to XConfig!");
for (int i = 0; i < n; i++) {
delete[] args[i];
}
delete[] args;
args = NULL;
n = myN;
nReal = n * 2;
args = new char*[nReal];
for (int i = 0; i < nReal; i++) {
args[i] = NULL;
}
for (int i = 0; i < n; i++) {
CheckNTErrors(myArgs[i] != NULL, "Illegal parameter input!");
args[i] = new char[strlen(myArgs[i]) + 1];
strcpy(args[i], myArgs[i]);
}
}
/*
add an argument
>> myArg - the argument
>> myValue - the value of the argument
*/
void XConfig::Add(const char * myArg, const char * myValue)
{
CheckNTErrors(myArg != NULL, "No argument!");
if (n + 2 > nReal) {
nReal = MAX(n * 2 + 1, 128);
char ** newArgs = new char*[nReal];
memset(newArgs, 0, sizeof(char*) * n);
memcpy(newArgs, args, sizeof(char*) * n);
delete[] args;
args = newArgs;
}
args[n] = new char[strlen(myArg) + 2];
args[n][0] = '-';
strcpy(args[n] + 1, myArg);
n++;
if (myValue != NULL) {
args[n] = new char[strlen(myValue) + 1];
strcpy(args[n], myValue);
n++;
}
}
/*
add an argument (in integer)
>> myArg - the argument
>> myValue - the value of the argument
*/
void XConfig::Add(const char * myArg, int myValue)
{
char value[MAX_WORD_LENGTH_IN_CONFIG];
sprintf(value, "%d", myValue);
Add(myArg, value);
}
/*
add an argument (in bool)
>> myArg - the argument
>> myValue - the value of the argument
*/
void XConfig::Add(const char * myArg, bool myValue)
{
char value[2];
if (myValue)
value[0] = '1';
else
value[0] = '0';
value[1] = 0;
Add(myArg, value);
}
/*
add an argument (in float)
>> myArg - the argument
>> myValue - the value of the argument
*/
void XConfig::Add(const char * myArg, float myValue)
{
char value[MAX_WORD_LENGTH_IN_CONFIG];
sprintf(value, "%f", myValue);
Add(myArg, value);
}
/*
load the value of an argument (in integer)
>> name - the name of the argument
>> p - where we place the loaded value
>> defaultP - the default value (used only if no argument is hit in the list)
*/
void XConfig::LoadInt(const char * name, int * p, int defaultP)
{
LoadParamInt(n, args, name, p, defaultP);
}
/*
load the value of an argument (in boolean)
>> name - the name of the argument
>> p - where we place the loaded value
>> defaultP - the default value (used only if no argument is hit in the list)
*/
void XConfig::LoadBool(const char * name, bool * p, bool defaultP)
{
LoadParamBool(n, args, name, p, defaultP);
}
/*
load the value of an argument (in float)
>> name - the name of the argument
>> p - where we place the loaded value
>> defaultP - the default value (used only if no argument is hit in the list)
*/void XConfig::LoadFloat(const char * name, float * p, float defaultP)
{
LoadParamFloat(n, args, name, p, defaultP);
}
/*
load the value of an argument (in char string)
>> name - the name of the argument
>> p - where we place the loaded value
>> defaultP - the default value (used only if no argument is hit in the list)
*/
void XConfig::LoadString(const char * name, char * p, const char* defaultP)
{
LoadParamString(n, args, name, p, defaultP);
}
/*
get the value of an argument (in integer)
>> name - the name of the argument
>> defaultP - the default value (used only if no argument is hit in the list)
*/
int XConfig::GetInt(const char * name, int defaultP)
{
int r;
LoadInt(name, &r, defaultP);
return r;
}
/*
get the value of an argument (in bool)
>> name - the name of the argument
>> defaultP - the default value (used only if no argument is hit in the list)
*/
bool XConfig::GetBool(const char * name, bool defaultP)
{
bool r;
LoadBool(name, &r, defaultP);
return r;
}
/*
get the value of an argument (in float)
>> name - the name of the argument
>> defaultP - the default value (used only if no argument is hit in the list)
*/
float XConfig::GetFloat(const char * name, float defaultP)
{
float r;
LoadFloat(name, &r, defaultP);
return r;
}
/*
load the value of an argument (in integer)
>> argc - number of arguments
>> argv - arguments
>> name - the argument we search for
>> p - the pointer to the target variable where we want to place the value
>> defaultP - the default value we use if no argument is found
*/
void LoadParamInt(int argc, char** argv, const char* name, int* p, int defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for (int i = 0; i < argc; i++) {
if (!strcmp(argv[i], vname) && i + 1 < argc) {
*(int*)p = atoi(argv[i + 1]);
hit = true;
break;
}
}
if (!hit)
*p = defaultP;
}
/*
load the value of an argument (in boolean)
>> argc - number of arguments
>> argv - arguments
>> name - the argument we search for
>> p - the pointer to the target variable where we want to place the value
>> defaultP - the default value we use if no argument is found
*/
void LoadParamBool(int argc, char** argv, const char* name, bool* p, bool defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for (int i = 0; i < argc; i++) {
if (!strcmp(argv[i], vname)) {
*(bool*)p = true;
hit = true;
break;
}
}
if (!hit)
*p = defaultP;
}
/*
load the value of an argument (in float)
>> argc - number of arguments
>> argv - arguments
>> name - the argument we search for
>> p - the pointer to the target variable where we want to place the value
>> defaultP - the default value we use if no argument is found
*/
void LoadParamFloat(int argc, char** argv, const char* name, float* p, float defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for (int i = 0; i < argc; i++) {
if (!strcmp(argv[i], vname) && i + 1 < argc) {
*p = (float)atof(argv[i + 1]);
hit = true;
break;
}
}
if (!hit)
*p = defaultP;
}
/*
load the value of an argument (in char string)
>> argc - number of arguments
>> argv - arguments
>> name - the argument we search for
>> p - the pointer to the target variable where we want to place the value
>> defaultP - the default value we use if no argument is found
*/
void LoadParamString(int argc, char** argv, const char* name, char* p, const char* defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for (int i = 0; i < argc; i++) {
if (!strcmp(argv[i], vname) && i + 1 < argc) {
strcpy(p, argv[i + 1]);
hit = true;
break;
}
}
if (!hit)
strcpy(p, defaultP);
}
/*
show the argument list
>> argc - number of arguments
>> argv - arguments
*/
void ShowParams(int argc, char** argv)
{
fprintf(stderr, "args:\n");
for (int i = 0; i < argc; i++) {
if (argv[i][1] == 0)
continue;
if (argv[i][0] == '-' && (argv[i][1] < '1' || argv[i][1] > '9')) {
if (i + 1 < argc && argv[i + 1][0] != '-')
fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]);
else
fprintf(stderr, " %s=yes\n", argv[i]);
}
}
fprintf(stderr, "\n");
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -34,6 +34,77 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
#define MAX_WORD_LENGTH_IN_CONFIG 256
/* the parameter keeper */
class XConfig
{
private:
/* number of arguments */
int n;
/* argument list (in char*) */
char ** args;
/* number of items we rellocate for these arguments */
int nReal;
public:
/* constructor */
XConfig();
/* de-constructor */
~XConfig();
/* create a config */
void Create(const int myN, const char ** myArgs);
/* add an argument */
void Add(const char * myArg, const char * myValue);
/* add an argument (in integer) */
void Add(const char * myArg, int myValue);
/* add an argument (in bool) */
void Add(const char * myArg, bool myValue);
/* add an argument (in float) */
void Add(const char * myArg, float myValue);
/* load the value of an argument to a variable (in integer) */
void LoadInt(const char * name, int * p, int defaultP);
/* load the value of an argument to a variable (in boolean) */
void LoadBool(const char * name, bool * p, bool defaultP);
/* load the value of an argument to a variable (in float) */
void LoadFloat(const char * name, float * p, float defaultP);
/* load the value of an argument to a variable (in char string) */
void LoadString(const char * name, char * p, const char* defaultP);
/* get the value of an argument (in integer) */
int GetInt(const char * name, int defaultP);
/* get the value of an argument (in boolean) */
bool GetBool(const char * name, bool defaultP);
/* get the value of an argument (in float) */
float GetFloat(const char * name, float defaultP);
};
#define MAX_PARAM_NUM 100
/* load arguments */
void extern LoadParamInt(int argc, char** argv, const char* name, int* p, int defaultP);
void extern LoadParamBool(int argc, char** argv, const char* name, bool* p, bool defaultP);
void extern LoadParamFloat(int argc, char** argv, const char* name, float* p, float defaultP);
void extern LoadParamString(int argc, char** argv, const char* name, char* p, const char* defaultP);
/* show arguments */
void extern ShowParams(int argc, char** argv);
} // namespace nts(NiuTrans.Tensor)
#endif
\ No newline at end of file
......@@ -40,7 +40,6 @@ namespace nts {
XLeader::XLeader()
{
id = -1;
dataDistributor = NULL;
}
/* de-constructor */
......@@ -69,17 +68,4 @@ void XLeader::SetMode(XLEADER_MODE myMode)
mode = myMode;
}
/* set the data loader */
void XLeader::SetDataDistributor(DataDistributeBase* myDataDistributor)
{
CheckNTErrors(myDataDistributor != NULL,
"The input of XLeader::SetDataLoader should not be NULL!");
dataDistributor = myDataDistributor;
}
/* run the leader (this is the core process) */
void XLeader::Run()
{
}
} /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
......@@ -35,7 +35,9 @@
#ifndef __XLEADER_H__
#define __XLEADER_H__
#include "XModel.h"
#include "XNetTemplate.h"
#include "../tensor/XConfig.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -58,10 +60,6 @@ protected:
/* communication mode */
XLEADER_MODE mode;
/* data loader */
DataDistributeBase* dataDistributor;
public:
/* constructor */
XLeader();
......@@ -78,13 +76,8 @@ public:
/* set the communication mode */
void SetMode(XLEADER_MODE myMode);
/* set the data loader */
void SetDataDistributor(DataDistributeBase * myDataLoader);
/* run the leader (this is the core process) */
virtual
void Run();
/* run the model (for one time) */
void Run(XConfig * config, DataDistributeBase * dataDistributor, XModel * modelParams, NetBase * net);
};
}
......
......@@ -47,14 +47,14 @@ DataDistributeBase::~DataDistributeBase()
}
/* * start the job (e.g., open the file) */
bool DataDistributeBase::Start(XList * args)
bool DataDistributeBase::Start()
{
ShowNTErrors("DataDistributeBase::Start must be overloaded!");
return true;
}
/* end the job (e.g., close the file) */
bool DataDistributeBase::End(XList * args)
bool DataDistributeBase::End()
{
ShowNTErrors("DataDistributeBase::End must be overloaded!");
return true;
......
......@@ -57,13 +57,15 @@ public:
/* de-constructor */
~DataDistributeBase();
/* start the job (e.g., open the file) */
/* start the job (e.g., open the file).
NOTE THAT before calling Start() one should initialize
the distributor if neccessary */
virtual
bool Start(XList * args);
bool Start();
/* end the job (e.g., close the file) */
virtual
bool End(XList * args);
bool End();
/* get a batch of samples */
virtual
......
......@@ -40,4 +40,30 @@ XTrainer::~XTrainer()
{
}
/*
run the trainer (this is the core process)
>> config - configuration
>> dataDistributor - the data distributor that generates an input for the net each time
>> modelParams - the parameter keeper
>> net - the neural network
*/
void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
XModel * modelParams, NetBase * net)
{
CheckNTErrors(config != NULL, "No input config!");
CheckNTErrors(dataDistributor != NULL, "No input data distributor!");
CheckNTErrors(modelParams != NULL, "No input model parameter keeper!");
CheckNTErrors(net != NULL, "No input neural network!");
int nepoch = config->GetInt("nepoch", 50);
int nstep = config->GetInt("nstep", 100000);
int epoch = 0;
int step = 0;
for (int epoch = 0; epoch < nepoch; epoch++) {
if (step++ >= nstep)
break;
}
}
} /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
......@@ -31,9 +31,10 @@
#ifndef __XTRAINER_H__
#define __XTRAINER_H__
#include "XLeader.h"
#include "../network/XNet.h"
#include "../tensor/XQueue.h"
#include "XWorker.h"
#include "../tensor/XConfig.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -66,6 +67,10 @@ public:
/* de-constructor */
~XTrainer();
/* run the leader (this is the core process) */
virtual
void Run(XConfig * config, DataDistributeBase * dataDistributor, XModel * modelParams, NetBase * net);
};
}
#endif // __XTRAINER_H__
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论