Commit b2be85f1 by xiaotong

bug fixes

parent 4bcf6c54
......@@ -39,6 +39,8 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
XTensor * tmpTT = NULL;
/* genreate the training data file */
void GeneateTTrainData(const char * fileName)
{
......@@ -83,6 +85,8 @@ void TestTrain()
TTModel model;
model.Init(config, -1);
tmpTT = (XTensor*)model.params[0];
XOptimizer optimizer;
optimizer.Init(config);
......
......@@ -52,6 +52,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define TT_EMBEDDING_SIZE 256
#define TT_HIDDEN_SIZE 256
extern XTensor * tmpTT;
/* genreate the training data file */
void GeneateTTrainData(const char * fileName);
......
......@@ -87,7 +87,7 @@ void XOptimizer::UpdateParam(XTensor * param, XTensor * grad, int pid)
{
/* the delta rule
\theta_new = \theta_old - \grad * \lrate */
Sum(*param, *grad, *param, -lrate);
_Sum(param, grad, param, -lrate);
}
}
......@@ -118,7 +118,7 @@ void XWorkerBroadcast::BroadcastP2P(XTensor * source, XTensor * target)
{
CheckNTErrors(source != NULL, "The source tensor should not be NULL!");
CheckNTErrors(target != NULL, "The target tensor should not be NULL!");
CheckNTErrors(IsSameShaped(source, target), "The two tensors should be of the same shape!");
CheckNTErrors(IsSameShaped(*source, *target), "The two tensors should be of the same shape!");
CopyValues(*source, *target);
}
......
......@@ -200,7 +200,7 @@ void XWorkerCollect::CollectP2P(XTensor * source, XTensor * target)
{
CheckNTErrors(source != NULL, "The source tensor should not be NULL!");
CheckNTErrors(target != NULL, "The target tensor should not be NULL!");
CheckNTErrors(IsSameShaped(source, target), "The two tensors should be of the same shape!");
CheckNTErrors(IsSameShaped(*source, *target), "The two tensors should be of the same shape!");
/* target += source */
if(source != target)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论