Commit 2dadc66a by xiaotong

check parameter in XModel

parent d5269725
...@@ -138,8 +138,11 @@ void XLeader::SetMode(XLEADER_MODE myMode) ...@@ -138,8 +138,11 @@ void XLeader::SetMode(XLEADER_MODE myMode)
/* start the workers */ /* start the workers */
void XLeader::Start() void XLeader::Start()
{ {
serverModel.CheckParam();
for (int i = 0; i < jworkers.count; i++) { for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers.GetItem(i); XWorkerJob * worker = (XWorkerJob*)jworkers.GetItem(i);
worker->GetModel()->CheckParam();
worker->Start(); worker->Start();
} }
......
...@@ -103,7 +103,7 @@ add a parameter tensor ...@@ -103,7 +103,7 @@ add a parameter tensor
*/ */
void XModel::AddParam(XTensor* param) void XModel::AddParam(XTensor* param)
{ {
//param->SetVarFlag(); param->SetVarFlag();
params.Add(param); params.Add(param);
PARAM_STATE * newFlags = new PARAM_STATE[params.count]; PARAM_STATE * newFlags = new PARAM_STATE[params.count];
...@@ -115,6 +115,18 @@ void XModel::AddParam(XTensor* param) ...@@ -115,6 +115,18 @@ void XModel::AddParam(XTensor* param)
flags = newFlags; flags = newFlags;
} }
/* check if the parameters are well-defined for training */
bool XModel::CheckParam()
{
for (int i = 0; i < params.count; i++) {
XTensor * param = (XTensor*)params[i];
if (!param->isGrad)
return false;
}
return true;
}
/* refresh the model */ /* refresh the model */
void XModel::RefreshMe() void XModel::RefreshMe()
{ {
......
...@@ -90,7 +90,10 @@ protected: ...@@ -90,7 +90,10 @@ protected:
public: public:
/* add a parameter tensor */ /* add a parameter tensor */
void AddParam(XTensor* param); void AddParam(XTensor * param);
/* check if the parameters are well-defined for training */
bool CheckParam();
/* refresh the model */ /* refresh the model */
void RefreshMe(); void RefreshMe();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论