Commit 0ec51854 by xiaotong

improve the way of using buf tensors

parent 0e1074ff
...@@ -268,10 +268,12 @@ void XShapeGrad::GradSplit(XTensor * node, bool isEfficient) ...@@ -268,10 +268,12 @@ void XShapeGrad::GradSplit(XTensor * node, bool isEfficient)
/* if the tensor is used somewhere else, we need another SUM /* if the tensor is used somewhere else, we need another SUM
for gradient accumulation */ for gradient accumulation */
else{ else{
XTensor inputGradTMP(input); XTensor * inputGradTMP = NewTensorBuf(input, input->devID, input->mem);
_Merge(node->grad, &inputGradTMP, whereToSplit + 1, 0); _Merge(node->grad, inputGradTMP, whereToSplit + 1, 0);
_Sum(input->grad, &inputGradTMP, input->grad); _Sum(input->grad, inputGradTMP, input->grad);
DelTensorBuf(inputGradTMP);
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
...@@ -347,10 +349,12 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient) ...@@ -347,10 +349,12 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient)
somewhere else, we need another SUM for gradient somewhere else, we need another SUM for gradient
accumulation */ accumulation */
else{ else{
XTensor nodeGradTMP(node); XTensor * nodeGradTMP = NewTensorBuf(node, node->devID, node->mem);
_Merge(&splits, &nodeGradTMP, whereToSplit + 1); _Merge(&splits, nodeGradTMP, whereToSplit + 1);
_Sum(node->grad, &nodeGradTMP, node->grad); _Sum(node->grad, nodeGradTMP, node->grad);
DelTensorBuf(nodeGradTMP);
} }
} }
...@@ -378,8 +382,13 @@ void XShapeGrad::GradUnsqueeze(XTensor * node, bool isEfficient) ...@@ -378,8 +382,13 @@ void XShapeGrad::GradUnsqueeze(XTensor * node, bool isEfficient)
CheckNTErrors(dSize == output->GetDim(dim), "Wrong dim size for UNSQUEEZE!"); CheckNTErrors(dSize == output->GetDim(dim), "Wrong dim size for UNSQUEEZE!");
CheckNTErrors(output->unitNum = input->unitNum * dSize, "Wrong tensor size!"); CheckNTErrors(output->unitNum = input->unitNum * dSize, "Wrong tensor size!");
_ReduceSum(output->grad, input->grad, dim); XTensor * g = NewTensorBuf(input->grad, input->devID, input->mem);
_ReduceSum(output->grad, g, dim);
_Sum(input->grad, g, input->grad);
DelTensorBuf(g);
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
} }
...@@ -401,7 +410,7 @@ void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient) ...@@ -401,7 +410,7 @@ void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient)
XTensor * output = node; XTensor * output = node;
XTensor * input = income.tails[0]; XTensor * input = income.tails[0];
XTensor * b = NewTensor(input); XTensor * b = NewTensorBuf(input, input->devID, input->mem);
XNoder::MakeGrad(input); XNoder::MakeGrad(input);
int i = income.GetParamInt(0); int i = income.GetParamInt(0);
...@@ -412,10 +421,12 @@ void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient) ...@@ -412,10 +421,12 @@ void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient)
_Transpose(output->grad, b, i, j); _Transpose(output->grad, b, i, j);
_Sum(input->grad, b, input->grad); _Sum(input->grad, b, input->grad);
DelTensorBuf(b);
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
delete b; delete b;
} }
} }
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论