Commit ddbb77b6 by liyinqiao

Reconstruct Range function.

parent c442dbeb
......@@ -857,33 +857,7 @@ void XTensor::Rand(int rNum, int cNum)
*/
void XTensor::Range(DTYPE lower, DTYPE upper, DTYPE step)
{
CheckNTErrors((order == 1), "Tensor must be 1 dimension!");
/* compute the true length according to the (start, end, step) */
DTYPE size = fabs(upper - lower);
int num = ceil(size / fabs(step));
CheckNTErrors((unitNum == num), "Unit number of the tensor is not matched.");
/* init a integer array to store the sequence */
void * data = NULL;
if (dataType == X_INT) {
data = new int[num];
for (int i = 0; i < num; i++)
*((int*)data + i) = lower + i * step;
}
else if (dataType == X_FLOAT) {
data = new float[num];
for (int i = 0; i < num; i++)
*((float*)data + i) = lower + i * step;
}
else {
ShowNTErrors("TODO!");
}
/* set the data from the array */
SetData(data, num);
delete[] data;
_SetDataRange(this, lower, upper, step);
}
/*
......
......@@ -526,6 +526,43 @@ void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
}
}
/* generate data items with a range by start, end and the step
>> tensor - the tensor whose data array would be initialized
>> start - the begin of the array
>> end - the end of the array (not included self)
>> step - the step of two items
*/
void _SetDataRange(XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE step)
{
CheckNTErrors((tensor->order == 1), "Tensor must be 1 dimension!");
/* compute the true length according to the (start, end, step) */
DTYPE size = fabs(upper - lower);
int num = ceil(size / fabs(step));
CheckNTErrors((tensor->unitNum == num), "Unit number of the tensor is not matched.");
/* init a integer array to store the sequence */
void * data = NULL;
if (tensor->dataType == X_INT) {
data = new int[num];
for (int i = 0; i < num; i++)
*((int*)data + i) = lower + i * step;
}
else if (tensor->dataType == X_FLOAT) {
data = new float[num];
for (int i = 0; i < num; i++)
*((float*)data + i) = lower + i * step;
}
else {
ShowNTErrors("TODO!");
}
/* set the data from the array */
tensor->SetData(data, num);
delete[] data;
}
/*
generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise
......
......@@ -69,6 +69,9 @@ void _SetDataRand(XTensor * tensor, int rNum, int cNum);
/* generate data items with a uniform distribution in [lower, upper] */
void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
/* generate data items with a range by start, end and the step */
void _SetDataRange(XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE step);
/* generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise */
void _SetDataRandP(XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value);
......
......@@ -406,6 +406,68 @@ bool TestSetData5()
#endif // USE_CUDA
}
/*
case 6: test SetDataRange function.
generate data items with a range by start, end and the step
*/
bool TestSetData6()
{
/* a input tensor of size (5) */
int order = 1;
int * dimSize = new int[order];
dimSize[0] = 5;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE answer[5] = { 5.2, 3.2, 1.2, -0.8, -2.8 };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(order, dimSize);
/* initialize variables */
s->SetZeroAll();
/* call _SetDataRange function */
_SetDataRange(s, 5.2, -3.2, -2);
/* check results */
cpuTest = s->CheckData(answer, unitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
/* initialize variables */
sGPU->SetZeroAll();
/* call _SetDataRange function */
_SetDataRange(sGPU, 5.2, -3.2, -2);
gpuTest = sGPU->CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete s;
delete sGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
......@@ -462,6 +524,15 @@ bool TestSetData()
else
XPRINT(0, stdout, ">> case 5 passed!\n");
/* case 6 test */
caseFlag = TestSetData6();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 6 failed!\n");
}
else
XPRINT(0, stdout, ">> case 6 passed!\n");
/* other cases test */
/*
TODO!!
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论