Commit 2bb37f15 by xiaotong

more methods of accessing tensor entries and creating tensors

parent 9619384f
...@@ -1020,10 +1020,27 @@ get the value of a cell with the index ...@@ -1020,10 +1020,27 @@ get the value of a cell with the index
*/ */
DTYPE XTensor::Get(int index[], int size) DTYPE XTensor::Get(int index[], int size)
{ {
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type."); CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in the default type.");
return ToCPU(devID, GetCell(index, size)); return ToCPU(devID, GetCell(index, size));
} }
/*
get the value of a cell with the offset
>> offset - offset in the array
<< return - cell value
*/
DTYPE XTensor::Get(int offset)
{
CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in the default type.");
CheckNTErrors(offset >= 0 && offset < unitNum, "Invalid index!");
CheckNTErrors(data != NULL, "Cannot use an uninitialized tensor!");
CheckNTErrors(denseRatio == 1.0F, "Only dense tensors are supported in Get(offset).");
DTYPE * address = (DTYPE*)data + offset;
return ToCPU(devID, address);
}
/* /*
get the pointer to a cell get the pointer to a cell
...@@ -1122,12 +1139,14 @@ get the int value of a cell by its offset ...@@ -1122,12 +1139,14 @@ get the int value of a cell by its offset
*/ */
int XTensor::GetInt(int offset) int XTensor::GetInt(int offset)
{ {
CheckNTErrors(dataType == X_INT, "The tensor is not in the integer type.");
CheckNTErrors(offset >= 0 && offset < unitNum, "Invalid index!"); CheckNTErrors(offset >= 0 && offset < unitNum, "Invalid index!");
CheckNTErrors(data != NULL, "Cannot use an uninitialized tensor!"); CheckNTErrors(data != NULL, "Cannot use an uninitialized tensor!");
CheckNTErrors(denseRatio == 1.0F, "Only dense tensors are supported in Get(offset).");
int * value = (int*)data + offset; int * address = (int*)data + offset;
return ToCPUInt(devID, value); return ToCPUInt(devID, address);
} }
/* /*
...@@ -2190,6 +2209,21 @@ void InitTensor(XTensor * tensor, const XTensor * reference) ...@@ -2190,6 +2209,21 @@ void InitTensor(XTensor * tensor, const XTensor * reference)
reference->devID, reference->mem); reference->devID, reference->mem);
} }
/*
initialize a tensor on the CPU with a reference tensor
>> tensor - the tensor we intend to initialize
>> reference - the reference tensor
*/
void InitTensorOnCPU(XTensor * tensor, const XTensor * reference)
{
if(reference->order < 0)
return;
InitTensor(tensor, reference->order, reference->dimSize,
reference->dataType, reference->denseRatio,
-1);
}
/* generate a XTensor with no initialization */ /* generate a XTensor with no initialization */
XTensor * NewTensor() XTensor * NewTensor()
{ {
......
...@@ -318,6 +318,9 @@ public: ...@@ -318,6 +318,9 @@ public:
/* get the value of a cell with the index */ /* get the value of a cell with the index */
DTYPE Get(int index[], int size = -1); DTYPE Get(int index[], int size = -1);
/* get the value of a cell with the offset */
DTYPE Get(int offset);
/* get the pointer to a cell */ /* get the pointer to a cell */
void * GetCell(int index[], int size = -1) const; void * GetCell(int index[], int size = -1) const;
...@@ -462,6 +465,9 @@ void InitTensor5D(XTensor * tensor, const int d0, const int d1, const int d2, co ...@@ -462,6 +465,9 @@ void InitTensor5D(XTensor * tensor, const int d0, const int d1, const int d2, co
/* initialize a tensor with a reference tensor */ /* initialize a tensor with a reference tensor */
void InitTensor(XTensor * tensor, const XTensor * reference); void InitTensor(XTensor * tensor, const XTensor * reference);
/* initialize a tensor on the CPU with a reference tensor */
void InitTensorOnCPU(XTensor * tensor, const XTensor * reference);
/* generate a XTensor with no initialization */ /* generate a XTensor with no initialization */
XTensor * NewTensor(); XTensor * NewTensor();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论