Commit c3b9f35a by xiaotong

copy constructor and reload of =

parent 4acc236a
...@@ -41,6 +41,21 @@ using namespace samplefnnlm; ...@@ -41,6 +41,21 @@ using namespace samplefnnlm;
int main( int argc, const char ** argv ) int main( int argc, const char ** argv )
{ {
//_CrtSetBreakAlloc(78);
{
XTensor a;
XTensor b;
InitTensor2D(&a, 2, 2);
a.SetZeroAll();
a.Set2D(1.0F, 0, 0);
a.Set2D(1.0F, 1, 1);
b = Sum(a, a);
b.Dump(stderr, "b: ");
}
if(argc > 1 && !strcmp(argv[1], "-test")) if(argc > 1 && !strcmp(argv[1], "-test"))
Test(); Test();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm")) else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
......
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include "XHeap.h" #include "XHeap.h"
#include "XBLAS.h" #include "XBLAS.h"
#include "core/shape/MergeBlockLists.h" #include "core/shape/MergeBlockLists.h"
#include "core/movement/CopyValues.h"
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -55,6 +56,23 @@ ...@@ -55,6 +56,23 @@
/* the nts (NiuTrans.Tensor) namespace */ /* the nts (NiuTrans.Tensor) namespace */
namespace nts{ namespace nts{
int tensorIDGlobal = 0;
MUTEX_HANDLE tensorMutex;
XTensor firstTensor;
/* generate a tensor id */
int MakeTensorID()
{
if(tensorIDGlobal == 0)
MUTEX_INIT(tensorMutex);
MUTEX_LOCK(tensorMutex);
int id = tensorIDGlobal++;
MUTEX_UNLOCK(tensorMutex);
return id;
}
/* /*
constructor constructor
>> myOrder - order of the tensor >> myOrder - order of the tensor
...@@ -64,6 +82,7 @@ XTensor::XTensor() ...@@ -64,6 +82,7 @@ XTensor::XTensor()
{ {
memset(this, 0, sizeof(XTensor)); memset(this, 0, sizeof(XTensor));
id = MakeTensorID();
order = -1; order = -1;
memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM); memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
memset(dimSizeRDI, 0, sizeof(int) * MAX_TENSOR_DIM_NUM); memset(dimSizeRDI, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
...@@ -89,6 +108,7 @@ XTensor::XTensor(XTensor * reference) ...@@ -89,6 +108,7 @@ XTensor::XTensor(XTensor * reference)
{ {
memset(this, 0, sizeof(XTensor)); memset(this, 0, sizeof(XTensor));
id = MakeTensorID();
dataType = DEFAULT_DTYPE; dataType = DEFAULT_DTYPE;
devID = -1; devID = -1;
denseRatio = 1.0F; denseRatio = 1.0F;
...@@ -109,6 +129,7 @@ XTensor::XTensor(const int myOrder, int myDevID, XMem * myMem) ...@@ -109,6 +129,7 @@ XTensor::XTensor(const int myOrder, int myDevID, XMem * myMem)
{ {
CheckNTErrors((myOrder > 0), "Illegal tensor order1"); CheckNTErrors((myOrder > 0), "Illegal tensor order1");
id = MakeTensorID();
order = myOrder; order = myOrder;
memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM); memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
memset(dimSizeRDI, 0, sizeof(int) * MAX_TENSOR_DIM_NUM); memset(dimSizeRDI, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
...@@ -145,6 +166,7 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP ...@@ -145,6 +166,7 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP
{ {
CheckNTErrors((myOrder > 0), "Illegal tensor order1"); CheckNTErrors((myOrder > 0), "Illegal tensor order1");
id = MakeTensorID();
order = myOrder; order = myOrder;
memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM); memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
memset(dimSizeRDI, 0, sizeof(int) * MAX_TENSOR_DIM_NUM); memset(dimSizeRDI, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
...@@ -165,6 +187,33 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP ...@@ -165,6 +187,33 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP
Resize(myOrder, myDimSize, myDataType, myDenseRatio); Resize(myOrder, myDimSize, myDataType, myDenseRatio);
} }
/* copy constructor */
XTensor::XTensor(XTensor &reference)
{
id = MakeTensorID();
ShallowCopy(reference);
isInit = false;
isTmp = false;
dataHost = NULL;
if(reference.isTmp){
devID = reference.devID;
mem = reference.mem;
data = reference.data;
reference.data = NULL;
}
else{
DestroyData();
if(isInit){
devID = reference.devID;
mem = reference.mem;
}
InitTensor(this, &reference);
CopyValues(&reference, this);
}
}
/* de-constructor */ /* de-constructor */
XTensor::~XTensor() XTensor::~XTensor()
{ {
...@@ -214,7 +263,7 @@ void XTensor::ShallowCopy(const XTensor &tensor) ...@@ -214,7 +263,7 @@ void XTensor::ShallowCopy(const XTensor &tensor)
} }
/* overloading of the equal-sign */ /* overloading of the equal-sign */
XTensor& XTensor::operator = (XTensor& tensor) XTensor& XTensor::operator= (XTensor& tensor)
{ {
/* hard copy of data array */ /* hard copy of data array */
int size = unitNum * unitSize; int size = unitNum * unitSize;
...@@ -229,20 +278,14 @@ XTensor& XTensor::operator = (XTensor& tensor) ...@@ -229,20 +278,14 @@ XTensor& XTensor::operator = (XTensor& tensor)
} }
else{ else{
DestroyData(); DestroyData();
if(!isInit){ if(isInit){
devID = tensor.devID; devID = tensor.devID;
mem = tensor.mem; mem = tensor.mem;
} }
Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio); Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
if(tensor.isSparse) { CopyValues(&tensor, this);
int num = int(tensor.unitNum * tensor.denseRatio + 1);
int tupleSize = sizeof(int)+sizeof(DTYPE);
size = sizeof(int) + tupleSize * num;
}
XMemCopy(data, devID, tensor.data, tensor.devID, size);
} }
/* copy member variables */ /* copy member variables */
...@@ -952,6 +995,15 @@ int XTensor::GetNonzeroSize() ...@@ -952,6 +995,15 @@ int XTensor::GetNonzeroSize()
} }
/* /*
set the tensor as "temporary"
>> myIsTMP - flag
*/
void XTensor::SetTMP(bool myIsTmp)
{
isTmp = myIsTmp;
}
/*
resize a tensor with a specified tensor size resize a tensor with a specified tensor size
>> myOrder - order of the tensor >> myOrder - order of the tensor
>> myDimSize - the size of each dimension >> myDimSize - the size of each dimension
......
...@@ -61,6 +61,9 @@ is the parent class of XMatrix. ...@@ -61,6 +61,9 @@ is the parent class of XMatrix.
*/ */
struct XTensor struct XTensor
{ {
/* id */
int id;
/* memory pool */ /* memory pool */
XMem * mem; XMem * mem;
...@@ -164,6 +167,9 @@ struct XTensor ...@@ -164,6 +167,9 @@ struct XTensor
XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType, XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType,
const float myDenseRatio, XMem * myMem); const float myDenseRatio, XMem * myMem);
/* copy constructor */
XTensor(XTensor &reference);
/* de-constructor */ /* de-constructor */
~XTensor(); ~XTensor();
...@@ -174,7 +180,7 @@ struct XTensor ...@@ -174,7 +180,7 @@ struct XTensor
void ShallowCopy(const XTensor &tensor); void ShallowCopy(const XTensor &tensor);
/* overloading of the equal-sign */ /* overloading of the equal-sign */
XTensor& operator = (XTensor &tensor); XTensor& operator= (XTensor &tensor);
/* judge whether the two matrices are in the same type and size */ /* judge whether the two matrices are in the same type and size */
static static
...@@ -271,6 +277,9 @@ struct XTensor ...@@ -271,6 +277,9 @@ struct XTensor
/* get the number of non-zero elements (in a sparse tensor) */ /* get the number of non-zero elements (in a sparse tensor) */
int GetNonzeroSize(); int GetNonzeroSize();
/* set the tensor as "temporary" */
void SetTMP(bool myIsTmp = true);
/* resize a matrix with a specified matrix size */ /* resize a matrix with a specified matrix size */
bool Resize(const int myOrder, const int * myDimSize, bool Resize(const int myOrder, const int * myDimSize,
const TENSOR_DATA_TYPE myDataType = DEFAULT_DTYPE, const TENSOR_DATA_TYPE myDataType = DEFAULT_DTYPE,
...@@ -305,6 +314,12 @@ struct XTensor ...@@ -305,6 +314,12 @@ struct XTensor
void FreeData(XTensor * matrix, XMem * myMem = NULL, bool useBuf = false); void FreeData(XTensor * matrix, XMem * myMem = NULL, bool useBuf = false);
}; };
/* we make a unique id for every tensor */
extern int tensorIDGlobal;
extern MUTEX_HANDLE tensorMutex;
extern XTensor firstTensor;
extern int MakeTensorID();
/************************************************ /************************************************
* we define the "new and delete" functions below * we define the "new and delete" functions below
*/ */
......
...@@ -136,13 +136,14 @@ return a XTensor structure ...@@ -136,13 +136,14 @@ return a XTensor structure
XTensor Sum(XTensor &a, XTensor &b, DTYPE beta) XTensor Sum(XTensor &a, XTensor &b, DTYPE beta)
{ {
XTensor c(&a); XTensor c(&a);
c.SetTMP();
/* computation */ /* computation */
_Sum(&a, &b, &c, beta); _Sum(&a, &b, &c, beta);
/* tensor connections */ /* tensor connections */
//XLink::MakeLink(&a, &b, &c, MATH_SUM); XLink::MakeLink(&a, &b, &c, MATH_SUM);
//XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
return c; return c;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论