Commit de3aeee1 by xiaotong

new code for split

parent ea3f2a21
...@@ -43,6 +43,10 @@ void XShapeGrad::MakeGrad(XTensor * node) ...@@ -43,6 +43,10 @@ void XShapeGrad::MakeGrad(XTensor * node)
GradMergeList(node); GradMergeList(node);
else if(operID == SHAPE_UNSQUEEZE) else if(operID == SHAPE_UNSQUEEZE)
GradUnsqueeze(node); GradUnsqueeze(node);
else if(operID == SHAPE_SPLIT)
GradSplit(node);
else if(operID == SHAPE_SPLIT_LIST)
GradSplitList(node);
else{ else{
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
} }
...@@ -234,7 +238,7 @@ we have ...@@ -234,7 +238,7 @@ we have
dE/da = merge(dE/dc) dE/da = merge(dE/dc)
>> node - the node (c) for backward computation >> node - the node (c) for backward computation
*/ */
void GradSplit(XTensor * node) void XShapeGrad::GradSplit(XTensor * node)
{ {
XLink &income = node->income; XLink &income = node->income;
XTensor * input = income.tails[0]; XTensor * input = income.tails[0];
...@@ -256,17 +260,10 @@ void GradSplit(XTensor * node) ...@@ -256,17 +260,10 @@ void GradSplit(XTensor * node)
/* if the tensor is used somewhere else, we need another SUM /* if the tensor is used somewhere else, we need another SUM
for gradient accumulation */ for gradient accumulation */
else{ else{
int * dims = new int[input->order]; XTensor inputGradTMP(input);
memcpy(dims, input->dimSize, sizeof(int) * input->order);
dims[0] = -dims[0];
XTensor inputGradTMP(input->order, dims,
input->dataType, input->denseRatio,
input->devID, input->mem);
_Merge(node->grad, &inputGradTMP, whereToSplit + 1, 0); _Merge(node->grad, &inputGradTMP, whereToSplit + 1, 0);
_Sum(input->grad, &inputGradTMP, input->grad); _Sum(input->grad, &inputGradTMP, input->grad);
delete[] dims;
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
...@@ -293,7 +290,7 @@ void XShapeGrad::GradSplitList(XTensor * node) ...@@ -293,7 +290,7 @@ void XShapeGrad::GradSplitList(XTensor * node)
} }
/* /*
gradient computation for spliting where we return gradient computation for spliting. We return
the list of the splits : list(c_1, ...) = split(a). the list of the splits : list(c_1, ...) = split(a).
this method is called only when all nodes of spliting this method is called only when all nodes of spliting
have been processed. We do this in a post-processing have been processed. We do this in a post-processing
...@@ -303,6 +300,46 @@ one time. This is good for system speed up. ...@@ -303,6 +300,46 @@ one time. This is good for system speed up.
*/ */
void XShapeGrad::GradSplitListPost(XTensor * node) void XShapeGrad::GradSplitListPost(XTensor * node)
{ {
/* we compute the gradient for current node, rather than for
child node, i.e., we use the outgoing edge here */
XLink &outgo = node->outgo;
XList splits(outgo.tailNum);
int whereToSplit = -1;
int splitNum = 0;
for(int i = 0; i < outgo.tailNum; i++){
XTensor * parent = (XTensor*)outgo.tails[i];
XLink &income = parent->income;
if(income.typeID == SHAPE_SPLIT_LIST){
int w = income.GetParamInt(0);
int splitID = income.GetParamInt(1);
if(whereToSplit < 0)
whereToSplit = w;
splitNum++;
CheckNTErrors(whereToSplit == w, "Wrong dimension for spliting");
CheckNTErrors(income.tailNum == 1, "Something wrong with outgoing edge!");
CheckNTErrors(splitNum - 1 == splitID, "Wrong split id!");
splits.Add(parent);
}
}
/* we can simply merge the gradient tensor
if the node is used in spliting only */
if(outgo.tailNum == splitNum){
_Merge(&splits, node->grad, whereToSplit + 1);
}
/* if the tensor is used as input to other nodes
somewhere else, we need another SUM for gradient
accumulation */
else{
XTensor nodeGradTMP(node);
_Merge(&splits, &nodeGradTMP, whereToSplit + 1);
_Sum(node->grad, &nodeGradTMP, node->grad);
}
} }
/* /*
......
...@@ -57,11 +57,11 @@ private: ...@@ -57,11 +57,11 @@ private:
static static
void GradSplit(XTensor * node); void GradSplit(XTensor * node);
/* gradient computation for spliting where we return the list of the splits : list(c_1, ...) = split(a) */ /* gradient computation for spliting. we return the list of the splits : list(c_1, ...) = split(a) */
static static
void GradSplitList(XTensor * node); void GradSplitList(XTensor * node);
/* gradient computation for spliting where we return the list of the splits : list(c_1, ...) = split(a). /* gradient computation for spliting. we return the list of the splits : list(c_1, ...) = split(a).
this method is called only when all nodes of spliting have been processed. We do this in a post-processing this method is called only when all nodes of spliting have been processed. We do this in a post-processing
manner because we can fuze multiple memory copy jobs one time. This is good for system speed up. */ manner because we can fuze multiple memory copy jobs one time. This is good for system speed up. */
static static
......
...@@ -492,15 +492,15 @@ void Update(FNNModel &model, FNNModel &grad, float epsilon, bool isNodeGrad) ...@@ -492,15 +492,15 @@ void Update(FNNModel &model, FNNModel &grad, float epsilon, bool isNodeGrad)
gradList.Add(&grad.embeddingW); gradList.Add(&grad.embeddingW);
} }
else{ else{
paraList.Add(model.outputW.grad); gradList.Add(model.outputW.grad);
paraList.Add(&model.outputB.grad); gradList.Add(model.outputB.grad);
for (int i = 0; i < model.hDepth; i++) { for (int i = 0; i < model.hDepth; i++) {
paraList.Add(&model.hiddenW[i].grad); gradList.Add(model.hiddenW[i].grad);
paraList.Add(&model.hiddenB[i].grad); gradList.Add(model.hiddenB[i].grad);
} }
paraList.Add(&model.embeddingW.grad); gradList.Add(model.embeddingW.grad);
} }
for (int i = 0; i < paraList.count; i++) { for (int i = 0; i < paraList.count; i++) {
......
...@@ -208,22 +208,16 @@ void XList::Insert(int pos, void * item) ...@@ -208,22 +208,16 @@ void XList::Insert(int pos, void * item)
/* get the item at position i */ /* get the item at position i */
void * XList::GetItem(int i) const void * XList::GetItem(int i) const
{ {
if( i >= 0 && i < count ) CheckNTErrors(i >= 0 && i < count, "Index of a list item is out of scope!");
return items[i]; return items[i];
else
return NULL;
} }
/* get the integer-typed item at position i */ /* get the integer-typed item at position i */
int XList::GetItemInt(int i) int XList::GetItemInt(int i)
{ {
CheckNTErrors(isIntList, "An int list is required!"); CheckNTErrors(isIntList, "An int list is required!");
CheckNTErrors(i >= 0 && i < count, "Index of a list item is out of scope!");
if( i >= 0 && i < count ){ return *(int*)(items[i]);
return *(int*)(items[i]);
}
else
return 0;
} }
/* set the item at position i */ /* set the item at position i */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论