Commit 51fe2255 by xiaotong

misuse of dimSize and dimSizeRDI

parent 385b7f62
......@@ -50,8 +50,8 @@ Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x
>> parallelRunner - parallel processing module
*/
void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
......@@ -210,9 +210,9 @@ Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x
XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{
CheckNTErrors((&a != &NULLTensor && &b != &NULLTensor), "Empty input tensors!");
CheckNTErrors((a.dataType == b.dataType), "Input tensors should have the same data type!");
CheckNTErrors((a.order >= 2 && b.order >= 2), "Input tensors must have a order >= 2!");
CheckNTErrors(&a != &NULLTensor && &b != &NULLTensor, "Empty input tensors!");
CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!");
CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!");
int an = transposedA == X_TRANS ? a.dimSizeRDI[0] : a.dimSizeRDI[1];
int am = transposedA == X_TRANS ? a.dimSizeRDI[1] : a.dimSizeRDI[0];
......@@ -224,15 +224,15 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor
int order = a.order + b.order - 2;
int sub = 0;
int * dimSize = new int[order];
for (int i = 2; i < a.order; i++)
dimSize[sub++] = a.dimSizeRDI[i];
for (int i = 2; i < b.order; i++)
dimSize[sub++] = b.dimSizeRDI[i];
dimSize[sub++] = b.dimSizeRDI[b.order + 1 - i];
for (int i = 2; i < a.order; i++)
dimSize[sub++] = a.dimSizeRDI[a.order + 1 - i];
dimSize[sub++] = an;
dimSize[sub++] = bm;
XTensor c = NewTensor(order, dimSize, a.dataType, a.denseRatio, a.devID, a.mem);
c.SetZeroAll();
float dr = (!a.isSparse || !b.isSparse) ? 1.0F : MAX(a.denseRatio, b.denseRatio);
XTensor c(order, dimSize, a.dataType, dr, a.devID, a.mem);
c.SetTMP();
/* call _MatrixMul function */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论