Commit fe37006f by xiaotong

improve the implementation of the overload of =

parent c090457f
......@@ -38,6 +38,7 @@
#include "XMem.h"
#include "XHeap.h"
#include "XBLAS.h"
#include "XName.h"
#include "core/shape/MergeBlockLists.h"
#include "core/movement/CopyValues.h"
#include "core/arithmetic/Sum.h"
......@@ -45,6 +46,7 @@
#include "core/arithmetic/Sub.h"
#include "core/arithmetic/Div.h"
#include "core/math/ScaleAndShift.h"
#include "function/Identity.h"
#ifdef USE_CUDA
......@@ -290,7 +292,8 @@ void XTensor::ShallowCopy(const XTensor &tensor)
/* overloading of the equal-sign */
XTensor& XTensor::operator= (const XTensor& tensor)
{
/* we must make a hard copy of the tensor if it is the input
/* we must make a hard copy of the tensor if it is the input
of another node. */
if(outgo.tailNum > 0){
int dims[MAX_TENSOR_DIM_NUM];
......@@ -312,38 +315,54 @@ XTensor& XTensor::operator= (const XTensor& tensor)
dataHost = NULL;
}
/* hard copy of the data array */
int size = unitNum * unitSize;
if( isInit && !isSparse && !tensor.isSparse &&
size == tensor.unitNum * tensor.unitSize &&
((devID < 0 && tensor.devID < 0) && devID == tensor.devID) &&
data != NULL)
{
XMemCopy(data, devID, tensor.data, tensor.devID, size);
if(dataHost != NULL && tensor.dataHost != NULL)
XMemCopy(dataHost, -1, tensor.dataHost, tensor.devID, size);
if(false && !tensor.isTmp){
/* NOTE: this might lead to additional data copy on Mac machines */
/* we make an identity transformation here */
if(outgo.tailNum > 0)
XLink::ClearOutgoing(this);
XLink::ClearIncoming(this);
if(!IsSameShaped(this, &tensor))
Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
_Identity(&tensor, this);
XLink::MakeLink(&tensor, NULL, this, FUNC_IDENTITY);
}
else{
DestroyData();
if(!isInit){
devID = tensor.devID;
mem = tensor.mem;
/* hard copy of the data array */
int size = unitNum * unitSize;
if( isInit && !isSparse && !tensor.isSparse &&
size == tensor.unitNum * tensor.unitSize &&
((devID < 0 && tensor.devID < 0) && devID == tensor.devID) &&
data != NULL)
{
XMemCopy(data, devID, tensor.data, tensor.devID, size);
if(dataHost != NULL && tensor.dataHost != NULL)
XMemCopy(dataHost, -1, tensor.dataHost, tensor.devID, size);
}
else{
DestroyData();
if(!isInit){
devID = tensor.devID;
mem = tensor.mem;
}
Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
_CopyValues(&tensor, this);
}
Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
_CopyValues(&tensor, this);
}
/* copy member variables */
ShallowCopy(tensor);
/* copy member variables */
ShallowCopy(tensor);
isInit = true;
isTmp = false;
isInit = true;
isTmp = false;
CheckNTErrors(outgo.tailNum == 0, "The node has outgoing edge to other nodes!");
CheckNTErrors(outgo.tailNum == 0, "The node has outgoing edge to other nodes!");
/* create tensor links for the new tensor */
XLink::Replace(&tensor, this);
/* create tensor links for the new tensor */
XLink::Replace(&tensor, this);
}
return *this;
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论