Commit b0f2bbbf by liyinqiao

Bug fixed and add new unit test.

1. Fix the bug in MultiplyMe function which cannot handle the scalar tensor and broadcast case.
2. Add broadcast multiply unit test.
parent d16a087d
......@@ -157,7 +157,35 @@ where i is the index of the item
*/
void MultiplyMe(XTensor& a, const XTensor& b, DTYPE alpha, int leadingDim)
{
_Multiply(&a, &b, &a, alpha, leadingDim);
if (b.order == 0){
DTYPE scale = b.Get0D();
XTensor * tmp1 = NewTensorBufV2(&a, a.devID, a.mem);
XTensor * tmp2 = NewTensorBufV2(&a, a.devID, a.mem);
_ScaleAndShift(&a, tmp1, scale, 0.0F);
_ScaleAndShift(&a, tmp2, alpha, 0.0F);
_Sum(tmp2, tmp1, &a);
DelTensorBuf(tmp1);
DelTensorBuf(tmp2);
}
else {
int n = GetBroadcastDimIndex(a, b);
if (n == -1) {
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* call _Multiply function */
_Multiply(&a, &b, &a, alpha, leadingDim);
}
else if (n >= 0 && n < a.order) {
/* call _MultiplyDim function */
_MultiplyDim(&a, &b, &a, n, alpha);
}
else {
ShowNTErrors("Something is wrong!");
}
}
}
/*
......
......@@ -149,6 +149,131 @@ bool TestMultiply1()
#endif // USE_CUDA
}
/*
case 2: element-wise product of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
In this case, (2, 3, 4) * (2, 1, 1) -> (2, 3, 4), alpha=0.
*/
bool TestMultiply2()
{
/* a source tensor of size (2, 3, 4) */
int sOrder1 = 3;
int * sDimSize1 = new int[sOrder1];
sDimSize1[0] = 2;
sDimSize1[1] = 3;
sDimSize1[2] = 4;
int sUnitNum1 = 1;
for (int i = 0; i < sOrder1; i++)
sUnitNum1 *= sDimSize1[i];
/* a source tensor of size (2, 1, 1) */
int sOrder2 = 3;
int * sDimSize2 = new int[sOrder2];
sDimSize2[0] = 2;
sDimSize2[1] = 1;
sDimSize2[2] = 1;
int sUnitNum2 = 1;
for (int i = 0; i < sOrder2; i++)
sUnitNum2 *= sDimSize2[i];
/* a target tensor of size (2, 3, 4) */
int tOrder = 3;
int * tDimSize = new int[tOrder];
tDimSize[0] = 2;
tDimSize[1] = 3;
tDimSize[2] = 4;
int tUnitNum = 1;
for (int i = 0; i < tOrder; i++)
tUnitNum *= tDimSize[i];
DTYPE sData1[2][3][4] = { { {0.0F, 1.0F, 2.0F, 3.0F},
{3.0F, 2.0F, 1.0F, 0.0F},
{0.0F, 1.0F, 2.0F, 3.0F} },
{ {3.0F, 2.0F, 1.0F, 0.0F},
{0.0F, 1.0F, 2.0F, 3.0F},
{3.0F, 2.0F, 1.0F, 0.0F} } };
DTYPE sData2[2][1][1] = { { {1.0F} },
{ {-1.0F} } };
DTYPE answer[2][3][4] = { { {0.0F, 1.0F, 2.0F, 3.0F},
{3.0F, 2.0F, 1.0F, 0.0F},
{0.0F, 1.0F, 2.0F, 3.0F} },
{ {-3.0F, -2.0F, -1.0F, 0.0F},
{0.0F, -1.0F, -2.0F, -3.0F},
{-3.0F, -2.0F, -1.0F, 0.0F} } };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s1 = NewTensorV2(sOrder1, sDimSize1);
XTensor * s2 = NewTensorV2(sOrder2, sDimSize2);
XTensor * tMe = NewTensorV2(tOrder, tDimSize);
XTensor tUser;
/* initialize variables */
s1->SetData(sData1, sUnitNum1);
tMe->SetData(sData1, sUnitNum1);
s2->SetData(sData2, sUnitNum2);
/* call Multiply function */
MultiplyMe(*tMe, *s2, 0);
tUser = Multiply(*s1, *s2);
/* check results */
cpuTest = _CheckData(tMe, answer, 1e-4, tUnitNum) &&
_CheckData(&tUser, answer, 1e-4, tUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * sGPU1 = NewTensorV2(sOrder1, sDimSize1, X_FLOAT, 1.0F, 0);
XTensor * sGPU2 = NewTensorV2(sOrder2, sDimSize2, X_FLOAT, 1.0F, 0);
XTensor * tMeGPU = NewTensorV2(tOrder, tDimSize, X_FLOAT, 1.0F, 0);
XTensor tUserGPU;
/* Initialize variables */
sGPU1->SetData(sData1, sUnitNum1);
tMeGPU->SetData(sData1, sUnitNum1);
sGPU2->SetData(sData2, sUnitNum2);
/* call Multiply function */
MultiplyMe(*tMeGPU, *sGPU2, 0);
tUserGPU = Multiply(*sGPU1, *sGPU2);
/* check results */
gpuTest = _CheckData(tMeGPU, answer, tUnitNum, 1e-4F) &&
_CheckData(&tUserGPU, answer, tUnitNum, 1e-4F);
/* destroy variables */
delete s1;
delete s2;
delete tMe;
delete sGPU1;
delete sGPU2;
delete tMeGPU;
delete[] sDimSize1;
delete[] sDimSize2;
delete[] tDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s1;
delete s2;
delete tMe;
delete[] sDimSize1;
delete[] sDimSize2;
delete[] tDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
......@@ -170,6 +295,16 @@ bool TestMultiply()
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestMultiply2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论