Commit 9b87b785 by xiaotong

Back propagation for Unsqueeze

parent 1226160d
......@@ -41,6 +41,8 @@ void XShapeGrad::MakeGrad(XTensor * node)
GradMerge(node);
else if(operID == SHAPE_MERGE_LIST)
GradMergeList(node);
else if(operID == SHAPE_UNSQUEEZE)
GradUnsqueeze(node);
else{
ShowNTErrors("TODO!");
}
......@@ -142,6 +144,7 @@ dE/da = dE/dc_{split_0}
dE/db = dE/dc_{split_1}
i.e.,
list(dE/da, dE/db, ...) = split(dE/dc)
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradMergeList(XTensor * node)
{
......@@ -200,9 +203,9 @@ void XShapeGrad::GradMergeList(XTensor * node)
/* gradient accumulation for each split */
for(int i = 0; i < smalls.count; i++){
XTensor * smallGrad = (XTensor*)smallsGrad.Get(i);
XTensor * inputGrad = (XTensor*)smallsGrad.Get(i);
gradSmall.data = (char*)gradSplit.data + i * last->unitNum * last->unitSize;
_Sum(smallGrad, &gradSmall, smallGrad);
_Sum(inputGrad, &gradSmall, inputGrad);
}
gradSmall.data = NULL;
......@@ -210,4 +213,30 @@ void XShapeGrad::GradMergeList(XTensor * node)
}
}
/*
gradient for unsqueezing a tensor
for
c = unsqueeze(a)
we have
dE/da = reduecesum(dE/dc)
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradUnsqueeze(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for UNSQUEEZE!");
XTensor * output = node;
XTensor * input = income.tails[0];
XNoder::MakeGrad(input);
int dim = income.GetParamInt(0);
int dSize = income.GetParamInt(1);
CheckNTErrors(dSize == output->GetDim(dim), "Wrong dim size for UNSQUEEZE!");
CheckNTErrors(output->unitNum = input->unitNum * dSize, "Wrong tensor size!");
_ReduceSum(output->grad, input->grad, dim);
}
}
\ No newline at end of file
......@@ -48,6 +48,10 @@ private:
/* gradient for merging a list of tensors : c = merge(list(a, b, ...)) */
static
void GradMergeList(XTensor * node);
/* gradient for unsqueezing a tensor : c = unsqueeze(a) */
static
void GradUnsqueeze(XTensor * node);
};
}
......
......@@ -32,6 +32,7 @@ void XNoder::MakeGrad(XTensor * node)
if(!XTensor::IsIdentical(node, node->grad)){
delete node->grad;
node->grad = NewTensor(node);
node->grad->SetZeroAll();
}
}
......
......@@ -129,6 +129,11 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize)
/* call _Unsqueeze function */
_Unsqueeze(&a, &b, dim, dSize);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE);
XLink::AddParamToHeadInt(&b, dim);
XLink::AddParamToHeadInt(&b, dSize);
/* destroy variables */
delete[] dimSize;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论