Commit a57ad688 by liyinqiao

Merge with the branch of huchi.

Support reserved property in XTensor to indicate whether the data is reserved for backward.
parent 55965160
......@@ -93,6 +93,7 @@ XTensor::XTensor()
isInGlobalMem = false;
isInit = false;
isTmp = false;
reserved = 0;
}
/* constructor */
......@@ -158,6 +159,8 @@ XTensor::XTensor(const XTensor& reference)
devID = reference.devID;
mem = reference.mem;
data = reference.data;
reserved = reference.reserved;
const_cast<XTensor&>(reference).reserved = 0;
signature = reference.signature;
const_cast<XTensor&>(reference).data = NULL;
}
......@@ -202,6 +205,8 @@ XTensor::XTensor(const XTensor&& reference)
isInit = true;
isTmp = reference.isTmp;
reserved = reference.reserved;
const_cast<XTensor&>(reference).reserved = 0;
}
/* de-constructor */
......@@ -218,8 +223,13 @@ XTensor::~XTensor()
XTensor* newTensor = new XTensor(order, dims, dataType, denseRatio, devID, mem);
newTensor->SetTMPFlag();
newTensor->data = data;
data = NULL;
if (reserved == -1) {
newTensor->data = NULL;
}
else {
newTensor->data = data;
data = NULL;
}
if (enableGrad)
XLink::Replace(this, newTensor);
......@@ -324,6 +334,11 @@ XTensor& XTensor::operator= (const XTensor& tensor)
XTensor* newTensor = new XTensor(order, dims, dataType, denseRatio, devID, mem);
newTensor->SetTMPFlag();
/* release the data if it won't be used in backward */
if (reserved == -1) {
DestroyData();
}
newTensor->data = data;
newTensor->dataHost = dataHost;
newTensor->signature = tensor.signature;
......@@ -409,6 +424,11 @@ XTensor& XTensor::operator= (const XTensor&& tensor)
XTensor* newTensor = new XTensor(order, dims, dataType, denseRatio, devID, mem);
newTensor->SetTMPFlag();
/* release the data if it won't be used in backward */
if (reserved == -1) {
DestroyData();
}
newTensor->data = data;
newTensor->dataHost = dataHost;
newTensor->signature = tensor.signature;
......@@ -436,9 +456,11 @@ XTensor& XTensor::operator= (const XTensor&& tensor)
const_cast<XTensor&>(tensor).data = NULL;
if (enableGrad) {
XLink::Copy(&tensor, this);
XLink::Replace(&tensor, this);
}
reserved = tensor.reserved;
const_cast<XTensor&>(tensor).reserved = 0;
return *this;
}
......
......@@ -142,6 +142,9 @@ public:
/* indicates whether the tensor is created temporarily */
bool isTmp;
/* indicates whether the data is reserved for backward. 0: uncertain; 1: reserved; -1: deletable */
int reserved;
/* indicates whether the tensor keeps the gradient when used as model parameters */
bool isGrad;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论