Commit 51b4da42 by xiaotong

bug fixes in back propagation of matrix multiplication

parent 4e8872e9
...@@ -43,6 +43,8 @@ void XMathGrad::MakeGrad(XTensor * node) ...@@ -43,6 +43,8 @@ void XMathGrad::MakeGrad(XTensor * node)
GradMultiply(node); GradMultiply(node);
else if(operID == MATH_MATRIXMUL) else if(operID == MATH_MATRIXMUL)
GradMatrixMul(node); GradMatrixMul(node);
else if(operID == MATH_MATRIXMULBATCHED)
GradMatrixMulBatched(node);
else if (operID == MATH_LOG) else if (operID == MATH_LOG)
GradLog(node); GradLog(node);
else if (operID == MATH_POWER) else if (operID == MATH_POWER)
...@@ -273,13 +275,14 @@ void XMathGrad::GradMatrixMul(XTensor * node) ...@@ -273,13 +275,14 @@ void XMathGrad::GradMatrixMul(XTensor * node)
int dimsBackupC[MAX_TENSOR_DIM_NUM]; int dimsBackupC[MAX_TENSOR_DIM_NUM];
memcpy(dimsBackupA, a->dimSize, sizeof(int) * a->order); memcpy(dimsBackupA, a->dimSize, sizeof(int) * a->order);
memcpy(dimsBackupC, c->dimSize, sizeof(int) * c->order); memcpy(dimsBackupC, c->dimSize, sizeof(int) * c->order);
int dimsA[2] = {a->unitNum/a->GetDim(-1), a->GetDim(-1)};
int dimsC[2] = {c->unitNum/c->GetDim(-1), c->GetDim(-1)}; a->Reshape(a->unitNum/a->GetDim(-1), a->GetDim(-1));
a->Reshape(2, dimsA); c->Reshape(c->unitNum/c->GetDim(-1), c->GetDim(-1));
c->Reshape(2, dimsC); deda->Reshape(a->unitNum/a->GetDim(-1), a->GetDim(-1));
deda->Reshape(2, dimsA); dedc->Reshape(c->unitNum/c->GetDim(-1), c->GetDim(-1));
dedc->Reshape(2, dimsC);
GradMatrixMul(a, deda, transA, b, dedb, transB, dedc, alpha); GradMatrixMul(a, deda, transA, b, dedb, transB, dedc, alpha);
a->Reshape(orderBackupA, dimsBackupA); a->Reshape(orderBackupA, dimsBackupA);
c->Reshape(orderBackupC, dimsBackupC); c->Reshape(orderBackupC, dimsBackupC);
deda->Reshape(orderBackupA, dimsBackupA); deda->Reshape(orderBackupA, dimsBackupA);
...@@ -318,8 +321,9 @@ void XMathGrad::GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE tra ...@@ -318,8 +321,9 @@ void XMathGrad::GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE tra
/* c = a^T * b * \alpha */ /* c = a^T * b * \alpha */
else if(transA == X_TRANS && transB == X_NOTRANS){ else if(transA == X_TRANS && transB == X_NOTRANS){
/* dE/da = dE/dc * b^T * \alpha */ /* dE/da = (dE/dc * b^T)^T * \alpha
_MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha, 1.0F); = b * dE/dc^T * \alpha */
_MatrixMul(b, X_NOTRANS, dedc, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a * dE/dc * \alpha */ /* dE/db = a * dE/dc * \alpha */
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F); _MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
...@@ -331,19 +335,98 @@ void XMathGrad::GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE tra ...@@ -331,19 +335,98 @@ void XMathGrad::GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE tra
/* dE/da = dE/dc * b * \alpha */ /* dE/da = dE/dc * b * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha, 1.0F); _MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha, 1.0F);
/* dE/db = a^T * dE/dc * \alpha */ /* dE/db = (a^T * dE/dc)^T * \alpha
_MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F); = dE/dc^T * a * \alpha */
_MatrixMul(dedc, X_TRANS, a, X_NOTRANS, dedb, alpha, 1.0F);
} }
/* c = a^T * b^T * \alpha */ /* c = a^T * b^T * \alpha */
else if(transA == X_TRANS && transB == X_TRANS){ else if(transA == X_TRANS && transB == X_TRANS){
/* dE/da = dE/dc * b * \alpha */ /* dE/da = (dE/dc * b)^T * \alpha
_MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha, 1.0F); = b^T * dE/dc^T * \alpha */
_MatrixMul(b, X_TRANS, dedc, X_TRANS, deda, alpha, 1.0F);
/* dE/db = (a * dE/dc)^T * \alpha
= dE/dc^T * a^T * \alpha */
_MatrixMul(dedc, X_TRANS, a, X_TRANS, dedb, alpha, 1.0F);
}
}
/*
gradient for matrix multiply in batch mode.
for each batch: c_i = matmul(a_i, b_i) * \alpha
for c_i = matmul(a_i, b_i) * \alpha
we have
dE/da_i = dE/dc_i * b_i^T * \alpha
dE/db_i = a_i^T * dE/dc_i * \alpha
>> node - the node (c) for backward computation
*/
void XMathGrad::GradMatrixMulBatched(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 2, "Wrong input tensor number for MULTIPLY!");
CheckNTErrors(income.paramNum == 3, "Wrong parameter number for MULTIPLY!");
XTensor * a = income.tails[0];
XTensor * b = income.tails[1];
MATRIX_TRANS_TYPE transA = income.GetParamTrans(0);
MATRIX_TRANS_TYPE transB = income.GetParamTrans(1);
DTYPE alpha = income.GetParam(2);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
XTensor * c = node;
XTensor * dedc = node->grad;
XTensor * deda = a->grad;
XTensor * dedb = b->grad;
/* c = a * b * \alpha */
if(transA == X_NOTRANS && transB == X_NOTRANS){
/* dE/da = dE/dc * b^T * \alpha */
_MatrixMulBatched(dedc, X_NOTRANS, b, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a^T * dE/dc * \alpha */
_MatrixMulBatched(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
}
/* c = a^T * b * \alpha */
else if(transA == X_TRANS && transB == X_NOTRANS){
/* dE/da = (dE/dc * b^T)^T * \alpha
= b * dE/dc^T * \alpha */
_MatrixMulBatched(b, X_NOTRANS, dedc, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a * dE/dc * \alpha */ /* dE/db = a * dE/dc * \alpha */
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F); _MatrixMulBatched(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
} }
/* c = a * b^T * \alpha */
else if(transA == X_NOTRANS && transB == X_TRANS){
/* dE/da = dE/dc * b * \alpha */
_MatrixMulBatched(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha, 1.0F);
/* dE/db = (a^T * dE/dc)^T * \alpha
= dE/dc^T * a * \alpha */
_MatrixMulBatched(dedc, X_TRANS, a, X_NOTRANS, dedb, alpha, 1.0F);
}
/* c = a^T * b^T * \alpha */
else if(transA == X_TRANS && transB == X_TRANS){
/* dE/da = (dE/dc * b)^T * \alpha
= b^T * dE/dc^T * \alpha */
_MatrixMulBatched(b, X_TRANS, dedc, X_TRANS, deda, alpha, 1.0F);
/* dE/db = (a * dE/dc)^T * \alpha
= dE/dc^T * a^T * \alpha */
_MatrixMulBatched(dedc, X_TRANS, a, X_TRANS, dedb, alpha, 1.0F);
}
node->visitMark = NODE_FINISHED;
} }
/* /*
......
...@@ -49,20 +49,25 @@ private: ...@@ -49,20 +49,25 @@ private:
static static
void GradSumDim(XTensor * node); void GradSumDim(XTensor * node);
/* gradient for multiply (dot production): c = a * b */ /* gradient for multiply (dot production): c = a * b * \alpha */
static static
void GradMultiply(XTensor * node); void GradMultiply(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) */ /* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static static
void GradMatrixMul(XTensor * node); void GradMatrixMul(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) */ /* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static static
void GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE transA, void GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE transA,
XTensor * b, XTensor * dedb, MATRIX_TRANS_TYPE transB, XTensor * b, XTensor * dedb, MATRIX_TRANS_TYPE transB,
XTensor * dedc, DTYPE alpha); XTensor * dedc, DTYPE alpha);
/* gradient for matrix multiply in batch mode.
for each batch: c_i = matmul(a_i, b_i) * \alpha */
static
void GradMatrixMulBatched(XTensor * node);
/* gradient for log: c = log(a) */ /* gradient for log: c = log(a) */
static static
void GradLog(XTensor * node); void GradLog(XTensor * node);
......
...@@ -472,6 +472,27 @@ void XTensor::Reshape(const int myOrder, const int * myDimSize) ...@@ -472,6 +472,27 @@ void XTensor::Reshape(const int myOrder, const int * myDimSize)
memcpy(dimSizeRDI, dimsRDI, sizeof(int) * order); memcpy(dimSizeRDI, dimsRDI, sizeof(int) * order);
} }
/*
reshape the tensor to a vector
>> num - number of elements
*/
void XTensor::Reshape(const int num)
{
int dim = num;
Reshape(1, &dim);
}
/*
reshape the tensor to a matrix
>> rowNum - number of rows
>> colNum - number of columns
*/
void XTensor::Reshape(const int rowNum, const int colNum)
{
int dims[2] = {rowNum, colNum};
Reshape(2, dims);
}
/* get the number of items in the data array */ /* get the number of items in the data array */
int XTensor::GetSize() const int XTensor::GetSize() const
{ {
......
...@@ -229,6 +229,12 @@ public: ...@@ -229,6 +229,12 @@ public:
/* reshape the tensor */ /* reshape the tensor */
void Reshape(const int order, const int * myDimSize); void Reshape(const int order, const int * myDimSize);
/* reshape the tensor to a vector */
void Reshape(const int num);
/* reshape the tensor to a matrix */
void Reshape(const int rowNum, const int colNum);
/* get the number of items in the data array */ /* get the number of items in the data array */
int GetSize() const; int GetSize() const;
......
...@@ -150,11 +150,11 @@ void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -150,11 +150,11 @@ void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
XTensor * c, DTYPE alpha, DTYPE beta) XTensor * c, DTYPE alpha, DTYPE beta)
{ {
CheckNTErrors((a && b && c), "Empty input tensors!"); CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType), CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Input tensors should have the same data type!"); "Input tensors should have the same data type!");
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2), CheckNTErrors(a->order >= 2 && b->order >= 2 && c->order >= 2,
"Input tensors must have a order >= 2!"); "Input tensors must have a order >= 2!");
CheckNTErrors((a->order == b->order && a->order == c->order), CheckNTErrors(a->order == b->order && a->order == c->order,
"Input tensor and output tensor must have same order!"); "Input tensor and output tensor must have same order!");
...@@ -165,7 +165,7 @@ CheckNTErrors((a && b && c), "Empty input tensors!"); ...@@ -165,7 +165,7 @@ CheckNTErrors((a && b && c), "Empty input tensors!");
int cn = c->dimSizeRDI[1]; int cn = c->dimSizeRDI[1];
int cm = c->dimSizeRDI[0]; int cm = c->dimSizeRDI[0];
CheckNTErrors((am == bn && an == cn && bm == cm), "Unmatched tensors in multiplication!"); CheckNTErrors(am == bn && an == cn && bm == cm, "Unmatched tensors in multiplication!");
int aBlockSize = a->dimSizeRDI[0] * a->dimSizeRDI[1]; int aBlockSize = a->dimSizeRDI[0] * a->dimSizeRDI[1];
int bBlockSize = b->dimSizeRDI[0] * b->dimSizeRDI[1]; int bBlockSize = b->dimSizeRDI[0] * b->dimSizeRDI[1];
......
...@@ -185,8 +185,8 @@ void _SoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x, ...@@ -185,8 +185,8 @@ void _SoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
int leadDim, int leadDim,
LOSS_FUNCTION_NAME lossName) LOSS_FUNCTION_NAME lossName)
{ {
CheckNTErrors((dedx->isSparse == false), "The gradient tensor must be dense!"); CheckNTErrors(dedx->isSparse == false, "The gradient tensor must be dense!");
CheckNTErrors((gold != NULL), "Incorrect x gold standard tensor!"); CheckNTErrors(gold != NULL || lossName == NOLOSS, "Gold standard is required for computing loss!");
if(leadDim < 0) if(leadDim < 0)
leadDim = y->order - 1; leadDim = y->order - 1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论