Commit 9b87b785 by xiaotong

Back propagation for Unsqueeze

parent 1226160d
...@@ -41,6 +41,8 @@ void XShapeGrad::MakeGrad(XTensor * node) ...@@ -41,6 +41,8 @@ void XShapeGrad::MakeGrad(XTensor * node)
GradMerge(node); GradMerge(node);
else if(operID == SHAPE_MERGE_LIST) else if(operID == SHAPE_MERGE_LIST)
GradMergeList(node); GradMergeList(node);
else if(operID == SHAPE_UNSQUEEZE)
GradUnsqueeze(node);
else{ else{
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
} }
...@@ -142,6 +144,7 @@ dE/da = dE/dc_{split_0} ...@@ -142,6 +144,7 @@ dE/da = dE/dc_{split_0}
dE/db = dE/dc_{split_1} dE/db = dE/dc_{split_1}
i.e., i.e.,
list(dE/da, dE/db, ...) = split(dE/dc) list(dE/da, dE/db, ...) = split(dE/dc)
>> node - the node (c) for backward computation
*/ */
void XShapeGrad::GradMergeList(XTensor * node) void XShapeGrad::GradMergeList(XTensor * node)
{ {
...@@ -200,9 +203,9 @@ void XShapeGrad::GradMergeList(XTensor * node) ...@@ -200,9 +203,9 @@ void XShapeGrad::GradMergeList(XTensor * node)
/* gradient accumulation for each split */ /* gradient accumulation for each split */
for(int i = 0; i < smalls.count; i++){ for(int i = 0; i < smalls.count; i++){
XTensor * smallGrad = (XTensor*)smallsGrad.Get(i); XTensor * inputGrad = (XTensor*)smallsGrad.Get(i);
gradSmall.data = (char*)gradSplit.data + i * last->unitNum * last->unitSize; gradSmall.data = (char*)gradSplit.data + i * last->unitNum * last->unitSize;
_Sum(smallGrad, &gradSmall, smallGrad); _Sum(inputGrad, &gradSmall, inputGrad);
} }
gradSmall.data = NULL; gradSmall.data = NULL;
...@@ -210,4 +213,30 @@ void XShapeGrad::GradMergeList(XTensor * node) ...@@ -210,4 +213,30 @@ void XShapeGrad::GradMergeList(XTensor * node)
} }
} }
/*
gradient for unsqueezing a tensor
for
c = unsqueeze(a)
we have
dE/da = reduecesum(dE/dc)
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradUnsqueeze(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for UNSQUEEZE!");
XTensor * output = node;
XTensor * input = income.tails[0];
XNoder::MakeGrad(input);
int dim = income.GetParamInt(0);
int dSize = income.GetParamInt(1);
CheckNTErrors(dSize == output->GetDim(dim), "Wrong dim size for UNSQUEEZE!");
CheckNTErrors(output->unitNum = input->unitNum * dSize, "Wrong tensor size!");
_ReduceSum(output->grad, input->grad, dim);
}
} }
\ No newline at end of file
...@@ -48,6 +48,10 @@ private: ...@@ -48,6 +48,10 @@ private:
/* gradient for merging a list of tensors : c = merge(list(a, b, ...)) */ /* gradient for merging a list of tensors : c = merge(list(a, b, ...)) */
static static
void GradMergeList(XTensor * node); void GradMergeList(XTensor * node);
/* gradient for unsqueezing a tensor : c = unsqueeze(a) */
static
void GradUnsqueeze(XTensor * node);
}; };
} }
......
...@@ -32,6 +32,7 @@ void XNoder::MakeGrad(XTensor * node) ...@@ -32,6 +32,7 @@ void XNoder::MakeGrad(XTensor * node)
if(!XTensor::IsIdentical(node, node->grad)){ if(!XTensor::IsIdentical(node, node->grad)){
delete node->grad; delete node->grad;
node->grad = NewTensor(node); node->grad = NewTensor(node);
node->grad->SetZeroAll();
} }
} }
......
...@@ -129,6 +129,11 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize) ...@@ -129,6 +129,11 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize)
/* call _Unsqueeze function */ /* call _Unsqueeze function */
_Unsqueeze(&a, &b, dim, dSize); _Unsqueeze(&a, &b, dim, dSize);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE);
XLink::AddParamToHeadInt(&b, dim);
XLink::AddParamToHeadInt(&b, dSize);
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论