Commit 450872fc by xiaotong

bug fixes

parent e31eed94
...@@ -40,8 +40,11 @@ using namespace nmt; ...@@ -40,8 +40,11 @@ using namespace nmt;
int main( int argc, const char ** argv ) int main( int argc, const char ** argv )
{ {
XConfig config; XConfig config;
if(argc > 1){
config.Create(argc - 1, argv + 1); config.Create(argc - 1, argv + 1);
verboseLevel = config.GetInt("verbose", 1); verboseLevel = config.GetInt("verbose", 1);
}
if (argc > 1 && !strcmp(argv[1], "-test")) if (argc > 1 && !strcmp(argv[1], "-test"))
Test(); Test();
......
...@@ -109,7 +109,7 @@ void XNet::Backward(TensorList &roots) ...@@ -109,7 +109,7 @@ void XNet::Backward(TensorList &roots)
XTensor * node = (XTensor*)nodes.Get(i); XTensor * node = (XTensor*)nodes.Get(i);
if(node->mem != NULL){ if(node->mem != NULL){
CheckNTErrors(node->mem->bufUsed < BUF_PITCH, "Illegal access of buffer!"); //CheckNTErrors(node->mem->bufUsed < BUF_PITCH, "Illegal access of buffer!");
} }
if(node->visitMark != NODE_FINISHED) if(node->visitMark != NODE_FINISHED)
...@@ -128,7 +128,20 @@ void XNet::Backward(TensorList &roots) ...@@ -128,7 +128,20 @@ void XNet::Backward(TensorList &roots)
delete node; delete node;
} }
} }
}
}
for (int i = 0; i < nodes.count; i++) {
XTensor* node = (XTensor*)nodes.Get(i);
if (node->income.tailNum >= 100 || node->outgo.tailNum >= 100) {
XPRINT(1, stderr, "Are you sure that the node should connect so many (100) nodes?\n");
}
if (node->grad != NULL) {
XTensor* grad = node->grad;
if (grad->income.tailNum >= 100 || grad->outgo.tailNum >= 100) {
XPRINT(1, stderr, "Are you sure that the grad node should connect so many (100) nodes?\n");
}
} }
} }
} }
......
...@@ -76,12 +76,15 @@ void TestTrain() ...@@ -76,12 +76,15 @@ void TestTrain()
GeneateTTrainData("ttrain.txt"); GeneateTTrainData("ttrain.txt");
XConfig config; XConfig config;
config.Add("dev", -1); //config.Add("dev", -1);
config.Add("lrate", 0.001F); config.Add("lrate", 0.001F);
config.Add("nstep", 10000); config.Add("nstep", 100000);
config.Add("nepoch", 5); config.Add("nepoch", 5);
//config.Add("jobdev0", -1); config.Add("jobdev0", -1);
//config.Add("jobdev1", -1); //config.Add("jobdev1", -1);
//config.Add("jobdev2", -1);
//config.Add("jobdev3", -1);
//config.Add("jobdev4", -1);
TTDataLoader loader; TTDataLoader loader;
loader.SetFileName("ttrain.txt"); loader.SetFileName("ttrain.txt");
...@@ -300,6 +303,10 @@ XModel * TTModel::Clone(int devID) ...@@ -300,6 +303,10 @@ XModel * TTModel::Clone(int devID)
model->SetConfig(config); model->SetConfig(config);
model->Init(config, devID); model->Init(config, devID);
CopyValues(embeddingW, model->embeddingW);
CopyValues(hiddenW, model->hiddenW);
CopyValues(outputW, model->outputW);
return model; return model;
} }
......
...@@ -135,7 +135,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -135,7 +135,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
XPRINT5(1, stderr, "[INFO] elapsed=%.1fs epoch:%d step:%d sample:%d loss:%f\n", XPRINT5(1, stderr, "[INFO] elapsed=%.1fs epoch:%d step:%d sample:%d loss:%f\n",
GetClockSec() - startT, epoch + 1, step + 1, leader.GetSampleNum(), loss); GetClockSec() - startT, epoch + 1, step + 1, leader.GetSampleNum(), loss);
if (step++ >= optimizer->nstep) if (++step >= optimizer->nstep)
break; break;
} }
......
...@@ -196,7 +196,7 @@ void XWorkerCollect::CollectP2P(XTensor * source, XTensor * target) ...@@ -196,7 +196,7 @@ void XWorkerCollect::CollectP2P(XTensor * source, XTensor * target)
/* target += source */ /* target += source */
if(source != target) if(source != target)
Sum(*source, *target, *source); _Sum(source, target, source);
} }
/* /*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论