Commit fe37006f by xiaotong

improve the implementation of the overload of =

parent c090457f
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include "XMem.h" #include "XMem.h"
#include "XHeap.h" #include "XHeap.h"
#include "XBLAS.h" #include "XBLAS.h"
#include "XName.h"
#include "core/shape/MergeBlockLists.h" #include "core/shape/MergeBlockLists.h"
#include "core/movement/CopyValues.h" #include "core/movement/CopyValues.h"
#include "core/arithmetic/Sum.h" #include "core/arithmetic/Sum.h"
...@@ -45,6 +46,7 @@ ...@@ -45,6 +46,7 @@
#include "core/arithmetic/Sub.h" #include "core/arithmetic/Sub.h"
#include "core/arithmetic/Div.h" #include "core/arithmetic/Div.h"
#include "core/math/ScaleAndShift.h" #include "core/math/ScaleAndShift.h"
#include "function/Identity.h"
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -290,7 +292,8 @@ void XTensor::ShallowCopy(const XTensor &tensor) ...@@ -290,7 +292,8 @@ void XTensor::ShallowCopy(const XTensor &tensor)
/* overloading of the equal-sign */ /* overloading of the equal-sign */
XTensor& XTensor::operator= (const XTensor& tensor) XTensor& XTensor::operator= (const XTensor& tensor)
{ {
/* we must make a hard copy of the tensor if it is the input
/* we must make a hard copy of the tensor if it is the input
of another node. */ of another node. */
if(outgo.tailNum > 0){ if(outgo.tailNum > 0){
int dims[MAX_TENSOR_DIM_NUM]; int dims[MAX_TENSOR_DIM_NUM];
...@@ -312,38 +315,54 @@ XTensor& XTensor::operator= (const XTensor& tensor) ...@@ -312,38 +315,54 @@ XTensor& XTensor::operator= (const XTensor& tensor)
dataHost = NULL; dataHost = NULL;
} }
/* hard copy of the data array */ if(false && !tensor.isTmp){
int size = unitNum * unitSize; /* NOTE: this might lead to additional data copy on Mac machines */
if( isInit && !isSparse && !tensor.isSparse && /* we make an identity transformation here */
size == tensor.unitNum * tensor.unitSize &&
((devID < 0 && tensor.devID < 0) && devID == tensor.devID) && if(outgo.tailNum > 0)
data != NULL) XLink::ClearOutgoing(this);
{ XLink::ClearIncoming(this);
XMemCopy(data, devID, tensor.data, tensor.devID, size);
if(dataHost != NULL && tensor.dataHost != NULL) if(!IsSameShaped(this, &tensor))
XMemCopy(dataHost, -1, tensor.dataHost, tensor.devID, size); Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
_Identity(&tensor, this);
XLink::MakeLink(&tensor, NULL, this, FUNC_IDENTITY);
} }
else{ else{
DestroyData(); /* hard copy of the data array */
if(!isInit){ int size = unitNum * unitSize;
devID = tensor.devID; if( isInit && !isSparse && !tensor.isSparse &&
mem = tensor.mem; size == tensor.unitNum * tensor.unitSize &&
((devID < 0 && tensor.devID < 0) && devID == tensor.devID) &&
data != NULL)
{
XMemCopy(data, devID, tensor.data, tensor.devID, size);
if(dataHost != NULL && tensor.dataHost != NULL)
XMemCopy(dataHost, -1, tensor.dataHost, tensor.devID, size);
} }
else{
DestroyData();
if(!isInit){
devID = tensor.devID;
mem = tensor.mem;
}
Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio); Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
_CopyValues(&tensor, this); _CopyValues(&tensor, this);
} }
/* copy member variables */ /* copy member variables */
ShallowCopy(tensor); ShallowCopy(tensor);
isInit = true; isInit = true;
isTmp = false; isTmp = false;
CheckNTErrors(outgo.tailNum == 0, "The node has outgoing edge to other nodes!"); CheckNTErrors(outgo.tailNum == 0, "The node has outgoing edge to other nodes!");
/* create tensor links for the new tensor */ /* create tensor links for the new tensor */
XLink::Replace(&tensor, this); XLink::Replace(&tensor, this);
}
return *this; return *this;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论