Commit 0ec51854 by xiaotong

improve the way of using buf tensors

parent 0e1074ff
......@@ -268,10 +268,12 @@ void XShapeGrad::GradSplit(XTensor * node, bool isEfficient)
/* if the tensor is used somewhere else, we need another SUM
for gradient accumulation */
else{
XTensor inputGradTMP(input);
XTensor * inputGradTMP = NewTensorBuf(input, input->devID, input->mem);
_Merge(node->grad, &inputGradTMP, whereToSplit + 1, 0);
_Sum(input->grad, &inputGradTMP, input->grad);
_Merge(node->grad, inputGradTMP, whereToSplit + 1, 0);
_Sum(input->grad, inputGradTMP, input->grad);
DelTensorBuf(inputGradTMP);
}
node->visitMark = NODE_FINISHED;
......@@ -347,10 +349,12 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient)
somewhere else, we need another SUM for gradient
accumulation */
else{
XTensor nodeGradTMP(node);
XTensor * nodeGradTMP = NewTensorBuf(node, node->devID, node->mem);
_Merge(&splits, nodeGradTMP, whereToSplit + 1);
_Sum(node->grad, nodeGradTMP, node->grad);
_Merge(&splits, &nodeGradTMP, whereToSplit + 1);
_Sum(node->grad, &nodeGradTMP, node->grad);
DelTensorBuf(nodeGradTMP);
}
}
......@@ -379,7 +383,12 @@ void XShapeGrad::GradUnsqueeze(XTensor * node, bool isEfficient)
CheckNTErrors(dSize == output->GetDim(dim), "Wrong dim size for UNSQUEEZE!");
CheckNTErrors(output->unitNum = input->unitNum * dSize, "Wrong tensor size!");
_ReduceSum(output->grad, input->grad, dim);
XTensor * g = NewTensorBuf(input->grad, input->devID, input->mem);
_ReduceSum(output->grad, g, dim);
_Sum(input->grad, g, input->grad);
DelTensorBuf(g);
node->visitMark = NODE_FINISHED;
}
......@@ -401,7 +410,7 @@ void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient)
XTensor * output = node;
XTensor * input = income.tails[0];
XTensor * b = NewTensor(input);
XTensor * b = NewTensorBuf(input, input->devID, input->mem);
XNoder::MakeGrad(input);
int i = income.GetParamInt(0);
......@@ -413,6 +422,8 @@ void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient)
_Transpose(output->grad, b, i, j);
_Sum(input->grad, b, input->grad);
DelTensorBuf(b);
node->visitMark = NODE_FINISHED;
delete b;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论