Commit 51b4da42 by xiaotong

bug fixes in back propagation of matrix multiplication

parent 4e8872e9
......@@ -43,6 +43,8 @@ void XMathGrad::MakeGrad(XTensor * node)
GradMultiply(node);
else if(operID == MATH_MATRIXMUL)
GradMatrixMul(node);
else if(operID == MATH_MATRIXMULBATCHED)
GradMatrixMulBatched(node);
else if (operID == MATH_LOG)
GradLog(node);
else if (operID == MATH_POWER)
......@@ -273,13 +275,14 @@ void XMathGrad::GradMatrixMul(XTensor * node)
int dimsBackupC[MAX_TENSOR_DIM_NUM];
memcpy(dimsBackupA, a->dimSize, sizeof(int) * a->order);
memcpy(dimsBackupC, c->dimSize, sizeof(int) * c->order);
int dimsA[2] = {a->unitNum/a->GetDim(-1), a->GetDim(-1)};
int dimsC[2] = {c->unitNum/c->GetDim(-1), c->GetDim(-1)};
a->Reshape(2, dimsA);
c->Reshape(2, dimsC);
deda->Reshape(2, dimsA);
dedc->Reshape(2, dimsC);
a->Reshape(a->unitNum/a->GetDim(-1), a->GetDim(-1));
c->Reshape(c->unitNum/c->GetDim(-1), c->GetDim(-1));
deda->Reshape(a->unitNum/a->GetDim(-1), a->GetDim(-1));
dedc->Reshape(c->unitNum/c->GetDim(-1), c->GetDim(-1));
GradMatrixMul(a, deda, transA, b, dedb, transB, dedc, alpha);
a->Reshape(orderBackupA, dimsBackupA);
c->Reshape(orderBackupC, dimsBackupC);
deda->Reshape(orderBackupA, dimsBackupA);
......@@ -318,8 +321,9 @@ void XMathGrad::GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE tra
/* c = a^T * b * \alpha */
else if(transA == X_TRANS && transB == X_NOTRANS){
/* dE/da = dE/dc * b^T * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha, 1.0F);
/* dE/da = (dE/dc * b^T)^T * \alpha
= b * dE/dc^T * \alpha */
_MatrixMul(b, X_NOTRANS, dedc, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a * dE/dc * \alpha */
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
......@@ -331,19 +335,98 @@ void XMathGrad::GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE tra
/* dE/da = dE/dc * b * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha, 1.0F);
/* dE/db = a^T * dE/dc * \alpha */
_MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
/* dE/db = (a^T * dE/dc)^T * \alpha
= dE/dc^T * a * \alpha */
_MatrixMul(dedc, X_TRANS, a, X_NOTRANS, dedb, alpha, 1.0F);
}
/* c = a^T * b^T * \alpha */
else if(transA == X_TRANS && transB == X_TRANS){
/* dE/da = dE/dc * b * \alpha */
_MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha, 1.0F);
/* dE/da = (dE/dc * b)^T * \alpha
= b^T * dE/dc^T * \alpha */
_MatrixMul(b, X_TRANS, dedc, X_TRANS, deda, alpha, 1.0F);
/* dE/db = (a * dE/dc)^T * \alpha
= dE/dc^T * a^T * \alpha */
_MatrixMul(dedc, X_TRANS, a, X_TRANS, dedb, alpha, 1.0F);
}
}
/*
gradient for matrix multiply in batch mode.
for each batch: c_i = matmul(a_i, b_i) * \alpha
for c_i = matmul(a_i, b_i) * \alpha
we have
dE/da_i = dE/dc_i * b_i^T * \alpha
dE/db_i = a_i^T * dE/dc_i * \alpha
>> node - the node (c) for backward computation
*/
void XMathGrad::GradMatrixMulBatched(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 2, "Wrong input tensor number for MULTIPLY!");
CheckNTErrors(income.paramNum == 3, "Wrong parameter number for MULTIPLY!");
XTensor * a = income.tails[0];
XTensor * b = income.tails[1];
MATRIX_TRANS_TYPE transA = income.GetParamTrans(0);
MATRIX_TRANS_TYPE transB = income.GetParamTrans(1);
DTYPE alpha = income.GetParam(2);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
XTensor * c = node;
XTensor * dedc = node->grad;
XTensor * deda = a->grad;
XTensor * dedb = b->grad;
/* c = a * b * \alpha */
if(transA == X_NOTRANS && transB == X_NOTRANS){
/* dE/da = dE/dc * b^T * \alpha */
_MatrixMulBatched(dedc, X_NOTRANS, b, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a^T * dE/dc * \alpha */
_MatrixMulBatched(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
}
/* c = a^T * b * \alpha */
else if(transA == X_TRANS && transB == X_NOTRANS){
/* dE/da = (dE/dc * b^T)^T * \alpha
= b * dE/dc^T * \alpha */
_MatrixMulBatched(b, X_NOTRANS, dedc, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a * dE/dc * \alpha */
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
_MatrixMulBatched(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
}
/* c = a * b^T * \alpha */
else if(transA == X_NOTRANS && transB == X_TRANS){
/* dE/da = dE/dc * b * \alpha */
_MatrixMulBatched(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha, 1.0F);
/* dE/db = (a^T * dE/dc)^T * \alpha
= dE/dc^T * a * \alpha */
_MatrixMulBatched(dedc, X_TRANS, a, X_NOTRANS, dedb, alpha, 1.0F);
}
/* c = a^T * b^T * \alpha */
else if(transA == X_TRANS && transB == X_TRANS){
/* dE/da = (dE/dc * b)^T * \alpha
= b^T * dE/dc^T * \alpha */
_MatrixMulBatched(b, X_TRANS, dedc, X_TRANS, deda, alpha, 1.0F);
/* dE/db = (a * dE/dc)^T * \alpha
= dE/dc^T * a^T * \alpha */
_MatrixMulBatched(dedc, X_TRANS, a, X_TRANS, dedb, alpha, 1.0F);
}
node->visitMark = NODE_FINISHED;
}
/*
......
......@@ -49,20 +49,25 @@ private:
static
void GradSumDim(XTensor * node);
/* gradient for multiply (dot production): c = a * b */
/* gradient for multiply (dot production): c = a * b * \alpha */
static
void GradMultiply(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) */
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static
void GradMatrixMul(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) */
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static
void GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE transA,
XTensor * b, XTensor * dedb, MATRIX_TRANS_TYPE transB,
XTensor * dedc, DTYPE alpha);
/* gradient for matrix multiply in batch mode.
for each batch: c_i = matmul(a_i, b_i) * \alpha */
static
void GradMatrixMulBatched(XTensor * node);
/* gradient for log: c = log(a) */
static
void GradLog(XTensor * node);
......
......@@ -472,6 +472,27 @@ void XTensor::Reshape(const int myOrder, const int * myDimSize)
memcpy(dimSizeRDI, dimsRDI, sizeof(int) * order);
}
/*
reshape the tensor to a vector
>> num - number of elements
*/
void XTensor::Reshape(const int num)
{
int dim = num;
Reshape(1, &dim);
}
/*
reshape the tensor to a matrix
>> rowNum - number of rows
>> colNum - number of columns
*/
void XTensor::Reshape(const int rowNum, const int colNum)
{
int dims[2] = {rowNum, colNum};
Reshape(2, dims);
}
/* get the number of items in the data array */
int XTensor::GetSize() const
{
......
......@@ -229,6 +229,12 @@ public:
/* reshape the tensor */
void Reshape(const int order, const int * myDimSize);
/* reshape the tensor to a vector */
void Reshape(const int num);
/* reshape the tensor to a matrix */
void Reshape(const int rowNum, const int colNum);
/* get the number of items in the data array */
int GetSize() const;
......
......@@ -150,12 +150,12 @@ void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
XTensor * c, DTYPE alpha, DTYPE beta)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
"Input tensors should have the same data type!");
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2),
"Input tensors must have a order >= 2!");
CheckNTErrors((a->order == b->order && a->order == c->order),
"Input tensor and output tensor must have same order!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Input tensors should have the same data type!");
CheckNTErrors(a->order >= 2 && b->order >= 2 && c->order >= 2,
"Input tensors must have a order >= 2!");
CheckNTErrors(a->order == b->order && a->order == c->order,
"Input tensor and output tensor must have same order!");
int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1];
......@@ -165,7 +165,7 @@ CheckNTErrors((a && b && c), "Empty input tensors!");
int cn = c->dimSizeRDI[1];
int cm = c->dimSizeRDI[0];
CheckNTErrors((am == bn && an == cn && bm == cm), "Unmatched tensors in multiplication!");
CheckNTErrors(am == bn && an == cn && bm == cm, "Unmatched tensors in multiplication!");
int aBlockSize = a->dimSizeRDI[0] * a->dimSizeRDI[1];
int bBlockSize = b->dimSizeRDI[0] * b->dimSizeRDI[1];
......
......@@ -185,8 +185,8 @@ void _SoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
int leadDim,
LOSS_FUNCTION_NAME lossName)
{
CheckNTErrors((dedx->isSparse == false), "The gradient tensor must be dense!");
CheckNTErrors((gold != NULL), "Incorrect x gold standard tensor!");
CheckNTErrors(dedx->isSparse == false, "The gradient tensor must be dense!");
CheckNTErrors(gold != NULL || lossName == NOLOSS, "Gold standard is required for computing loss!");
if(leadDim < 0)
leadDim = y->order - 1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论