Commit 450872fc by xiaotong

bug fixes

parent e31eed94
......@@ -40,8 +40,11 @@ using namespace nmt;
int main( int argc, const char ** argv )
{
XConfig config;
config.Create(argc - 1, argv + 1);
verboseLevel = config.GetInt("verbose", 1);
if(argc > 1){
config.Create(argc - 1, argv + 1);
verboseLevel = config.GetInt("verbose", 1);
}
if (argc > 1 && !strcmp(argv[1], "-test"))
Test();
......
......@@ -109,7 +109,7 @@ void XNet::Backward(TensorList &roots)
XTensor * node = (XTensor*)nodes.Get(i);
if(node->mem != NULL){
CheckNTErrors(node->mem->bufUsed < BUF_PITCH, "Illegal access of buffer!");
//CheckNTErrors(node->mem->bufUsed < BUF_PITCH, "Illegal access of buffer!");
}
if(node->visitMark != NODE_FINISHED)
......@@ -128,7 +128,20 @@ void XNet::Backward(TensorList &roots)
delete node;
}
}
}
}
for (int i = 0; i < nodes.count; i++) {
XTensor* node = (XTensor*)nodes.Get(i);
if (node->income.tailNum >= 100 || node->outgo.tailNum >= 100) {
XPRINT(1, stderr, "Are you sure that the node should connect so many (100) nodes?\n");
}
if (node->grad != NULL) {
XTensor* grad = node->grad;
if (grad->income.tailNum >= 100 || grad->outgo.tailNum >= 100) {
XPRINT(1, stderr, "Are you sure that the grad node should connect so many (100) nodes?\n");
}
}
}
}
......
......@@ -76,12 +76,15 @@ void TestTrain()
GeneateTTrainData("ttrain.txt");
XConfig config;
config.Add("dev", -1);
//config.Add("dev", -1);
config.Add("lrate", 0.001F);
config.Add("nstep", 10000);
config.Add("nstep", 100000);
config.Add("nepoch", 5);
//config.Add("jobdev0", -1);
config.Add("jobdev0", -1);
//config.Add("jobdev1", -1);
//config.Add("jobdev2", -1);
//config.Add("jobdev3", -1);
//config.Add("jobdev4", -1);
TTDataLoader loader;
loader.SetFileName("ttrain.txt");
......@@ -300,6 +303,10 @@ XModel * TTModel::Clone(int devID)
model->SetConfig(config);
model->Init(config, devID);
CopyValues(embeddingW, model->embeddingW);
CopyValues(hiddenW, model->hiddenW);
CopyValues(outputW, model->outputW);
return model;
}
......
......@@ -135,7 +135,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
XPRINT5(1, stderr, "[INFO] elapsed=%.1fs epoch:%d step:%d sample:%d loss:%f\n",
GetClockSec() - startT, epoch + 1, step + 1, leader.GetSampleNum(), loss);
if (step++ >= optimizer->nstep)
if (++step >= optimizer->nstep)
break;
}
......
......@@ -196,7 +196,7 @@ void XWorkerCollect::CollectP2P(XTensor * source, XTensor * target)
/* target += source */
if(source != target)
Sum(*source, *target, *source);
_Sum(source, target, source);
}
/*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论