Commit c3b9f35a by xiaotong

copy constructor and reload of =

parent 4acc236a
......@@ -41,6 +41,21 @@ using namespace samplefnnlm;
int main( int argc, const char ** argv )
{
//_CrtSetBreakAlloc(78);
{
XTensor a;
XTensor b;
InitTensor2D(&a, 2, 2);
a.SetZeroAll();
a.Set2D(1.0F, 0, 0);
a.Set2D(1.0F, 1, 1);
b = Sum(a, a);
b.Dump(stderr, "b: ");
}
if(argc > 1 && !strcmp(argv[1], "-test"))
Test();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
......
......@@ -39,6 +39,7 @@
#include "XHeap.h"
#include "XBLAS.h"
#include "core/shape/MergeBlockLists.h"
#include "core/movement/CopyValues.h"
#ifdef USE_CUDA
......@@ -55,6 +56,23 @@
/* the nts (NiuTrans.Tensor) namespace */
namespace nts{
int tensorIDGlobal = 0;
MUTEX_HANDLE tensorMutex;
XTensor firstTensor;
/* generate a tensor id */
int MakeTensorID()
{
if(tensorIDGlobal == 0)
MUTEX_INIT(tensorMutex);
MUTEX_LOCK(tensorMutex);
int id = tensorIDGlobal++;
MUTEX_UNLOCK(tensorMutex);
return id;
}
/*
constructor
>> myOrder - order of the tensor
......@@ -64,6 +82,7 @@ XTensor::XTensor()
{
memset(this, 0, sizeof(XTensor));
id = MakeTensorID();
order = -1;
memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
memset(dimSizeRDI, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
......@@ -89,6 +108,7 @@ XTensor::XTensor(XTensor * reference)
{
memset(this, 0, sizeof(XTensor));
id = MakeTensorID();
dataType = DEFAULT_DTYPE;
devID = -1;
denseRatio = 1.0F;
......@@ -109,6 +129,7 @@ XTensor::XTensor(const int myOrder, int myDevID, XMem * myMem)
{
CheckNTErrors((myOrder > 0), "Illegal tensor order1");
id = MakeTensorID();
order = myOrder;
memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
memset(dimSizeRDI, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
......@@ -145,6 +166,7 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP
{
CheckNTErrors((myOrder > 0), "Illegal tensor order1");
id = MakeTensorID();
order = myOrder;
memset(dimSize, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
memset(dimSizeRDI, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
......@@ -165,6 +187,33 @@ XTensor::XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYP
Resize(myOrder, myDimSize, myDataType, myDenseRatio);
}
/* copy constructor */
XTensor::XTensor(XTensor &reference)
{
id = MakeTensorID();
ShallowCopy(reference);
isInit = false;
isTmp = false;
dataHost = NULL;
if(reference.isTmp){
devID = reference.devID;
mem = reference.mem;
data = reference.data;
reference.data = NULL;
}
else{
DestroyData();
if(isInit){
devID = reference.devID;
mem = reference.mem;
}
InitTensor(this, &reference);
CopyValues(&reference, this);
}
}
/* de-constructor */
XTensor::~XTensor()
{
......@@ -214,7 +263,7 @@ void XTensor::ShallowCopy(const XTensor &tensor)
}
/* overloading of the equal-sign */
XTensor& XTensor::operator = (XTensor& tensor)
XTensor& XTensor::operator= (XTensor& tensor)
{
/* hard copy of data array */
int size = unitNum * unitSize;
......@@ -229,20 +278,14 @@ XTensor& XTensor::operator = (XTensor& tensor)
}
else{
DestroyData();
if(!isInit){
if(isInit){
devID = tensor.devID;
mem = tensor.mem;
}
Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
if(tensor.isSparse) {
int num = int(tensor.unitNum * tensor.denseRatio + 1);
int tupleSize = sizeof(int)+sizeof(DTYPE);
size = sizeof(int) + tupleSize * num;
}
XMemCopy(data, devID, tensor.data, tensor.devID, size);
CopyValues(&tensor, this);
}
/* copy member variables */
......@@ -952,6 +995,15 @@ int XTensor::GetNonzeroSize()
}
/*
set the tensor as "temporary"
>> myIsTMP - flag
*/
void XTensor::SetTMP(bool myIsTmp)
{
isTmp = myIsTmp;
}
/*
resize a tensor with a specified tensor size
>> myOrder - order of the tensor
>> myDimSize - the size of each dimension
......
......@@ -61,6 +61,9 @@ is the parent class of XMatrix.
*/
struct XTensor
{
/* id */
int id;
/* memory pool */
XMem * mem;
......@@ -164,6 +167,9 @@ struct XTensor
XTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType,
const float myDenseRatio, XMem * myMem);
/* copy constructor */
XTensor(XTensor &reference);
/* de-constructor */
~XTensor();
......@@ -174,7 +180,7 @@ struct XTensor
void ShallowCopy(const XTensor &tensor);
/* overloading of the equal-sign */
XTensor& operator = (XTensor &tensor);
XTensor& operator= (XTensor &tensor);
/* judge whether the two matrices are in the same type and size */
static
......@@ -271,6 +277,9 @@ struct XTensor
/* get the number of non-zero elements (in a sparse tensor) */
int GetNonzeroSize();
/* set the tensor as "temporary" */
void SetTMP(bool myIsTmp = true);
/* resize a matrix with a specified matrix size */
bool Resize(const int myOrder, const int * myDimSize,
const TENSOR_DATA_TYPE myDataType = DEFAULT_DTYPE,
......@@ -305,6 +314,12 @@ struct XTensor
void FreeData(XTensor * matrix, XMem * myMem = NULL, bool useBuf = false);
};
/* we make a unique id for every tensor */
extern int tensorIDGlobal;
extern MUTEX_HANDLE tensorMutex;
extern XTensor firstTensor;
extern int MakeTensorID();
/************************************************
* we define the "new and delete" functions below
*/
......
......@@ -136,13 +136,14 @@ return a XTensor structure
XTensor Sum(XTensor &a, XTensor &b, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
/* computation */
_Sum(&a, &b, &c, beta);
/* tensor connections */
//XLink::MakeLink(&a, &b, &c, MATH_SUM);
//XLink::AddParamToHead(&c, beta);
XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta);
return c;
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论