Commit a40cd0f8 by xiaotong

updates of XQueue, XNet and XTensor

parent 4e4f27f5
......@@ -162,6 +162,7 @@ void XNet::BackwardNode(XTensor * node, bool isEfficent)
}
else{
node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
}
}
......
......@@ -245,4 +245,10 @@ bool XQueue::GetJobBreak()
return jobDequeuerBreak;
}
/* get the number of jobs */
int XQueue::GetJobNum()
{
return runningJobCount;
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -143,6 +143,9 @@ public:
/* get the break flag */
bool GetJobBreak();
/* get the number of jobs */
int GetJobNum();
};
} /* end of the nts (NiuTrans.Tensor) namespace */
......
......@@ -89,10 +89,6 @@ XTensor::XTensor()
Init();
id = MakeTensorID();
isDefaultDType = true;
isInGlobalMem = false;
isInit = false;
isTmp = false;
reserved = 0;
}
......@@ -277,6 +273,7 @@ void XTensor::Init()
isTmp = false;
isGrad = false;
isVar = false;
isGradFinished = false;
enableGrad = X_ENABLE_GRAD;
visitMark = 0;
grad = NULL;
......
......@@ -156,6 +156,11 @@ public:
/* mark for traversing the gragh */
unsigned int visitMark;
/* indicates whether the gradient of the tensor has been computed (in the backward process)
Note that the indicator could be modified by XNet (in back propagation) and be accessed
in XTrainer (and related classes). */
bool isGradFinished;
/* gradient (for back-propagation) */
XTensor * grad;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论