Commit de3aeee1 by xiaotong

new code for split

parent ea3f2a21
......@@ -43,6 +43,10 @@ void XShapeGrad::MakeGrad(XTensor * node)
GradMergeList(node);
else if(operID == SHAPE_UNSQUEEZE)
GradUnsqueeze(node);
else if(operID == SHAPE_SPLIT)
GradSplit(node);
else if(operID == SHAPE_SPLIT_LIST)
GradSplitList(node);
else{
ShowNTErrors("TODO!");
}
......@@ -234,7 +238,7 @@ we have
dE/da = merge(dE/dc)
>> node - the node (c) for backward computation
*/
void GradSplit(XTensor * node)
void XShapeGrad::GradSplit(XTensor * node)
{
XLink &income = node->income;
XTensor * input = income.tails[0];
......@@ -256,17 +260,10 @@ void GradSplit(XTensor * node)
/* if the tensor is used somewhere else, we need another SUM
for gradient accumulation */
else{
int * dims = new int[input->order];
memcpy(dims, input->dimSize, sizeof(int) * input->order);
dims[0] = -dims[0];
XTensor inputGradTMP(input->order, dims,
input->dataType, input->denseRatio,
input->devID, input->mem);
XTensor inputGradTMP(input);
_Merge(node->grad, &inputGradTMP, whereToSplit + 1, 0);
_Sum(input->grad, &inputGradTMP, input->grad);
delete[] dims;
}
node->visitMark = NODE_FINISHED;
......@@ -293,7 +290,7 @@ void XShapeGrad::GradSplitList(XTensor * node)
}
/*
gradient computation for spliting where we return
gradient computation for spliting. We return
the list of the splits : list(c_1, ...) = split(a).
this method is called only when all nodes of spliting
have been processed. We do this in a post-processing
......@@ -303,6 +300,46 @@ one time. This is good for system speed up.
*/
void XShapeGrad::GradSplitListPost(XTensor * node)
{
/* we compute the gradient for current node, rather than for
child node, i.e., we use the outgoing edge here */
XLink &outgo = node->outgo;
XList splits(outgo.tailNum);
int whereToSplit = -1;
int splitNum = 0;
for(int i = 0; i < outgo.tailNum; i++){
XTensor * parent = (XTensor*)outgo.tails[i];
XLink &income = parent->income;
if(income.typeID == SHAPE_SPLIT_LIST){
int w = income.GetParamInt(0);
int splitID = income.GetParamInt(1);
if(whereToSplit < 0)
whereToSplit = w;
splitNum++;
CheckNTErrors(whereToSplit == w, "Wrong dimension for spliting");
CheckNTErrors(income.tailNum == 1, "Something wrong with outgoing edge!");
CheckNTErrors(splitNum - 1 == splitID, "Wrong split id!");
splits.Add(parent);
}
}
/* we can simply merge the gradient tensor
if the node is used in spliting only */
if(outgo.tailNum == splitNum){
_Merge(&splits, node->grad, whereToSplit + 1);
}
/* if the tensor is used as input to other nodes
somewhere else, we need another SUM for gradient
accumulation */
else{
XTensor nodeGradTMP(node);
_Merge(&splits, &nodeGradTMP, whereToSplit + 1);
_Sum(node->grad, &nodeGradTMP, node->grad);
}
}
/*
......
......@@ -57,11 +57,11 @@ private:
static
void GradSplit(XTensor * node);
/* gradient computation for spliting where we return the list of the splits : list(c_1, ...) = split(a) */
/* gradient computation for spliting. we return the list of the splits : list(c_1, ...) = split(a) */
static
void GradSplitList(XTensor * node);
/* gradient computation for spliting where we return the list of the splits : list(c_1, ...) = split(a).
/* gradient computation for spliting. we return the list of the splits : list(c_1, ...) = split(a).
this method is called only when all nodes of spliting have been processed. We do this in a post-processing
manner because we can fuze multiple memory copy jobs one time. This is good for system speed up. */
static
......
......@@ -492,15 +492,15 @@ void Update(FNNModel &model, FNNModel &grad, float epsilon, bool isNodeGrad)
gradList.Add(&grad.embeddingW);
}
else{
paraList.Add(model.outputW.grad);
paraList.Add(&model.outputB.grad);
gradList.Add(model.outputW.grad);
gradList.Add(model.outputB.grad);
for (int i = 0; i < model.hDepth; i++) {
paraList.Add(&model.hiddenW[i].grad);
paraList.Add(&model.hiddenB[i].grad);
gradList.Add(model.hiddenW[i].grad);
gradList.Add(model.hiddenB[i].grad);
}
paraList.Add(&model.embeddingW.grad);
gradList.Add(model.embeddingW.grad);
}
for (int i = 0; i < paraList.count; i++) {
......
......@@ -208,22 +208,16 @@ void XList::Insert(int pos, void * item)
/* get the item at position i */
void * XList::GetItem(int i) const
{
if( i >= 0 && i < count )
return items[i];
else
return NULL;
CheckNTErrors(i >= 0 && i < count, "Index of a list item is out of scope!");
return items[i];
}
/* get the integer-typed item at position i */
int XList::GetItemInt(int i)
{
CheckNTErrors(isIntList, "An int list is required!");
if( i >= 0 && i < count ){
return *(int*)(items[i]);
}
else
return 0;
CheckNTErrors(i >= 0 && i < count, "Index of a list item is out of scope!");
return *(int*)(items[i]);
}
/* set the item at position i */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论