Commit 718ab43f by liyinqiao

Merge with XU Chen branch (Don't use this! It's an incomplete version)

Clean the codes.
parent 3a515f68
...@@ -62,7 +62,7 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -62,7 +62,7 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
/* we transform a higher order tensor to a matrix to kill the number /* we transform a higher order tensor to a matrix to kill the number
of calls of matrix multiplication */ of calls of matrix multiplication */
if(transposedA == X_NOTRANS && a->order > 2 && b->order == 2){ if (transposedA == X_NOTRANS && a->order > 2 && b->order == 2) {
int ncolA = a->dimSize[a->order - 1]; int ncolA = a->dimSize[a->order - 1];
int ncolC = c->dimSize[c->order - 1]; int ncolC = c->dimSize[c->order - 1];
XTensor * a2 = NewTensor2DV2(a->unitNum/ncolA, -ncolA, a->dataType, a->devID, a->mem); XTensor * a2 = NewTensor2DV2(a->unitNum/ncolA, -ncolA, a->dataType, a->devID, a->mem);
...@@ -345,7 +345,6 @@ void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, ...@@ -345,7 +345,6 @@ void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
} }
/* call _MatrixMul function */ /* call _MatrixMul function */
......
...@@ -90,29 +90,9 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta) ...@@ -90,29 +90,9 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
/* when c != a, OpenBLAS needs to copy a to c first. This operation /* when c != a, OpenBLAS needs to copy a to c first. This operation
slow down the speed, so just use OpenBLAS when c == a */ slow down the speed, so just use OpenBLAS when c == a */
#if defined(USE_BLAS) #if defined(USE_BLAS)
if (c == a) if (c == a) {
AXPY(a->unitNum,beta,bp,1,cp,1); AXPY(a->unitNum,beta,bp,1,cp,1);
else { return;
int num = a->unitNum;
if (num % 4 == 0) {
for (int i = 0; i < num; i += 4) {
cp[i] = ap[i] + bp[i] * beta;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
cp[i + 2] = ap[i + 2] + bp[i + 2] * beta;
cp[i + 3] = ap[i + 3] + bp[i + 3] * beta;
}
}
else if (num % 2 == 0) {
for (int i = 0; i < num; i += 2) {
cp[i] = ap[i] + bp[i] * beta;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
}
}
else {
for (int i = 0; i < num; i++) {
cp[i] = ap[i] + bp[i] * beta;
}
}
} }
#else #else
/* unrolling */ /* unrolling */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论