Commit 11a57e99 by xiaotong

backward propagation for MatrixMul

parent e18e6358
...@@ -55,9 +55,9 @@ int main( int argc, const char ** argv ) ...@@ -55,9 +55,9 @@ int main( int argc, const char ** argv )
b.SetZeroAll(); b.SetZeroAll();
c.SetZeroAll(); c.SetZeroAll();
SetDataFixed(a, 1.0F); SetDataFixed(a, 0.1F);
a.Set2D(3.0F, 1, 0); a.Set2D(0.3F, 1, 0);
a.Set2D(4.0F, 1, 1); a.Set2D(0.4F, 1, 1);
b = a + a; b = a + a;
c = HTanH(MMul(a, b)); c = HTanH(MMul(a, b));
......
...@@ -116,10 +116,13 @@ void XMathGrad::GradMatrixMul(XTensor * node) ...@@ -116,10 +116,13 @@ void XMathGrad::GradMatrixMul(XTensor * node)
{ {
XLink &income = node->income; XLink &income = node->income;
CheckNTErrors(income.tailNum == 2, "Wrong input tensor number for MULTIPLY!"); CheckNTErrors(income.tailNum == 2, "Wrong input tensor number for MULTIPLY!");
CheckNTErrors(income.paramNum == 3, "Wrong parameter number for MULTIPLY!");
XTensor * a = income.tails[0]; XTensor * a = income.tails[0];
XTensor * b = income.tails[1]; XTensor * b = income.tails[1];
DTYPE alpha = income.GetParam(0); MATRIX_TRANS_TYPE transA = income.GetParamTrans(0);
MATRIX_TRANS_TYPE transB = income.GetParamTrans(1);
DTYPE alpha = income.GetParam(2);
XNoder::MakeGrad(a); XNoder::MakeGrad(a);
XNoder::MakeGrad(b); XNoder::MakeGrad(b);
...@@ -128,11 +131,45 @@ void XMathGrad::GradMatrixMul(XTensor * node) ...@@ -128,11 +131,45 @@ void XMathGrad::GradMatrixMul(XTensor * node)
XTensor * deda = a->grad; XTensor * deda = a->grad;
XTensor * dedb = b->grad; XTensor * dedb = b->grad;
/* c = a * b * \alpha */
if(transA == X_NOTRANS && transB == X_NOTRANS){
/* dE/da = dE/dc * b^T * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha);
/* dE/db = a^T * dE/dc * \alpha */
_MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha);
}
/* c = a^T * b * \alpha */
else if(transA == X_TRANS && transB == X_NOTRANS){
/* dE/da = dE/dc * b^T * \alpha */ /* dE/da = dE/dc * b^T * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha); _MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha);
/* dE/db = a * dE/dc * \alpha */
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha);
}
/* c = a * b^T * \alpha */
else if(transA == X_NOTRANS && transB == X_TRANS){
/* dE/da = dE/dc * b * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha);
/* dE/db = a^T * dE/dc * \alpha */ /* dE/db = a^T * dE/dc * \alpha */
_MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha); _MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha);
}
/* c = a^T * b^T * \alpha */
else if(transA == X_TRANS && transB == X_TRANS){
/* dE/da = dE/dc * b * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha);
/* dE/db = a * dE/dc * \alpha */
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha);
}
} }
} }
...@@ -240,7 +240,8 @@ void XLink::AddParam(void * param, int size) ...@@ -240,7 +240,8 @@ void XLink::AddParam(void * param, int size)
/* /*
get a paramter in default type get a paramter in default type
>> i - id the of the parameter >> i - id of the parameter
<< return - the parameter in default type
*/ */
DTYPE XLink::GetParam(int i) DTYPE XLink::GetParam(int i)
{ {
...@@ -251,7 +252,8 @@ DTYPE XLink::GetParam(int i) ...@@ -251,7 +252,8 @@ DTYPE XLink::GetParam(int i)
/* /*
get a paramter in integer get a paramter in integer
>> i - id the of the parameter >> i - id of the parameter
<< return - the parameter in integer
*/ */
int XLink::GetParamInt(int i) int XLink::GetParamInt(int i)
{ {
...@@ -261,6 +263,18 @@ int XLink::GetParamInt(int i) ...@@ -261,6 +263,18 @@ int XLink::GetParamInt(int i)
} }
/* /*
get a parameter in MATRIX_TRANS_TYPE
>> i - id of the parameter
<< return - the parameter in MATRIX_TRANS_TYPE
*/
MATRIX_TRANS_TYPE XLink::GetParamTrans(int i)
{
CheckNTErrors(params != NULL, "parameter array cannot be empty!");
char * p = (char*)params + i * paramSize;
return *(MATRIX_TRANS_TYPE*)p;
}
/*
create a hyperedge with two input tensors and a output tensor create a hyperedge with two input tensors and a output tensor
>> t1 - a tail tensor >> t1 - a tail tensor
>> t2 - another tail tensor >> t2 - another tail tensor
......
...@@ -128,6 +128,9 @@ struct XLink ...@@ -128,6 +128,9 @@ struct XLink
/* get a paramter in integer */ /* get a paramter in integer */
int GetParamInt(int i); int GetParamInt(int i);
/* get a parameter in MATRIX_TRANS_TYPE */
MATRIX_TRANS_TYPE GetParamTrans(int i);
/* create a hyper edge with two input tensors and a output tensor */ /* create a hyper edge with two input tensors and a output tensor */
static static
void MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id); void MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论