Commit 2bb37f15 by xiaotong

more methods of accessing tensor entries and creating tensors

parent 9619384f
......@@ -1020,10 +1020,27 @@ get the value of a cell with the index
*/
DTYPE XTensor::Get(int index[], int size)
{
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type.");
CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in the default type.");
return ToCPU(devID, GetCell(index, size));
}
/*
get the value of a cell with the offset
>> offset - offset in the array
<< return - cell value
*/
DTYPE XTensor::Get(int offset)
{
CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in the default type.");
CheckNTErrors(offset >= 0 && offset < unitNum, "Invalid index!");
CheckNTErrors(data != NULL, "Cannot use an uninitialized tensor!");
CheckNTErrors(denseRatio == 1.0F, "Only dense tensors are supported in Get(offset).");
DTYPE * address = (DTYPE*)data + offset;
return ToCPU(devID, address);
}
/*
get the pointer to a cell
......@@ -1122,12 +1139,14 @@ get the int value of a cell by its offset
*/
int XTensor::GetInt(int offset)
{
CheckNTErrors(dataType == X_INT, "The tensor is not in the integer type.");
CheckNTErrors(offset >= 0 && offset < unitNum, "Invalid index!");
CheckNTErrors(data != NULL, "Cannot use an uninitialized tensor!");
CheckNTErrors(denseRatio == 1.0F, "Only dense tensors are supported in Get(offset).");
int * value = (int*)data + offset;
int * address = (int*)data + offset;
return ToCPUInt(devID, value);
return ToCPUInt(devID, address);
}
/*
......@@ -2190,6 +2209,21 @@ void InitTensor(XTensor * tensor, const XTensor * reference)
reference->devID, reference->mem);
}
/*
initialize a tensor on the CPU with a reference tensor
>> tensor - the tensor we intend to initialize
>> reference - the reference tensor
*/
void InitTensorOnCPU(XTensor * tensor, const XTensor * reference)
{
if(reference->order < 0)
return;
InitTensor(tensor, reference->order, reference->dimSize,
reference->dataType, reference->denseRatio,
-1);
}
/* generate a XTensor with no initialization */
XTensor * NewTensor()
{
......
......@@ -318,6 +318,9 @@ public:
/* get the value of a cell with the index */
DTYPE Get(int index[], int size = -1);
/* get the value of a cell with the offset */
DTYPE Get(int offset);
/* get the pointer to a cell */
void * GetCell(int index[], int size = -1) const;
......@@ -462,6 +465,9 @@ void InitTensor5D(XTensor * tensor, const int d0, const int d1, const int d2, co
/* initialize a tensor with a reference tensor */
void InitTensor(XTensor * tensor, const XTensor * reference);
/* initialize a tensor on the CPU with a reference tensor */
void InitTensorOnCPU(XTensor * tensor, const XTensor * reference);
/* generate a XTensor with no initialization */
XTensor * NewTensor();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论