Commit ea3f2a21 by xiaotong

new code for Split

parent 7ac8e731
...@@ -63,6 +63,8 @@ void XFuncGrad::MakeGrad(XTensor * node) ...@@ -63,6 +63,8 @@ void XFuncGrad::MakeGrad(XTensor * node)
else{ else{
ShowNTErrors("Wrong activation function type!"); ShowNTErrors("Wrong activation function type!");
} }
node->visitMark = NODE_FINISHED;
} }
/* indicates whether the node is for an activation function */ /* indicates whether the node is for an activation function */
......
...@@ -75,6 +75,8 @@ void XMathGrad::GradSum(XTensor * node) ...@@ -75,6 +75,8 @@ void XMathGrad::GradSum(XTensor * node)
_Sum(a->grad, node->grad, a->grad); _Sum(a->grad, node->grad, a->grad);
_Sum(b->grad, node->grad, b->grad, beta); _Sum(b->grad, node->grad, b->grad, beta);
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -99,6 +101,8 @@ void XMathGrad::GradMultiply(XTensor * node) ...@@ -99,6 +101,8 @@ void XMathGrad::GradMultiply(XTensor * node)
CheckNTErrors(XTensor::IsSameShaped(a, b), "Wrong sized input tensors!"); CheckNTErrors(XTensor::IsSameShaped(a, b), "Wrong sized input tensors!");
_Multiply(node->grad, b, a->grad, 1.0F); _Multiply(node->grad, b, a->grad, 1.0F);
_Multiply(node->grad, a, b->grad, 1.0F); _Multiply(node->grad, a, b->grad, 1.0F);
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -167,6 +171,8 @@ void XMathGrad::GradMatrixMul(XTensor * node) ...@@ -167,6 +171,8 @@ void XMathGrad::GradMatrixMul(XTensor * node)
/* dE/db = a * dE/dc * \alpha */ /* dE/db = a * dE/dc * \alpha */
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F); _MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
} }
node->visitMark = NODE_FINISHED;
} }
} }
...@@ -55,6 +55,13 @@ bool XShapeGrad::IsShapeOP(XTensor * node) ...@@ -55,6 +55,13 @@ bool XShapeGrad::IsShapeOP(XTensor * node)
return (income.typeID & DATA_BASE) != 0; return (income.typeID & DATA_BASE) != 0;
} }
/* post processing of a node */
void XShapeGrad::PostProcessing(XTensor * node, int typeID)
{
if(typeID == SHAPE_SPLIT_LIST)
GradSplitListPost(node);
}
/* /*
gradient for merge gradient for merge
for for
...@@ -134,6 +141,8 @@ void XShapeGrad::GradMerge(XTensor * node) ...@@ -134,6 +141,8 @@ void XShapeGrad::GradMerge(XTensor * node)
gradInputSmall.data = NULL; gradInputSmall.data = NULL;
delete[] dims; delete[] dims;
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -213,6 +222,87 @@ void XShapeGrad::GradMergeList(XTensor * node) ...@@ -213,6 +222,87 @@ void XShapeGrad::GradMergeList(XTensor * node)
gradSmall.data = NULL; gradSmall.data = NULL;
delete[] dims; delete[] dims;
} }
node->visitMark = NODE_FINISHED;
}
/*
gradient computation for split:
for
c = split(a)
we have
dE/da = merge(dE/dc)
>> node - the node (c) for backward computation
*/
void GradSplit(XTensor * node)
{
XLink &income = node->income;
XTensor * input = income.tails[0];
int whereToSplit = income.GetParamInt(0);
int splitNum = income.GetParamInt(1);
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SPLIT!");
CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!");
CheckNTErrors(splitNum == node->dimSize[0], "Wrong split number!");
XNoder::MakeGrad(input);
/* we can simply merge the gradient tensor
if the input is used in spliting only */
if(input->outgo.tailNum == 1)
_Merge(node->grad, input->grad, whereToSplit + 1, 0);
/* if the tensor is used somewhere else, we need another SUM
for gradient accumulation */
else{
int * dims = new int[input->order];
memcpy(dims, input->dimSize, sizeof(int) * input->order);
dims[0] = -dims[0];
XTensor inputGradTMP(input->order, dims,
input->dataType, input->denseRatio,
input->devID, input->mem);
_Merge(node->grad, &inputGradTMP, whereToSplit + 1, 0);
_Sum(input->grad, &inputGradTMP, input->grad);
delete[] dims;
}
node->visitMark = NODE_FINISHED;
}
/*
gradient computation for spliting
where we return the list of the splits
for
list(c_1, ...) = split(a)
we have
dE/da = merge(dE/c_1, ...)
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradSplitList(XTensor * node)
{
XLink &income = node->income;
XTensor * input = income.tails[0];
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SPLIT!");
CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!");
node->visitMark = NODE_DOING;
}
/*
gradient computation for spliting where we return
the list of the splits : list(c_1, ...) = split(a).
this method is called only when all nodes of spliting
have been processed. We do this in a post-processing
manner because we can fuze multiple memory copy jobs
one time. This is good for system speed up.
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradSplitListPost(XTensor * node)
{
} }
/* /*
...@@ -239,6 +329,8 @@ void XShapeGrad::GradUnsqueeze(XTensor * node) ...@@ -239,6 +329,8 @@ void XShapeGrad::GradUnsqueeze(XTensor * node)
CheckNTErrors(output->unitNum = input->unitNum * dSize, "Wrong tensor size!"); CheckNTErrors(output->unitNum = input->unitNum * dSize, "Wrong tensor size!");
_ReduceSum(output->grad, input->grad, dim); _ReduceSum(output->grad, input->grad, dim);
node->visitMark = NODE_FINISHED;
} }
} }
\ No newline at end of file
...@@ -40,18 +40,37 @@ public: ...@@ -40,18 +40,37 @@ public:
static static
bool IsShapeOP(XTensor * node); bool IsShapeOP(XTensor * node);
/* post processing of a node */
static
void PostProcessing(XTensor * node, int typeId);
private: private:
/* gradient for merge: c = merge(a, b, ...) */ /* gradient computation for merge: c = merge(a, b, ...) */
static static
void GradMerge(XTensor * node); void GradMerge(XTensor * node);
/* gradient for merging a list of tensors : c = merge(list(a, b, ...)) */ /* gradient computation for merging a list of tensors : c = merge(list(a, b, ...)) */
static static
void GradMergeList(XTensor * node); void GradMergeList(XTensor * node);
/* gradient for unsqueezing a tensor : c = unsqueeze(a) */ /* gradient computation for split: c = split(a) */
static
void GradSplit(XTensor * node);
/* gradient computation for spliting where we return the list of the splits : list(c_1, ...) = split(a) */
static
void GradSplitList(XTensor * node);
/* gradient computation for spliting where we return the list of the splits : list(c_1, ...) = split(a).
this method is called only when all nodes of spliting have been processed. We do this in a post-processing
manner because we can fuze multiple memory copy jobs one time. This is good for system speed up. */
static
void GradSplitListPost(XTensor * node);
/* gradient computation for unsqueezing a tensor : c = unsqueeze(a) */
static static
void GradUnsqueeze(XTensor * node); void GradUnsqueeze(XTensor * node);
}; };
} }
......
...@@ -176,6 +176,10 @@ void XNet::BackwardNode(XTensor * node) ...@@ -176,6 +176,10 @@ void XNet::BackwardNode(XTensor * node)
return; return;
if(!XNoder::IsLeaf(node)){ if(!XNoder::IsLeaf(node)){
/* post processing for parent nodes */
BackwardNodePost(node);
/* process the current node */
if(XMathGrad::IsMathOP(node)) if(XMathGrad::IsMathOP(node))
XMathGrad::MakeGrad(node); XMathGrad::MakeGrad(node);
else if(XFuncGrad::IsFunc(node)) else if(XFuncGrad::IsFunc(node))
...@@ -186,8 +190,24 @@ void XNet::BackwardNode(XTensor * node) ...@@ -186,8 +190,24 @@ void XNet::BackwardNode(XTensor * node)
ShowNTErrors("Wrong node type!"); ShowNTErrors("Wrong node type!");
} }
} }
}
/*
backward computation (in post processing) for a given node
>> node - the node whose parent nodes are not processed yet. So
we do the job at the child node.
*/
void XNet::BackwardNodePost(XTensor * node)
{
bool isSplitList = false;
XLink &outgo = node->outgo;
for(int i = 0; i < outgo.tailNum; i++){
if(outgo.tails[i]->income.typeID == SHAPE_SPLIT_LIST)
isSplitList = true;
}
node->visitMark = NODE_FINISHED; if(isSplitList)
XShapeGrad::PostProcessing(node, SHAPE_SPLIT_LIST);
} }
/* /*
......
...@@ -73,6 +73,9 @@ struct XNet ...@@ -73,6 +73,9 @@ struct XNet
/* backward computation for a given node */ /* backward computation for a given node */
void BackwardNode(XTensor * node); void BackwardNode(XTensor * node);
/* backward computation (in post processing) for a given node */
void BackwardNodePost(XTensor * node);
/* traverse the net and find the topological order by /* traverse the net and find the topological order by
depth-first search (Tarjan's algorithm) */ depth-first search (Tarjan's algorithm) */
void Traverse(XTensor &root); void Traverse(XTensor &root);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论