Commit b2be85f1 by xiaotong

bug fixes

parent 4bcf6c54
...@@ -39,6 +39,8 @@ ...@@ -39,6 +39,8 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
XTensor * tmpTT = NULL;
/* genreate the training data file */ /* genreate the training data file */
void GeneateTTrainData(const char * fileName) void GeneateTTrainData(const char * fileName)
{ {
...@@ -83,6 +85,8 @@ void TestTrain() ...@@ -83,6 +85,8 @@ void TestTrain()
TTModel model; TTModel model;
model.Init(config, -1); model.Init(config, -1);
tmpTT = (XTensor*)model.params[0];
XOptimizer optimizer; XOptimizer optimizer;
optimizer.Init(config); optimizer.Init(config);
......
...@@ -52,6 +52,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -52,6 +52,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define TT_EMBEDDING_SIZE 256 #define TT_EMBEDDING_SIZE 256
#define TT_HIDDEN_SIZE 256 #define TT_HIDDEN_SIZE 256
extern XTensor * tmpTT;
/* genreate the training data file */ /* genreate the training data file */
void GeneateTTrainData(const char * fileName); void GeneateTTrainData(const char * fileName);
......
...@@ -87,7 +87,7 @@ void XOptimizer::UpdateParam(XTensor * param, XTensor * grad, int pid) ...@@ -87,7 +87,7 @@ void XOptimizer::UpdateParam(XTensor * param, XTensor * grad, int pid)
{ {
/* the delta rule /* the delta rule
\theta_new = \theta_old - \grad * \lrate */ \theta_new = \theta_old - \grad * \lrate */
Sum(*param, *grad, *param, -lrate); _Sum(param, grad, param, -lrate);
} }
} }
...@@ -118,7 +118,7 @@ void XWorkerBroadcast::BroadcastP2P(XTensor * source, XTensor * target) ...@@ -118,7 +118,7 @@ void XWorkerBroadcast::BroadcastP2P(XTensor * source, XTensor * target)
{ {
CheckNTErrors(source != NULL, "The source tensor should not be NULL!"); CheckNTErrors(source != NULL, "The source tensor should not be NULL!");
CheckNTErrors(target != NULL, "The target tensor should not be NULL!"); CheckNTErrors(target != NULL, "The target tensor should not be NULL!");
CheckNTErrors(IsSameShaped(source, target), "The two tensors should be of the same shape!"); CheckNTErrors(IsSameShaped(*source, *target), "The two tensors should be of the same shape!");
CopyValues(*source, *target); CopyValues(*source, *target);
} }
......
...@@ -200,7 +200,7 @@ void XWorkerCollect::CollectP2P(XTensor * source, XTensor * target) ...@@ -200,7 +200,7 @@ void XWorkerCollect::CollectP2P(XTensor * source, XTensor * target)
{ {
CheckNTErrors(source != NULL, "The source tensor should not be NULL!"); CheckNTErrors(source != NULL, "The source tensor should not be NULL!");
CheckNTErrors(target != NULL, "The target tensor should not be NULL!"); CheckNTErrors(target != NULL, "The target tensor should not be NULL!");
CheckNTErrors(IsSameShaped(source, target), "The two tensors should be of the same shape!"); CheckNTErrors(IsSameShaped(*source, *target), "The two tensors should be of the same shape!");
/* target += source */ /* target += source */
if(source != target) if(source != target)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论