Commit 18a08a65 by xuchen

optimize xbackward implementation for supporting efficient propagate and gradient accumulation

parent 0e585782
...@@ -40,30 +40,39 @@ void XFuncGrad::MakeGrad(XTensor * node, bool isEfficient) ...@@ -40,30 +40,39 @@ void XFuncGrad::MakeGrad(XTensor * node, bool isEfficient)
XTensor * input = income.tails[0]; XTensor * input = income.tails[0];
XTensor * output = node; XTensor * output = node;
if (!isEfficient || input->isGrad) {
XNoder::MakeGrad(input); XNoder::MakeGrad(input);
if(operID == FUNC_HARDTANH) XTensor * dedx = input->grad;
_HardTanHBackward(output, input, output->grad, input->grad); XTensor * dedy = output->grad;
else if(operID == FUNC_IDENTITY) XTensor * tmp = NewTensorBufV2(output, output->devID, output->mem);
_IdentityBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_LOGSOFTMAX){ if (operID == FUNC_HARDTANH)
_HardTanHBackward(output, input, dedy, tmp);
else if (operID == FUNC_IDENTITY)
_IdentityBackward(output, input, dedy, tmp);
else if (operID == FUNC_LOGSOFTMAX) {
int leadDim = income.GetParamInt(0); int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in logsoftmax!"); CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in logsoftmax!");
_LogSoftmaxBackward(NULL, output, input, output->grad, input->grad, NULL, leadDim, NOLOSS); _LogSoftmaxBackward(NULL, output, input, dedy, tmp, NULL, leadDim, NOLOSS);
} }
else if(operID == FUNC_RECTIFY) else if (operID == FUNC_RECTIFY)
_RectifyBackward(output, input, output->grad, input->grad); _RectifyBackward(output, input, dedy, tmp);
else if(operID == FUNC_SIGMOID) else if (operID == FUNC_SIGMOID)
_SigmoidBackward(output, input, output->grad, input->grad); _SigmoidBackward(output, input, dedy, tmp);
else if(operID == FUNC_SOFTMAX){ else if (operID == FUNC_SOFTMAX) {
int leadDim = income.GetParamInt(0); int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in softmax!"); CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in softmax!");
_SoftmaxBackward(NULL, output, input, output->grad, input->grad, NULL, leadDim, NOLOSS); _SoftmaxBackward(NULL, output, input, dedy, tmp, NULL, leadDim, NOLOSS);
} }
else{ else {
ShowNTErrors("Wrong activation function type!"); ShowNTErrors("Wrong activation function type!");
} }
_SumMe(dedx, tmp);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
} }
......
...@@ -48,15 +48,16 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient) ...@@ -48,15 +48,16 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient)
XTensor * padding = NULL; XTensor * padding = NULL;
int leadingDim; int leadingDim;
if (!isEfficient || output->isGrad) {
XNoder::MakeGrad(output); XNoder::MakeGrad(output);
XTensor * dedy = output->grad; XTensor * dedy = output->grad;
if (income.tailNum == 1) { if (income.tailNum == 1) {
if(dedy->dataType == X_FLOAT) if (dedy->dataType == X_FLOAT)
_SetDataFixedFloat(dedy, 1.0F); _SetDataFixedFloat(dedy, 1.0F);
else if(dedy->dataType == X_DOUBLE) else if (dedy->dataType == X_DOUBLE)
_SetDataFixedDouble(dedy, 1.0); _SetDataFixedDouble(dedy, 1.0);
else if(dedy->dataType == X_INT) else if (dedy->dataType == X_INT)
_SetDataFixedInt(dedy, 1); _SetDataFixedInt(dedy, 1);
else else
ShowNTErrors("TODO"); ShowNTErrors("TODO");
...@@ -66,16 +67,20 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient) ...@@ -66,16 +67,20 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient)
gold = income.tails[1]; gold = income.tails[1];
if(operID == LOSS_CROSSENTROPY) { XTensor* tmp = NewTensorBufV2(output, output->devID, output->mem);
if (operID == LOSS_CROSSENTROPY) {
if (income.tailNum == 3) if (income.tailNum == 3)
padding = income.tails[2]; padding = income.tails[2];
leadingDim = income.GetParamInt(0); leadingDim = income.GetParamInt(0);
CheckNTErrors(leadingDim >= 0 && leadingDim < output->order, "wrong leading dimension in logsoftmax!"); CheckNTErrors(leadingDim >= 0 && leadingDim < output->order, "wrong leading dimension in logsoftmax!");
_CrossEntropyBackward(dedy, output, gold, weight, padding, leadingDim); _CrossEntropyBackward(tmp, output, gold, weight, padding, leadingDim);
_SumMe(dedy, tmp);
} }
else{ else {
ShowNTErrors("Wrong activation function type!"); ShowNTErrors("Wrong activation function type!");
} }
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
} }
...@@ -87,79 +92,4 @@ bool XLossGrad::IsLossOP(XTensor * node) ...@@ -87,79 +92,4 @@ bool XLossGrad::IsLossOP(XTensor * node)
return (income.typeID & LOSS_BASE) != 0; return (income.typeID & LOSS_BASE) != 0;
} }
/*
compute dE/dx for a given function y = f(x)
>> gold - gold standard to measure error (or loss)
>> y - output of the function
>> x - input of the function
>> dedy - dE/dy
>> dedx - dE/dx
>> funcID - id of the function f
>> params - parameters of the function
>> lossName - name of the loss, e.g., cross entropy
*/
//void XLossGrad::Compute(XTensor * gold, XTensor * y, XTensor * x,
// XTensor * dedy, XTensor * dedx, XTensor * padding,
// int funcID, void * params,
// LOSS_FUNCTION_NAME lossName)
//{
// CheckNTErrors(gold && y && x, "Empty input tensors!");
// CheckNTErrors(dedx, "Empty gradient tensors!");
// CheckNTErrors((funcID & FUNCTION_BASE) != 0, "Illegal function id");
//
// if(funcID == FUNC_HARDTANH){
// _HardTanHBackward(gold, y, x, dedy, dedx, lossName);
// }
// else if(funcID == FUNC_IDENTITY){
// _IdentityBackward(gold, y, x, dedy, dedx, lossName);
// }
// else if(funcID == FUNC_LOGSOFTMAX){
// int leadDim = *(int*)params;
// _LogSoftmaxBackward(gold, y, x, dedy, dedx, padding, leadDim, lossName);
// }
// else if(funcID == FUNC_RECTIFY){
// _RectifyBackward(gold, y, x, dedy, dedx, lossName);
// }
// else if(funcID == FUNC_SIGMOID){
// _SigmoidBackward(gold, y, x, dedy, dedx, lossName);
// }else if(funcID == FUNC_SOFTMAX){
// int leadDim = *(int*)params;
// _SoftmaxBackward(gold, y, x, dedy, dedx, padding, leadDim, lossName);
// }
// else{
// ShowNTErrors("wrong function found when call the backward process!");
// }
//
//}
/*
compute dE/dy for variable y and error(loss) function E
>> gold - gold standard to measure error (or loss)
>> y - output of the function
>> dedy - dE/dy
>> lossName - name of the loss, e.g., cross entropy
*/
//void XLossGrad::Compute(XTensor * gold, XTensor * y,
// XTensor * dedy, XTensor * padding,
// LOSS_FUNCTION_NAME lossName)
//{
// if(gold == NULL){
// if(dedy->dataType == X_FLOAT)
// _SetDataFixedFloat(dedy, 1.0F);
// else if(dedy->dataType == X_DOUBLE)
// _SetDataFixedDouble(dedy, 1.0);
// else if(dedy->dataType == X_INT)
// _SetDataFixedInt(dedy, 1);
// else{
// ShowNTErrors("TODO");
// }
// return;
// }
//
// //_LossBackward(dedy, gold, y, lossName);
// if(lossName == CROSSENTROPY)
// _CrossEntropyBackward(dedy, y, gold, NULL, padding);
//
//}
} }
\ No newline at end of file
...@@ -316,7 +316,6 @@ void XNet::ClearGrad(XTensor * node) ...@@ -316,7 +316,6 @@ void XNet::ClearGrad(XTensor * node)
} }
if(finished){ if(finished){
//fprintf(stderr, "del %d %ld\n", node->id, node->grad->unitNum);
delete node->grad; delete node->grad;
node->grad = NULL; node->grad = NULL;
} }
......
...@@ -62,7 +62,7 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -62,7 +62,7 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
/* we transform a higher order tensor to a matrix to kill the number /* we transform a higher order tensor to a matrix to kill the number
of calls of matrix multiplication */ of calls of matrix multiplication */
if(transposedA == X_NOTRANS && a->order > 2 && b->order == 2){ if (transposedA == X_NOTRANS && a->order > 2 && b->order == 2) {
int ncolA = a->dimSize[a->order - 1]; int ncolA = a->dimSize[a->order - 1];
int ncolC = c->dimSize[c->order - 1]; int ncolC = c->dimSize[c->order - 1];
XTensor * a2 = NewTensor2DV2(a->unitNum/ncolA, -ncolA, a->dataType, a->devID, a->mem); XTensor * a2 = NewTensor2DV2(a->unitNum/ncolA, -ncolA, a->dataType, a->devID, a->mem);
......
...@@ -199,8 +199,8 @@ void funcName(const XTensor &a, const XTensor &b, XTensor c) ...@@ -199,8 +199,8 @@ void funcName(const XTensor &a, const XTensor &b, XTensor c)
} }
#ifdef USE_CUDA #ifdef USE_CUDA
_SIMPLE_MAX_MIN_FUNCTION(_Max, _CudaMax, max) _SIMPLE_MAX_MIN_FUNCTION(_Max, _CudaMax, MAX)
_SIMPLE_MAX_MIN_FUNCTION(_Min, _CudaMin, min) _SIMPLE_MAX_MIN_FUNCTION(_Min, _CudaMin, MIN)
#else #else
_SIMPLE_MAX_MIN_FUNCTION(_Max, max) _SIMPLE_MAX_MIN_FUNCTION(_Max, max)
_SIMPLE_MAX_MIN_FUNCTION(_Min, min) _SIMPLE_MAX_MIN_FUNCTION(_Min, min)
......
/* NiuTrans.Tensor - an open-source tensor library /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University. * Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved. * All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24 * $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/ */
#ifndef __SPLIT_H__ #ifndef __SPLIT_H__
#define __SPLIT_H__ #define __SPLIT_H__
......
...@@ -85,7 +85,7 @@ XTensor Stack(const TensorList &smalls, int dim) ...@@ -85,7 +85,7 @@ XTensor Stack(const TensorList &smalls, int dim)
{ {
int count = smalls.count; int count = smalls.count;
CheckNTErrors(count > 0, "Empty list!"); CheckNTErrors(count > 0, "Empty list!");
CheckNTErrors(dim >= 0, "Illegal dimension to concatenate!"); CheckNTErrors(dim >= 0, "Illegal dimension to Stack!");
XTensor * tensor = smalls.GetItem(0); XTensor * tensor = smalls.GetItem(0);
int order = tensor->order + 1; int order = tensor->order + 1;
...@@ -95,7 +95,7 @@ XTensor Stack(const TensorList &smalls, int dim) ...@@ -95,7 +95,7 @@ XTensor Stack(const TensorList &smalls, int dim)
if (i < dim) if (i < dim)
dimSize[i] = tensor->GetDim(i); dimSize[i] = tensor->GetDim(i);
else if (i > dim) else if (i > dim)
dimSize[i] = tensor->GetDim(i); dimSize[i] = tensor->GetDim(i-1);
else if (i == dim) else if (i == dim)
dimSize[i] = count; dimSize[i] = count;
} }
...@@ -149,7 +149,7 @@ void Stack(const TensorList &smalls, XTensor &t, int dim) ...@@ -149,7 +149,7 @@ void Stack(const TensorList &smalls, XTensor &t, int dim)
{ {
int count = smalls.count; int count = smalls.count;
CheckNTErrors(count > 0, "Empty list!"); CheckNTErrors(count > 0, "Empty list!");
CheckNTErrors(dim >= 0, "Illegal dimension to concatenate!"); CheckNTErrors(dim >= 0, "Illegal dimension to Stack!");
if (!t.isInit || !CheckStackShape(smalls, t, dim)) { if (!t.isInit || !CheckStackShape(smalls, t, dim)) {
XTensor * tensor = smalls.GetItem(0); XTensor * tensor = smalls.GetItem(0);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论