Commit ad3025f0 by xiaotong

buf fix in MatrixMul

parent 515af68a
...@@ -264,10 +264,10 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b, ...@@ -264,10 +264,10 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b,
CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!"); CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!");
CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!"); CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!");
int an = a.dimSizeRDI[0]; int an = a.dimSizeRDI[1];
int am = a.dimSizeRDI[1]; int am = a.dimSizeRDI[0];
int bn = b.dimSizeRDI[0]; int bn = b.dimSizeRDI[1];
int bm = b.dimSizeRDI[1]; int bm = b.dimSizeRDI[0];
CheckNTErrors(am == bn, "Unmatched tensors in multiplication!"); CheckNTErrors(am == bn, "Unmatched tensors in multiplication!");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论