Commit 7d4ab222 by liyinqiao

Fix the mistakes in manual and clean the codes.

1. Fix the mistakes in manual. By the way, I have to say there are so many mistakes in the manual. I'm shocked it has been checked a lot of times, but why they are still be there. No one care about that? Really???
2. Clean the codes.
parent 8e21537a
...@@ -103,7 +103,6 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, ...@@ -103,7 +103,6 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DelTensorBuf(tmp); DelTensorBuf(tmp);
return c; return c;
} }
/* /*
...@@ -114,17 +113,17 @@ operation c = x * w + b MulAndShift ...@@ -114,17 +113,17 @@ operation c = x * w + b MulAndShift
>> parallelRunner - parallel processing module >> parallelRunner - parallel processing module
<< return - the result of matrix multiplication << return - the result of matrix multiplication
*/ */
XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedA, XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedX,
const XTensor& w, MATRIX_TRANS_TYPE transposedB, const XTensor& w, MATRIX_TRANS_TYPE transposedW,
const XTensor& b, DTYPE alpha, XPRunner* parallelRunner) const XTensor& b, DTYPE alpha, XPRunner* parallelRunner)
{ {
CheckNTErrors(x.dataType == w.dataType, "Input tensors should have the same data type!"); CheckNTErrors(x.dataType == w.dataType, "Input tensors should have the same data type!");
CheckNTErrors(x.order >= 2 && w.order >= 2, "Input tensors must have a order >= 2!"); CheckNTErrors(x.order >= 2 && w.order >= 2, "Input tensors must have a order >= 2!");
int xn = transposedA == X_TRANS ? x.dimSize[x.order - 1] : x.dimSize[x.order - 2]; int xn = transposedX == X_TRANS ? x.dimSize[x.order - 1] : x.dimSize[x.order - 2];
int xm = transposedA == X_TRANS ? x.dimSize[x.order - 2] : x.dimSize[x.order - 1]; int xm = transposedX == X_TRANS ? x.dimSize[x.order - 2] : x.dimSize[x.order - 1];
int wn = transposedB == X_TRANS ? w.dimSize[w.order - 1] : w.dimSize[w.order - 2]; int wn = transposedW == X_TRANS ? w.dimSize[w.order - 1] : w.dimSize[w.order - 2];
int wm = transposedB == X_TRANS ? w.dimSize[w.order - 2] : w.dimSize[w.order - 1]; int wm = transposedW == X_TRANS ? w.dimSize[w.order - 2] : w.dimSize[w.order - 1];
int order = x.order + w.order - 2; int order = x.order + w.order - 2;
int sub = 0; int sub = 0;
...@@ -141,7 +140,7 @@ XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedA, ...@@ -141,7 +140,7 @@ XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedA,
XTensor * tmp = NewTensorBufV2(order, dimSize, x.dataType, dr, x.devID, x.mem); XTensor * tmp = NewTensorBufV2(order, dimSize, x.dataType, dr, x.devID, x.mem);
/* call _MatrixMul function */ /* call _MatrixMul function */
_MatrixMul(&x, transposedA, &w, transposedB, tmp, alpha, 0, parallelRunner); _MatrixMul(&x, transposedX, &w, transposedW, tmp, alpha, 0, parallelRunner);
XTensor c(tmp); XTensor c(tmp);
c.SetTMPFlag(); c.SetTMPFlag();
...@@ -169,8 +168,8 @@ XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedA, ...@@ -169,8 +168,8 @@ XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedA,
if (w.enableGrad && b.enableGrad) { if (w.enableGrad && b.enableGrad) {
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT); XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, transposedA); XLink::AddParamToHeadTrans(&c, transposedX);
XLink::AddParamToHeadTrans(&c, transposedB); XLink::AddParamToHeadTrans(&c, transposedW);
} }
/* destroy variables */ /* destroy variables */
......
...@@ -31,8 +31,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -31,8 +31,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL); DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
XTensor MulAndShift(const XTensor &x, MATRIX_TRANS_TYPE transposedA, XTensor MulAndShift(const XTensor &x, MATRIX_TRANS_TYPE transposedX,
const XTensor &w, MATRIX_TRANS_TYPE transposedB, const XTensor &w, MATRIX_TRANS_TYPE transposedW,
const XTensor &b, DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL); const XTensor &b, DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -27,7 +27,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -27,7 +27,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* /*
case 1: element-wise division of two tensors case 1: element-wise division of two tensors
c(i) = a(i)/b(i) + \alpha * c(i) c(i) = a(i)/b(i) + \alpha * c(i)
In this case, (2, 2) (2, 2) -> (2, 2), leadingDim=0, alpha=0. In this case, (2, 2) / (2, 2) -> (2, 2), leadingDim=0, alpha=0.
*/ */
bool TestDiv1() bool TestDiv1()
{ {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论