Commit 11a57e99 by xiaotong

backward propagation for MatrixMul

parent e18e6358
......@@ -55,9 +55,9 @@ int main( int argc, const char ** argv )
b.SetZeroAll();
c.SetZeroAll();
SetDataFixed(a, 1.0F);
a.Set2D(3.0F, 1, 0);
a.Set2D(4.0F, 1, 1);
SetDataFixed(a, 0.1F);
a.Set2D(0.3F, 1, 0);
a.Set2D(0.4F, 1, 1);
b = a + a;
c = HTanH(MMul(a, b));
......
......@@ -116,10 +116,13 @@ void XMathGrad::GradMatrixMul(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 2, "Wrong input tensor number for MULTIPLY!");
CheckNTErrors(income.paramNum == 3, "Wrong parameter number for MULTIPLY!");
XTensor * a = income.tails[0];
XTensor * b = income.tails[1];
DTYPE alpha = income.GetParam(0);
MATRIX_TRANS_TYPE transA = income.GetParamTrans(0);
MATRIX_TRANS_TYPE transB = income.GetParamTrans(1);
DTYPE alpha = income.GetParam(2);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
......@@ -127,12 +130,46 @@ void XMathGrad::GradMatrixMul(XTensor * node)
XTensor * dedc = node->grad;
XTensor * deda = a->grad;
XTensor * dedb = b->grad;
/* c = a * b * \alpha */
if(transA == X_NOTRANS && transB == X_NOTRANS){
/* dE/da = dE/dc * b^T * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha);
/* dE/da = dE/dc * b^T * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha);
/* dE/db = a^T * dE/dc * \alpha */
_MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha);
/* dE/db = a^T * dE/dc * \alpha */
_MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha);
}
/* c = a^T * b * \alpha */
else if(transA == X_TRANS && transB == X_NOTRANS){
/* dE/da = dE/dc * b^T * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha);
/* dE/db = a * dE/dc * \alpha */
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha);
}
/* c = a * b^T * \alpha */
else if(transA == X_NOTRANS && transB == X_TRANS){
/* dE/da = dE/dc * b * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha);
/* dE/db = a^T * dE/dc * \alpha */
_MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha);
}
/* c = a^T * b^T * \alpha */
else if(transA == X_TRANS && transB == X_TRANS){
/* dE/da = dE/dc * b * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha);
/* dE/db = a * dE/dc * \alpha */
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha);
}
}
}
\ No newline at end of file
}
......@@ -240,7 +240,8 @@ void XLink::AddParam(void * param, int size)
/*
get a paramter in default type
>> i - id the of the parameter
>> i - id of the parameter
<< return - the parameter in default type
*/
DTYPE XLink::GetParam(int i)
{
......@@ -251,7 +252,8 @@ DTYPE XLink::GetParam(int i)
/*
get a paramter in integer
>> i - id the of the parameter
>> i - id of the parameter
<< return - the parameter in integer
*/
int XLink::GetParamInt(int i)
{
......@@ -259,6 +261,18 @@ int XLink::GetParamInt(int i)
char * p = (char*)params + i * paramSize;
return *(int*)p;
}
/*
get a parameter in MATRIX_TRANS_TYPE
>> i - id of the parameter
<< return - the parameter in MATRIX_TRANS_TYPE
*/
MATRIX_TRANS_TYPE XLink::GetParamTrans(int i)
{
CheckNTErrors(params != NULL, "parameter array cannot be empty!");
char * p = (char*)params + i * paramSize;
return *(MATRIX_TRANS_TYPE*)p;
}
/*
create a hyperedge with two input tensors and a output tensor
......
......@@ -127,6 +127,9 @@ struct XLink
/* get a paramter in integer */
int GetParamInt(int i);
/* get a parameter in MATRIX_TRANS_TYPE */
MATRIX_TRANS_TYPE GetParamTrans(int i);
/* create a hyper edge with two input tensors and a output tensor */
static
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论