Commit 3d31727d by xiaotong

add Get and Set methods to fetch and save int-type items

parent fe2e0d79
...@@ -1115,6 +1115,20 @@ DTYPE XTensor::Get3D(int d0, int d1, int d2) ...@@ -1115,6 +1115,20 @@ DTYPE XTensor::Get3D(int d0, int d1, int d2)
return ToCPU(devID, value); return ToCPU(devID, value);
} }
/*
get the int value of a cell by its offset
>> offset - offset of the item
*/
int XTensor::GetInt(int offset)
{
CheckNTErrors(offset >= 0 && offset < unitNum, "Invalid index!");
CheckNTErrors(data != NULL, "Cannot use an uninitialized tensor!");
int * value = (int*)data + offset;
return ToCPUInt(devID, value);
}
/* /*
get the value of a cell in a 1d tensor in int type get the value of a cell in a 1d tensor in int type
...@@ -1123,9 +1137,9 @@ get the value of a cell in a 1d tensor in int type ...@@ -1123,9 +1137,9 @@ get the value of a cell in a 1d tensor in int type
*/ */
int XTensor::Get1DInt(int i) int XTensor::Get1DInt(int i)
{ {
CheckNTErrors((order == 1), "Cannot get a 2d cell for a tensor whose order is not 2!"); CheckNTErrors(order == 1, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors((i >= 0 && i < dimSize[0]), "dimension 0 is out of range!"); CheckNTErrors(i >= 0 && i < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors((dataType == X_INT), "The tensor is not in int type."); CheckNTErrors(dataType == X_INT, "The tensor is not in int type.");
int dimSize[1] = {i}; int dimSize[1] = {i};
void * value = GetCell(dimSize, 1); void * value = GetCell(dimSize, 1);
...@@ -1141,10 +1155,10 @@ get the value of a cell in a 2d tensor in int type ...@@ -1141,10 +1155,10 @@ get the value of a cell in a 2d tensor in int type
*/ */
int XTensor::Get2DInt(int ni, int mi) int XTensor::Get2DInt(int ni, int mi)
{ {
CheckNTErrors((order == 2), "Cannot get a 2d cell for a tensor whose order is not 2!"); CheckNTErrors(order == 2, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors((ni >= 0 && ni < dimSize[0]), "dimension 0 is out of range!"); CheckNTErrors(ni >= 0 && ni < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors((mi >= 0 && mi < dimSize[1]), "dimension 1 is out of range!"); CheckNTErrors(mi >= 0 && mi < dimSize[1], "dimension 1 is out of range!");
CheckNTErrors((dataType == X_INT), "The tensor is not in default type."); CheckNTErrors(dataType == X_INT, "The tensor is not in default type.");
int dims[2] = {ni, mi}; int dims[2] = {ni, mi};
void * value = GetCell(dims, 2); void * value = GetCell(dims, 2);
...@@ -1161,11 +1175,11 @@ get the value of a cell in a 3d tensor in int type ...@@ -1161,11 +1175,11 @@ get the value of a cell in a 3d tensor in int type
*/ */
int XTensor::Get3DInt(int d0, int d1, int d2) int XTensor::Get3DInt(int d0, int d1, int d2)
{ {
CheckNTErrors((order == 3), "Cannot get a 2d cell for a tensor whose order is not 2!"); CheckNTErrors(order == 3, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors((d0 >= 0 && d0 < dimSize[0]), "dimension 0 is out of range!"); CheckNTErrors(d0 >= 0 && d0 < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors((d1 >= 0 && d1 < dimSize[1]), "dimension 1 is out of range!"); CheckNTErrors(d1 >= 0 && d1 < dimSize[1], "dimension 1 is out of range!");
CheckNTErrors((d2 >= 0 && d2 < dimSize[2]), "dimension 2 is out of range!"); CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 2 is out of range!");
CheckNTErrors((dataType == X_INT), "The tensor is not in default type."); CheckNTErrors(dataType == X_INT, "The tensor is not in default type.");
int dims[3] = {d0, d1, d2}; int dims[3] = {d0, d1, d2};
void * value = GetCell(dims, 3); void * value = GetCell(dims, 3);
...@@ -1180,8 +1194,8 @@ get the value of a cell in the sparse tensor ...@@ -1180,8 +1194,8 @@ get the value of a cell in the sparse tensor
*/ */
DTYPE XTensor::GetInSparse(int i) DTYPE XTensor::GetInSparse(int i)
{ {
CheckNTErrors((i >= 0 && i < unitNum), "Index is out of range!"); CheckNTErrors(i >= 0 && i < unitNum, "Index is out of range!");
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type."); CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
char * d = (char*)data + sizeof(int); char * d = (char*)data + sizeof(int);
DTYPE * value = (DTYPE*)(d + (sizeof(int) + sizeof(DTYPE)) * i + sizeof(int)); DTYPE * value = (DTYPE*)(d + (sizeof(int) + sizeof(DTYPE)) * i + sizeof(int));
...@@ -1196,8 +1210,8 @@ get the key value of a tuple in a sparse tensor ...@@ -1196,8 +1210,8 @@ get the key value of a tuple in a sparse tensor
*/ */
int XTensor::GetKeyInSparse(int i) int XTensor::GetKeyInSparse(int i)
{ {
CheckNTErrors((i >= 0 && i < unitNum), "Index is out of range!"); CheckNTErrors(i >= 0 && i < unitNum, "Index is out of range!");
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type."); CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
char * d = (char*)data + sizeof(int); char * d = (char*)data + sizeof(int);
int * key = (int*)(d + (sizeof(int) + sizeof(DTYPE)) * i); int * key = (int*)(d + (sizeof(int) + sizeof(DTYPE)) * i);
...@@ -1213,7 +1227,7 @@ set the value of a cell ...@@ -1213,7 +1227,7 @@ set the value of a cell
*/ */
bool XTensor::Set(DTYPE value, int index[], int size) bool XTensor::Set(DTYPE value, int index[], int size)
{ {
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type."); CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
return SetToDevice(devID, GetCell(index, size), value); return SetToDevice(devID, GetCell(index, size), value);
} }
...@@ -1226,9 +1240,9 @@ set the value of a cell in a 1d tensor ...@@ -1226,9 +1240,9 @@ set the value of a cell in a 1d tensor
*/ */
bool XTensor::Set1D(DTYPE value, int i) bool XTensor::Set1D(DTYPE value, int i)
{ {
CheckNTErrors((order == 1), "Cannot get a 2d cell for a tensor whose order is not 2!"); CheckNTErrors(order == 1, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors((i >= 0 && i < dimSize[0]), "dimension 0 is out of range!"); CheckNTErrors(i >= 0 && i < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type."); CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
int dims[1] = {i}; int dims[1] = {i};
...@@ -1244,10 +1258,10 @@ set the value of a cell in a 2d tensor in default type ...@@ -1244,10 +1258,10 @@ set the value of a cell in a 2d tensor in default type
*/ */
bool XTensor::Set2D(DTYPE value, int ni, int mi) bool XTensor::Set2D(DTYPE value, int ni, int mi)
{ {
CheckNTErrors((order == 2), "Cannot get a 2d cell for a tensor whose order is not 2!"); CheckNTErrors(order == 2, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors((ni >= 0 && ni < dimSize[0]), "dimension 0 is out of range!"); CheckNTErrors(ni >= 0 && ni < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors((mi >= 0 && mi < dimSize[1]), "dimension 1 is out of range!"); CheckNTErrors(mi >= 0 && mi < dimSize[1], "dimension 1 is out of range!");
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type."); CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
int dims[2] = {ni, mi}; int dims[2] = {ni, mi};
...@@ -1274,6 +1288,21 @@ bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2) ...@@ -1274,6 +1288,21 @@ bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2)
return SetToDevice(devID, GetCell(dims, 3), value); return SetToDevice(devID, GetCell(dims, 3), value);
} }
/*
set the integer value of a cell by its offset
>> value - value we tend to set to the item
>> offset - offset of the item
*/
bool XTensor::SetInt(int value, int offset)
{
CheckNTErrors(offset >= 0 && offset < unitNum, "Invalid index!");
CheckNTErrors(data != NULL, "Cannot use an uninitialized tensor!");
int * d = (int*)data + offset;
return SetToDevice(devID, d, value);
}
/* /*
...@@ -1285,7 +1314,7 @@ set the integer value of a cell ...@@ -1285,7 +1314,7 @@ set the integer value of a cell
*/ */
bool XTensor::SetInt(int value, int index[], int size) bool XTensor::SetInt(int value, int index[], int size)
{ {
CheckNTErrors((dataType == X_INT), "The tensor is not in integer type."); CheckNTErrors(dataType == X_INT, "The tensor is not in integer type.");
return SetToDeviceInt(devID, GetCell(index, size), value); return SetToDeviceInt(devID, GetCell(index, size), value);
} }
...@@ -1298,9 +1327,9 @@ set the integer value of a cell in a 1d tensor ...@@ -1298,9 +1327,9 @@ set the integer value of a cell in a 1d tensor
*/ */
bool XTensor::Set1DInt(int value, int i) bool XTensor::Set1DInt(int value, int i)
{ {
CheckNTErrors((order == 1), "Cannot get a 2d cell for a tensor whose order is not 2!"); CheckNTErrors(order == 1, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors((i >= 0 && i < dimSize[0]), "dimension 0 is out of range!"); CheckNTErrors(i >= 0 && i < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors((dataType == X_INT), "The tensor is not in integer type."); CheckNTErrors(dataType == X_INT, "The tensor is not in integer type.");
int dims[1] = {i}; int dims[1] = {i};
...@@ -1316,10 +1345,10 @@ set the integer value of a cell in a 2d tensor in default type ...@@ -1316,10 +1345,10 @@ set the integer value of a cell in a 2d tensor in default type
*/ */
bool XTensor::Set2DInt(int value, int ni, int mi) bool XTensor::Set2DInt(int value, int ni, int mi)
{ {
CheckNTErrors((order == 2), "Cannot get a 2d cell for a tensor whose order is not 2!"); CheckNTErrors(order == 2, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors((ni >= 0 && ni < dimSize[0]), "dimension 0 is out of range!"); CheckNTErrors(ni >= 0 && ni < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors((mi >= 0 && mi < dimSize[1]), "dimension 1 is out of range!"); CheckNTErrors(mi >= 0 && mi < dimSize[1], "dimension 1 is out of range!");
CheckNTErrors((dataType == X_INT), "The tensor is not in integer type."); CheckNTErrors(dataType == X_INT, "The tensor is not in integer type.");
int dims[2] = {ni, mi}; int dims[2] = {ni, mi};
...@@ -1356,10 +1385,10 @@ increase the value of a cell in a 2d tensor ...@@ -1356,10 +1385,10 @@ increase the value of a cell in a 2d tensor
*/ */
bool XTensor::Add2D(DTYPE value, int ni, int mi) bool XTensor::Add2D(DTYPE value, int ni, int mi)
{ {
CheckNTErrors((ni >= 0 && ni < dimSize[0]), "the row index is out of range!"); CheckNTErrors(ni >= 0 && ni < dimSize[0], "the row index is out of range!");
CheckNTErrors((mi >= 0 && mi < dimSize[1]), "the column index is out of range!"); CheckNTErrors(mi >= 0 && mi < dimSize[1], "the column index is out of range!");
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type."); CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
CheckNTErrors((isSparse == false), "TODO!"); CheckNTErrors(isSparse == false, "TODO!");
if(devID < 0){ if(devID < 0){
DTYPE * p = (DTYPE*)data + ni * dimSize[1] + mi; DTYPE * p = (DTYPE*)data + ni * dimSize[1] + mi;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论