Commit 4e8872e9 by xiaotong

bug fixes in matrix multiplication

parent f21e1b48
...@@ -259,16 +259,58 @@ void XMathGrad::GradMatrixMul(XTensor * node) ...@@ -259,16 +259,58 @@ void XMathGrad::GradMatrixMul(XTensor * node)
XNoder::MakeGrad(a); XNoder::MakeGrad(a);
XNoder::MakeGrad(b); XNoder::MakeGrad(b);
XTensor * c = node;
XTensor * dedc = node->grad; XTensor * dedc = node->grad;
XTensor * deda = a->grad; XTensor * deda = a->grad;
XTensor * dedb = b->grad; XTensor * dedb = b->grad;
if(deda->order == 2 && dedb->order == 2)
GradMatrixMul(a, deda, transA, b, dedb, transB, dedc, alpha);
else if(transA == X_NOTRANS && deda->order > 2 && dedb->order == 2){
int orderBackupA = a->order;
int orderBackupC = c->order;
int dimsBackupA[MAX_TENSOR_DIM_NUM];
int dimsBackupC[MAX_TENSOR_DIM_NUM];
memcpy(dimsBackupA, a->dimSize, sizeof(int) * a->order);
memcpy(dimsBackupC, c->dimSize, sizeof(int) * c->order);
int dimsA[2] = {a->unitNum/a->GetDim(-1), a->GetDim(-1)};
int dimsC[2] = {c->unitNum/c->GetDim(-1), c->GetDim(-1)};
a->Reshape(2, dimsA);
c->Reshape(2, dimsC);
deda->Reshape(2, dimsA);
dedc->Reshape(2, dimsC);
GradMatrixMul(a, deda, transA, b, dedb, transB, dedc, alpha);
a->Reshape(orderBackupA, dimsBackupA);
c->Reshape(orderBackupC, dimsBackupC);
deda->Reshape(orderBackupA, dimsBackupA);
dedc->Reshape(orderBackupC, dimsBackupC);
}
else{
ShowNTErrors("TODO!");
}
node->visitMark = NODE_FINISHED;
}
/*
gradient for matrix multiply: c = matmul(a, b) * \alpha
>> a - as it is
>> deda - dE/da
>> b - as it is
>> dedb - dE/db
>> dedc - dE/dc
>> alpha - the scalar
*/
void XMathGrad::GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE transA,
XTensor * b, XTensor * dedb, MATRIX_TRANS_TYPE transB,
XTensor * dedc, DTYPE alpha)
{
/* c = a * b * \alpha */ /* c = a * b * \alpha */
if(transA == X_NOTRANS && transB == X_NOTRANS){ if(transA == X_NOTRANS && transB == X_NOTRANS){
/* dE/da = dE/dc * b^T * \alpha */ /* dE/da = dE/dc * b^T * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha, 1.0F); _MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a^T * dE/dc * \alpha */ /* dE/db = a^T * dE/dc * \alpha */
_MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F); _MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
} }
...@@ -302,8 +344,6 @@ void XMathGrad::GradMatrixMul(XTensor * node) ...@@ -302,8 +344,6 @@ void XMathGrad::GradMatrixMul(XTensor * node)
/* dE/db = a * dE/dc * \alpha */ /* dE/db = a * dE/dc * \alpha */
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F); _MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
......
...@@ -56,6 +56,12 @@ private: ...@@ -56,6 +56,12 @@ private:
/* gradient for matrix multiply: c = matmul(a, b) */ /* gradient for matrix multiply: c = matmul(a, b) */
static static
void GradMatrixMul(XTensor * node); void GradMatrixMul(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) */
static
void GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE transA,
XTensor * b, XTensor * dedb, MATRIX_TRANS_TYPE transB,
XTensor * dedc, DTYPE alpha);
/* gradient for log: c = log(a) */ /* gradient for log: c = log(a) */
static static
...@@ -128,4 +134,4 @@ private: ...@@ -128,4 +134,4 @@ private:
} }
#endif #endif
\ No newline at end of file
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
namespace transformer namespace transformer
{ {
void LoadParamString(int argc, const char ** argv, const char * name, char * p, char * defaultP) void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP)
{ {
char vname[128]; char vname[128];
vname[0] = '-'; vname[0] = '-';
...@@ -108,4 +108,4 @@ void ShowParams(int argc, const char ** argv) ...@@ -108,4 +108,4 @@ void ShowParams(int argc, const char ** argv)
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
} }
\ No newline at end of file
...@@ -28,7 +28,7 @@ namespace transformer ...@@ -28,7 +28,7 @@ namespace transformer
{ {
/* load arguments */ /* load arguments */
void LoadParamString(int argc, const char ** argv, const char * name, char * p, char * defaultP); void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP);
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP); void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP);
void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP); void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP);
void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP); void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP);
...@@ -38,4 +38,4 @@ void ShowParams(int argc, const char ** argv); ...@@ -38,4 +38,4 @@ void ShowParams(int argc, const char ** argv);
} }
#endif #endif
\ No newline at end of file
...@@ -53,11 +53,29 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -53,11 +53,29 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB, const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta, XPRunner * parallelRunner) XTensor * c, DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{ {
CheckNTErrors((a && b && c), "Empty input tensors!"); CheckNTErrors(a && b && c, "Empty input tensors!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType), CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Input tensors should have the same data type!"); "Input tensors should have the same data type!");
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2), CheckNTErrors(a->order >= 2 && b->order >= 2 && c->order >= 2,
"Input tensors must have a order >= 2!"); "Input tensors must have a order >= 2!");
CheckNTErrors(c->order == a->order + b->order - 2, "wrong tensor order")
/* we transform a higher order tensor to a matrix to kill the number
of calls of matrix multiplication */
if(transposedA == X_NOTRANS && a->order > 2 && b->order == 2){
int ncolA = a->dimSize[a->order - 1];
int ncolC = c->dimSize[c->order - 1];
XTensor * a2 = NewTensor2D(a->unitNum/ncolA, -ncolA, a->dataType, a->devID, a->mem);
XTensor * c2 = NewTensor2D(c->unitNum/ncolC, -ncolC, c->dataType, c->devID, c->mem);
a2->data = a->data;
c2->data = c->data;
_MatrixMul2D(a2, transposedA, b, transposedB, c2, alpha, beta, parallelRunner);
a2->data = NULL;
c2->data = NULL;
delete a2;
delete c2;
return;
}
int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1]; int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1];
int am = transposedA == X_TRANS ? a->dimSizeRDI[1] : a->dimSizeRDI[0]; int am = transposedA == X_TRANS ? a->dimSizeRDI[1] : a->dimSizeRDI[0];
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论