Commit fe37006f by xiaotong

improve the implementation of the overload of =

parent c090457f
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include "XMem.h" #include "XMem.h"
#include "XHeap.h" #include "XHeap.h"
#include "XBLAS.h" #include "XBLAS.h"
#include "XName.h"
#include "core/shape/MergeBlockLists.h" #include "core/shape/MergeBlockLists.h"
#include "core/movement/CopyValues.h" #include "core/movement/CopyValues.h"
#include "core/arithmetic/Sum.h" #include "core/arithmetic/Sum.h"
...@@ -45,6 +46,7 @@ ...@@ -45,6 +46,7 @@
#include "core/arithmetic/Sub.h" #include "core/arithmetic/Sub.h"
#include "core/arithmetic/Div.h" #include "core/arithmetic/Div.h"
#include "core/math/ScaleAndShift.h" #include "core/math/ScaleAndShift.h"
#include "function/Identity.h"
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -290,6 +292,7 @@ void XTensor::ShallowCopy(const XTensor &tensor) ...@@ -290,6 +292,7 @@ void XTensor::ShallowCopy(const XTensor &tensor)
/* overloading of the equal-sign */ /* overloading of the equal-sign */
XTensor& XTensor::operator= (const XTensor& tensor) XTensor& XTensor::operator= (const XTensor& tensor)
{ {
/* we must make a hard copy of the tensor if it is the input /* we must make a hard copy of the tensor if it is the input
of another node. */ of another node. */
if(outgo.tailNum > 0){ if(outgo.tailNum > 0){
...@@ -312,6 +315,21 @@ XTensor& XTensor::operator= (const XTensor& tensor) ...@@ -312,6 +315,21 @@ XTensor& XTensor::operator= (const XTensor& tensor)
dataHost = NULL; dataHost = NULL;
} }
if(false && !tensor.isTmp){
/* NOTE: this might lead to additional data copy on Mac machines */
/* we make an identity transformation here */
if(outgo.tailNum > 0)
XLink::ClearOutgoing(this);
XLink::ClearIncoming(this);
if(!IsSameShaped(this, &tensor))
Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
_Identity(&tensor, this);
XLink::MakeLink(&tensor, NULL, this, FUNC_IDENTITY);
}
else{
/* hard copy of the data array */ /* hard copy of the data array */
int size = unitNum * unitSize; int size = unitNum * unitSize;
if( isInit && !isSparse && !tensor.isSparse && if( isInit && !isSparse && !tensor.isSparse &&
...@@ -344,6 +362,7 @@ XTensor& XTensor::operator= (const XTensor& tensor) ...@@ -344,6 +362,7 @@ XTensor& XTensor::operator= (const XTensor& tensor)
/* create tensor links for the new tensor */ /* create tensor links for the new tensor */
XLink::Replace(&tensor, this); XLink::Replace(&tensor, this);
}
return *this; return *this;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论