Commit 18a08a65 by xuchen

optimize xbackward implementation for supporting efficient propagate and gradient accumulation

parent 0e585782
......@@ -40,30 +40,39 @@ void XFuncGrad::MakeGrad(XTensor * node, bool isEfficient)
XTensor * input = income.tails[0];
XTensor * output = node;
if (!isEfficient || input->isGrad) {
XNoder::MakeGrad(input);
if(operID == FUNC_HARDTANH)
_HardTanHBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_IDENTITY)
_IdentityBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_LOGSOFTMAX){
XTensor * dedx = input->grad;
XTensor * dedy = output->grad;
XTensor * tmp = NewTensorBufV2(output, output->devID, output->mem);
if (operID == FUNC_HARDTANH)
_HardTanHBackward(output, input, dedy, tmp);
else if (operID == FUNC_IDENTITY)
_IdentityBackward(output, input, dedy, tmp);
else if (operID == FUNC_LOGSOFTMAX) {
int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in logsoftmax!");
_LogSoftmaxBackward(NULL, output, input, output->grad, input->grad, NULL, leadDim, NOLOSS);
_LogSoftmaxBackward(NULL, output, input, dedy, tmp, NULL, leadDim, NOLOSS);
}
else if(operID == FUNC_RECTIFY)
_RectifyBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_SIGMOID)
_SigmoidBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_SOFTMAX){
else if (operID == FUNC_RECTIFY)
_RectifyBackward(output, input, dedy, tmp);
else if (operID == FUNC_SIGMOID)
_SigmoidBackward(output, input, dedy, tmp);
else if (operID == FUNC_SOFTMAX) {
int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in softmax!");
_SoftmaxBackward(NULL, output, input, output->grad, input->grad, NULL, leadDim, NOLOSS);
_SoftmaxBackward(NULL, output, input, dedy, tmp, NULL, leadDim, NOLOSS);
}
else{
else {
ShowNTErrors("Wrong activation function type!");
}
_SumMe(dedx, tmp);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......
......@@ -48,15 +48,16 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient)
XTensor * padding = NULL;
int leadingDim;
if (!isEfficient || output->isGrad) {
XNoder::MakeGrad(output);
XTensor * dedy = output->grad;
if (income.tailNum == 1) {
if(dedy->dataType == X_FLOAT)
if (dedy->dataType == X_FLOAT)
_SetDataFixedFloat(dedy, 1.0F);
else if(dedy->dataType == X_DOUBLE)
else if (dedy->dataType == X_DOUBLE)
_SetDataFixedDouble(dedy, 1.0);
else if(dedy->dataType == X_INT)
else if (dedy->dataType == X_INT)
_SetDataFixedInt(dedy, 1);
else
ShowNTErrors("TODO");
......@@ -66,16 +67,20 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient)
gold = income.tails[1];
if(operID == LOSS_CROSSENTROPY) {
XTensor* tmp = NewTensorBufV2(output, output->devID, output->mem);
if (operID == LOSS_CROSSENTROPY) {
if (income.tailNum == 3)
padding = income.tails[2];
leadingDim = income.GetParamInt(0);
CheckNTErrors(leadingDim >= 0 && leadingDim < output->order, "wrong leading dimension in logsoftmax!");
_CrossEntropyBackward(dedy, output, gold, weight, padding, leadingDim);
_CrossEntropyBackward(tmp, output, gold, weight, padding, leadingDim);
_SumMe(dedy, tmp);
}
else{
else {
ShowNTErrors("Wrong activation function type!");
}
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -87,79 +92,4 @@ bool XLossGrad::IsLossOP(XTensor * node)
return (income.typeID & LOSS_BASE) != 0;
}
/*
compute dE/dx for a given function y = f(x)
>> gold - gold standard to measure error (or loss)
>> y - output of the function
>> x - input of the function
>> dedy - dE/dy
>> dedx - dE/dx
>> funcID - id of the function f
>> params - parameters of the function
>> lossName - name of the loss, e.g., cross entropy
*/
//void XLossGrad::Compute(XTensor * gold, XTensor * y, XTensor * x,
// XTensor * dedy, XTensor * dedx, XTensor * padding,
// int funcID, void * params,
// LOSS_FUNCTION_NAME lossName)
//{
// CheckNTErrors(gold && y && x, "Empty input tensors!");
// CheckNTErrors(dedx, "Empty gradient tensors!");
// CheckNTErrors((funcID & FUNCTION_BASE) != 0, "Illegal function id");
//
// if(funcID == FUNC_HARDTANH){
// _HardTanHBackward(gold, y, x, dedy, dedx, lossName);
// }
// else if(funcID == FUNC_IDENTITY){
// _IdentityBackward(gold, y, x, dedy, dedx, lossName);
// }
// else if(funcID == FUNC_LOGSOFTMAX){
// int leadDim = *(int*)params;
// _LogSoftmaxBackward(gold, y, x, dedy, dedx, padding, leadDim, lossName);
// }
// else if(funcID == FUNC_RECTIFY){
// _RectifyBackward(gold, y, x, dedy, dedx, lossName);
// }
// else if(funcID == FUNC_SIGMOID){
// _SigmoidBackward(gold, y, x, dedy, dedx, lossName);
// }else if(funcID == FUNC_SOFTMAX){
// int leadDim = *(int*)params;
// _SoftmaxBackward(gold, y, x, dedy, dedx, padding, leadDim, lossName);
// }
// else{
// ShowNTErrors("wrong function found when call the backward process!");
// }
//
//}
/*
compute dE/dy for variable y and error(loss) function E
>> gold - gold standard to measure error (or loss)
>> y - output of the function
>> dedy - dE/dy
>> lossName - name of the loss, e.g., cross entropy
*/
//void XLossGrad::Compute(XTensor * gold, XTensor * y,
// XTensor * dedy, XTensor * padding,
// LOSS_FUNCTION_NAME lossName)
//{
// if(gold == NULL){
// if(dedy->dataType == X_FLOAT)
// _SetDataFixedFloat(dedy, 1.0F);
// else if(dedy->dataType == X_DOUBLE)
// _SetDataFixedDouble(dedy, 1.0);
// else if(dedy->dataType == X_INT)
// _SetDataFixedInt(dedy, 1);
// else{
// ShowNTErrors("TODO");
// }
// return;
// }
//
// //_LossBackward(dedy, gold, y, lossName);
// if(lossName == CROSSENTROPY)
// _CrossEntropyBackward(dedy, y, gold, NULL, padding);
//
//}
}
\ No newline at end of file
......@@ -30,80 +30,80 @@ namespace nts{
/* compute dE/dx of a node */
void XMathGrad::MakeGrad(XTensor * node, bool isEfficient)
{
if(!isEfficient){
if (!isEfficient) {
CheckNTErrors(node->grad != NULL, "No gradient found!");
}
else{
else {
CheckNTErrors(!node->isGrad || node->grad != NULL, "No gradient found!");
}
XLink &income = node->income;
int operID = income.typeID;
if(operID == MATH_ABSOLUTE)
if (operID == MATH_ABSOLUTE)
GradAbsolute(node, isEfficient);
else if(operID == MATH_COS)
else if (operID == MATH_COS)
GradCos(node, isEfficient);
else if(operID == MATH_EXP)
else if (operID == MATH_EXP)
GradExp(node, isEfficient);
else if(operID == MATH_LOG)
else if (operID == MATH_LOG)
GradLog(node, isEfficient);
else if(operID == MATH_ROUND)
else if (operID == MATH_ROUND)
GradRound(node, isEfficient);
else if(operID == MATH_SIGN)
else if (operID == MATH_SIGN)
GradSign(node, isEfficient);
else if(operID == MATH_SIN)
else if (operID == MATH_SIN)
GradSin(node, isEfficient);
else if(operID == MATH_TAN)
else if (operID == MATH_TAN)
GradTan(node, isEfficient);
else if(operID == MATH_CLIP)
else if (operID == MATH_CLIP)
GradClip(node, isEfficient);
else if(operID == MATH_DIV)
else if (operID == MATH_DIV)
GradDiv(node, isEfficient);
else if(operID == MATH_DIVDIM)
else if (operID == MATH_DIVDIM)
GradDivDim(node, isEfficient);
else if(operID == MATH_MATRIXMUL)
else if (operID == MATH_MATRIXMUL)
GradMatrixMul(node, isEfficient);
else if(operID == MATH_MATRIXMULBATCHED)
else if (operID == MATH_MATRIXMULBATCHED)
GradMatrixMulBatched(node, isEfficient);
else if(operID == MATH_MULTIPLY)
else if (operID == MATH_MULTIPLY)
GradMultiply(node, isEfficient);
else if(operID == MATH_MULTIPLYDIM)
else if (operID == MATH_MULTIPLYDIM)
GradMultiplyDim(node, isEfficient);
else if (operID == MATH_MULTIPLYBROADCAST)
GradMultiplyBroadcast(node, isEfficient);
else if(operID == MATH_NEGATE)
else if (operID == MATH_NEGATE)
GradNegate(node, isEfficient);
else if(operID == MATH_NORMALIZE)
else if (operID == MATH_NORMALIZE)
GradNormalize(node, isEfficient);
else if(operID == MATH_POWER)
else if (operID == MATH_POWER)
GradPower(node, isEfficient);
else if(operID == MATH_SCALEANDSHIFT)
else if (operID == MATH_SCALEANDSHIFT)
GradScaleAndShift(node, isEfficient);
else if(operID == MATH_SCALE)
else if (operID == MATH_SCALE)
GradScale(node, isEfficient);
else if(operID == MATH_DESCALE)
else if (operID == MATH_DESCALE)
GradDescale(node, isEfficient);
else if(operID == MATH_SHIFT)
else if (operID == MATH_SHIFT)
GradShift(node, isEfficient);
else if(operID == MATH_SUB)
else if (operID == MATH_SUB)
GradSub(node, isEfficient);
else if(operID == MATH_SUBDIM)
else if (operID == MATH_SUBDIM)
GradSubDim(node, isEfficient);
else if(operID == MATH_SUM)
else if (operID == MATH_SUM)
GradSum(node, isEfficient);
else if(operID == MATH_SUMDIM)
else if (operID == MATH_SUMDIM)
GradSumDim(node, isEfficient);
else if(operID == MATH_SUMBROADCAST)
else if (operID == MATH_SUMBROADCAST)
GradSumBroadcast(node, isEfficient);
else if(operID == REDUCE_REDUCEMEAN)
else if (operID == REDUCE_REDUCEMEAN)
GradReduceMean(node, isEfficient);
else if(operID == REDUCE_REDUCESUM)
else if (operID == REDUCE_REDUCESUM)
GradReduceSum(node, isEfficient);
else if(operID == REDUCE_REDUCESUMSQUARED)
else if (operID == REDUCE_REDUCESUMSQUARED)
GradReduceSumSquared(node, isEfficient);
else if(operID == REDUCE_REDUCEVARIANCE)
else if (operID == REDUCE_REDUCEVARIANCE)
GradReduceVariance(node, isEfficient);
else if (operID == MATH_MULANDSHIFT)
GradMulAndShift(node, isEfficient);
......@@ -136,14 +136,17 @@ void XMathGrad::GradAbsolute(XTensor * node, bool isEfficient)
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for ABSOLUTE!");
XTensor * a = income.tails[0];
XTensor * b = NewTensorBufV2(a, a->devID, a->mem);
/* dE/da = dE/dc * sign(a) */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Sign(a, b);
_Multiply(node->grad, b, a->grad, 1.0F);
XTensor * tmp = NewTensorBufV2(a, a->devID, a->mem);
_Sign(a, tmp);
_Multiply(node->grad, tmp, a->grad, 1.0F);
DelTensorBuf(b);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -164,15 +167,18 @@ void XMathGrad::GradCos(XTensor * node, bool isEfficient)
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for COS!");
XTensor * a = income.tails[0];
XTensor * b = NewTensorBufV2(a, a->devID, a->mem);
/* dE/da = dE/dc * -sin(a) */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Sin(a, b);
_ScaleAndShiftMe(b, -1.0F);
_Multiply(node->grad, b, a->grad, 1.0F);
XTensor * tmp = NewTensorBufV2(a, a->devID, a->mem);
_Sin(a, tmp);
_NegateMe(tmp);
_Multiply(node->grad, tmp, a->grad, 1.0F);
DelTensorBuf(b);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -193,14 +199,17 @@ void XMathGrad::GradExp(XTensor * node, bool isEfficient)
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for EXP!");
XTensor * a = income.tails[0];
XTensor * b = NewTensorBufV2(a, a->devID, a->mem);
/* dE/da = dE/dc * exp(a) */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Exp(a, b);
_Multiply(node->grad, b, a->grad, 1.0F);
XTensor * tmp = NewTensorBufV2(a, a->devID, a->mem);
_Exp(a, tmp);
_Multiply(node->grad, tmp, a->grad, 1.0F);
DelTensorBuf(b);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -222,9 +231,11 @@ void XMathGrad::GradLog(XTensor * node, bool isEfficient)
XTensor * a = income.tails[0];
/* dE/da = dE/dc * 1/a */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Div(node->grad, a, a->grad, 1.0F);
}
node->visitMark = NODE_FINISHED;
}
......@@ -244,8 +255,12 @@ void XMathGrad::GradRound(XTensor * node, bool isEfficient)
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for ROUND!");
// we do nothing here
// TODO: set grad = 0 if the node is the only child
XTensor * a = income.tails[0];
/* dE/da = 0, we do nothing here */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
}
node->visitMark = NODE_FINISHED;
}
......@@ -265,8 +280,12 @@ void XMathGrad::GradSign(XTensor * node, bool isEfficient)
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SIGN!");
// we do nothing here
// TODO: set grad = 0 if the node is the only child
XTensor * a = income.tails[0];
/* dE/da = 0, we do nothing here */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
}
node->visitMark = NODE_FINISHED;
}
......@@ -287,14 +306,17 @@ void XMathGrad::GradSin(XTensor * node, bool isEfficient)
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SIN!");
XTensor * a = income.tails[0];
XTensor * b = NewTensorBufV2(a, a->devID, a->mem);
/* dE/da = dE/dc * cos(a) */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Cos(a, b);
_Multiply(node->grad, b, a->grad, 1.0F);
XTensor * tmp = NewTensorBufV2(a, a->devID, a->mem);
_Cos(a, tmp);
_Multiply(node->grad, tmp, a->grad, 1.0F);
DelTensorBuf(b);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -315,15 +337,18 @@ void XMathGrad::GradTan(XTensor * node, bool isEfficient)
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for TAN!");
XTensor * a = income.tails[0];
XTensor * b = NewTensorBufV2(a, a->devID, a->mem);
XTensor * tmp = NewTensorBufV2(a, a->devID, a->mem);
/* dE/da = dE/dc * 1/(cos(a))^2
= dE/dc * (cos(a))^-2 */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Cos(a, tmp);
_PowerMe(tmp, -2.0F);
_Multiply(node->grad, tmp, a->grad, 1.0F);
_Cos(a, b);
_PowerMe(b, -2.0F);
_Multiply(node->grad, b, a->grad, 1.0F);
DelTensorBuf(b);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -343,17 +368,21 @@ void XMathGrad::GradClip(XTensor * node, bool isEfficient)
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for CLIP!");
XTensor * a = income.tails[0];
XTensor * b = NewTensorBufV2(a, a->devID, a->mem);
DTYPE lower = income.GetParam(0);
DTYPE upper = income.GetParam(1);
/* dE/da = 1 lower < a < upper
= 0 otherwise */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_ClipBackward(node, a, node->grad, a->grad, lower, upper);
_Sum(a->grad, b, a->grad);
XTensor * tmp = NewTensorBufV2(a, a->devID, a->mem);
_ClipBackward(node, a, node->grad, tmp, lower, upper);
_SumMe(a->grad, tmp);
DelTensorBuf(b);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -376,21 +405,26 @@ void XMathGrad::GradDiv(XTensor * node, bool isEfficient)
XTensor * a = income.tails[0];
XTensor * b = income.tails[1];
XTensor * ab2 = NewTensorBufV2(a, a->devID, a->mem);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
CheckNTErrors(_IsSameShaped(a, b), "Wrong sized input tensors!");
/* dE/da = dE/dc / b */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Div(node->grad, b, a->grad, 1.0F);
}
_Power(b, ab2, -2.0F);
_Multiply(a, ab2, ab2);
_ScaleAndShiftMe(ab2, -1.0F);
_Multiply(node->grad, ab2, b->grad, 1.0F);
/* dE/db = dE/dc * a/(-b^2)
= dE/dc * a * (-b^-2) */
if (!isEfficient || b->isGrad) {
XNoder::MakeGrad(b);
XTensor * tmp = NewTensorBufV2(a, a->devID, a->mem);
_Power(b, tmp, -2.0F);
_NegateMe(tmp);
_MultiplyMe(tmp, a);
_Multiply(node->grad, tmp, b->grad, 1.0F);
DelTensorBuf(ab2);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -414,13 +448,17 @@ void XMathGrad::GradDivDim(XTensor * node, bool isEfficient)
XTensor * a = income.tails[0];
XTensor * b = income.tails[1];
int n = income.GetParamInt(0);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
/* dE/da = dE/dc * (1/b) */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_DivDim(node->grad, b, a->grad, n, 1.0);
}
/* dE/db = dE/dc * dc/db */
/* dE/db = dE/dc * dc/db
= (dE/dc * (-a/b^2)).reduce(0,...,n-1,n+1,...) */
if (!isEfficient || b->isGrad) {
XNoder::MakeGrad(b);
int order = a->order;
int dimSize[MAX_TENSOR_DIM_NUM];
memcpy(dimSize, a->dimSize, sizeof(int) * a->order);
......@@ -436,35 +474,30 @@ void XMathGrad::GradDivDim(XTensor * node, bool isEfficient)
_Multiply(node->grad, aTMP2, interGradTMP);
if(n == order - 1){
if (n == order - 1) {
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = a->unitNum/dimSize[order - 1];
reshapedSize[0] = a->unitNum / dimSize[order - 1];
reshapedSize[1] = dimSize[order - 1];
/* we reshape dE/dc * a to a matrix whose column number is equal to the
size of b. Then we can reduce the matrix into a row vector. */
interGradTMP->Reshape(2, reshapedSize);
//if(b->outgo.tailNum > 1){
XTensor * bGradTMP = NewTensorBufV2(b->grad, b->devID, b->mem);
_ReduceSum(interGradTMP, bGradTMP, 0);
_Sum(b->grad, bGradTMP, b->grad);
_SumMe(b->grad, bGradTMP);
DelTensorBuf(bGradTMP);
/*}
else{
_ReduceSum(interGradTMP, b->grad, 0);
}*/
}
else{
else {
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = 1;
reshapedSize[1] = dimSize[n];
reshapedSize[2] = 1;
for(int i = 0; i < order; i++){
if(i < n)
for (int i = 0; i < order; i++) {
if (i < n)
reshapedSize[0] *= dimSize[i];
}
......@@ -477,17 +510,12 @@ void XMathGrad::GradDivDim(XTensor * node, bool isEfficient)
XTensor * interGrad = NewTensorBufV2(2, reshapedSize, b->dataType, b->denseRatio, b->devID, b->mem);
_ReduceSum(interGradTMP, interGrad, 2);
//if(b->outgo.tailNum > 1){
XTensor * bGradTMP2 = NewTensorBufV2(b->grad, b->devID, b->mem);
_ReduceSum(interGrad, bGradTMP2, 0);
_Sum(b->grad, bGradTMP2, b->grad);
_SumMe(b->grad, bGradTMP2);
DelTensorBuf(bGradTMP2);
/*}
else{
_ReduceSum(interGrad, b->grad, 0);
}*/
DelTensorBuf(interGrad);
}
......@@ -495,6 +523,7 @@ void XMathGrad::GradDivDim(XTensor * node, bool isEfficient)
DelTensorBuf(bTMP);
DelTensorBuf(aTMP2);
DelTensorBuf(aTMP1);
}
node->visitMark = NODE_FINISHED;
}
......@@ -521,9 +550,9 @@ void XMathGrad::GradMatrixMul(XTensor * node, bool isEfficient)
MATRIX_TRANS_TYPE transB = income.GetParamTrans(1);
DTYPE alpha = income.GetParam(2);
if(!isEfficient || a->isGrad)
if (!isEfficient || a->isGrad)
XNoder::MakeGrad(a);
if(!isEfficient || b->isGrad)
if (!isEfficient || b->isGrad)
XNoder::MakeGrad(b);
XTensor * c = node;
......@@ -531,9 +560,9 @@ void XMathGrad::GradMatrixMul(XTensor * node, bool isEfficient)
XTensor * deda = a->grad;
XTensor * dedb = b->grad;
if(a->order == 2 && b->order == 2)
if (a->order == 2 && b->order == 2)
GradMatrixMul(a, deda, transA, b, dedb, transB, dedc, alpha, isEfficient);
else if(transA == X_NOTRANS && a->order > 2 && b->order == 2){
else if (transA == X_NOTRANS && a->order > 2 && b->order == 2){
int orderBackupA = a->order;
int orderBackupC = c->order;
int dimsBackupA[MAX_TENSOR_DIM_NUM];
......@@ -543,7 +572,7 @@ void XMathGrad::GradMatrixMul(XTensor * node, bool isEfficient)
a->Reshape(a->unitNum/a->GetDim(-1), a->GetDim(-1));
c->Reshape(c->unitNum/c->GetDim(-1), c->GetDim(-1));
if(!isEfficient || a->isGrad)
if (!isEfficient || a->isGrad)
deda->Reshape(deda->unitNum/deda->GetDim(-1), deda->GetDim(-1));
dedc->Reshape(dedc->unitNum/dedc->GetDim(-1), dedc->GetDim(-1));
......@@ -551,7 +580,7 @@ void XMathGrad::GradMatrixMul(XTensor * node, bool isEfficient)
a->Reshape(orderBackupA, dimsBackupA);
c->Reshape(orderBackupC, dimsBackupC);
if(!isEfficient || a->isGrad)
if (!isEfficient || a->isGrad)
deda->Reshape(orderBackupA, dimsBackupA);
dedc->Reshape(orderBackupC, dimsBackupC);
}
......@@ -578,54 +607,54 @@ void XMathGrad::GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE tra
XTensor * dedc, DTYPE alpha, bool isEfficient)
{
/* c = a * b * \alpha */
if(transA == X_NOTRANS && transB == X_NOTRANS){
if (transA == X_NOTRANS && transB == X_NOTRANS) {
/* dE/da = dE/dc * b^T * \alpha */
if(!isEfficient || a->isGrad)
if (!isEfficient || a->isGrad)
_MatrixMul(dedc, X_NOTRANS, b, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a^T * dE/dc * \alpha */
if(!isEfficient || b->isGrad)
if (!isEfficient || b->isGrad)
_MatrixMul(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
}
/* c = a^T * b * \alpha */
else if(transA == X_TRANS && transB == X_NOTRANS){
else if (transA == X_TRANS && transB == X_NOTRANS){
/* dE/da = (dE/dc * b^T)^T * \alpha
= b * dE/dc^T * \alpha */
if(!isEfficient || a->isGrad)
if (!isEfficient || a->isGrad)
_MatrixMul(b, X_NOTRANS, dedc, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a * dE/dc * \alpha */
if(!isEfficient || b->isGrad)
if (!isEfficient || b->isGrad)
_MatrixMul(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
}
/* c = a * b^T * \alpha */
else if(transA == X_NOTRANS && transB == X_TRANS){
else if (transA == X_NOTRANS && transB == X_TRANS){
/* dE/da = dE/dc * b * \alpha */
if(!isEfficient || a->isGrad)
if (!isEfficient || a->isGrad)
_MatrixMul(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha, 1.0F);
/* dE/db = (a^T * dE/dc)^T * \alpha
= dE/dc^T * a * \alpha */
if(!isEfficient || b->isGrad)
if (!isEfficient || b->isGrad)
_MatrixMul(dedc, X_TRANS, a, X_NOTRANS, dedb, alpha, 1.0F);
}
/* c = a^T * b^T * \alpha */
else if(transA == X_TRANS && transB == X_TRANS){
else if (transA == X_TRANS && transB == X_TRANS){
/* dE/da = (dE/dc * b)^T * \alpha
= b^T * dE/dc^T * \alpha */
if(!isEfficient || a->isGrad)
if (!isEfficient || a->isGrad)
_MatrixMul(b, X_TRANS, dedc, X_TRANS, deda, alpha, 1.0F);
/* dE/db = (a * dE/dc)^T * \alpha
= dE/dc^T * a^T * \alpha */
if(!isEfficient || b->isGrad)
if (!isEfficient || b->isGrad)
_MatrixMul(dedc, X_TRANS, a, X_TRANS, dedb, alpha, 1.0F);
}
}
......@@ -653,7 +682,9 @@ void XMathGrad::GradMatrixMulBatched(XTensor * node, bool isEfficient)
MATRIX_TRANS_TYPE transB = income.GetParamTrans(1);
DTYPE alpha = income.GetParam(2);
if (!isEfficient || a->isGrad)
XNoder::MakeGrad(a);
if (!isEfficient || b->isGrad)
XNoder::MakeGrad(b);
XTensor * dedc = node->grad;
......@@ -661,46 +692,54 @@ void XMathGrad::GradMatrixMulBatched(XTensor * node, bool isEfficient)
XTensor * dedb = b->grad;
/* c = a * b * \alpha */
if(transA == X_NOTRANS && transB == X_NOTRANS){
if (transA == X_NOTRANS && transB == X_NOTRANS) {
/* dE/da = dE/dc * b^T * \alpha */
if (!isEfficient || a->isGrad)
_MatrixMulBatched(dedc, X_NOTRANS, b, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a^T * dE/dc * \alpha */
if (!isEfficient || b->isGrad)
_MatrixMulBatched(a, X_TRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
}
/* c = a^T * b * \alpha */
else if(transA == X_TRANS && transB == X_NOTRANS){
else if (transA == X_TRANS && transB == X_NOTRANS) {
/* dE/da = (dE/dc * b^T)^T * \alpha
= b * dE/dc^T * \alpha */
if (!isEfficient || a->isGrad)
_MatrixMulBatched(b, X_NOTRANS, dedc, X_TRANS, deda, alpha, 1.0F);
/* dE/db = a * dE/dc * \alpha */
if (!isEfficient || b->isGrad)
_MatrixMulBatched(a, X_NOTRANS, dedc, X_NOTRANS, dedb, alpha, 1.0F);
}
/* c = a * b^T * \alpha */
else if(transA == X_NOTRANS && transB == X_TRANS){
else if (transA == X_NOTRANS && transB == X_TRANS) {
/* dE/da = dE/dc * b * \alpha */
if (!isEfficient || a->isGrad)
_MatrixMulBatched(dedc, X_NOTRANS, b, X_NOTRANS, deda, alpha, 1.0F);
/* dE/db = (a^T * dE/dc)^T * \alpha
= dE/dc^T * a * \alpha */
if (!isEfficient || b->isGrad)
_MatrixMulBatched(dedc, X_TRANS, a, X_NOTRANS, dedb, alpha, 1.0F);
}
/* c = a^T * b^T * \alpha */
else if(transA == X_TRANS && transB == X_TRANS){
else if (transA == X_TRANS && transB == X_TRANS) {
/* dE/da = (dE/dc * b)^T * \alpha
= b^T * dE/dc^T * \alpha */
if (!isEfficient || a->isGrad)
_MatrixMulBatched(b, X_TRANS, dedc, X_TRANS, deda, alpha, 1.0F);
/* dE/db = (a * dE/dc)^T * \alpha
= dE/dc^T * a^T * \alpha */
if (!isEfficient || b->isGrad)
_MatrixMulBatched(dedc, X_TRANS, a, X_TRANS, dedb, alpha, 1.0F);
}
......@@ -728,11 +767,13 @@ void XMathGrad::GradMultiply(XTensor * node, bool isEfficient)
CheckNTErrors(_IsSameShaped(a, b), "Wrong sized input tensors!");
/* dE/da = dE/dc * b */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Multiply(node->grad, b, a->grad, 1.0F);
}
/* dE/db = dE/dc * a */
if (!isEfficient || b->isGrad) {
XNoder::MakeGrad(b);
_Multiply(node->grad, a, b->grad, 1.0F);
......@@ -760,13 +801,16 @@ void XMathGrad::GradMultiplyDim(XTensor * node, bool isEfficient)
XTensor * a = income.tails[0];
XTensor * b = income.tails[1];
int n = income.GetParamInt(0);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
/* dE/da */
/* dE/da = dE/dc * b */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_MultiplyDim(node->grad, b, a->grad, n, 1.0F);
}
/* dE/db */
/* dE/db = (dE/dc * a).reduce(0,...,n-1,n+1,...) */
if (!isEfficient || b->isGrad) {
XNoder::MakeGrad(b);
int order = a->order;
int dimSize[MAX_TENSOR_DIM_NUM];
memcpy(dimSize, a->dimSize, sizeof(int) * a->order);
......@@ -774,35 +818,30 @@ void XMathGrad::GradMultiplyDim(XTensor * node, bool isEfficient)
XTensor * bGradTMP = NewTensorBufV2(node->grad, node->devID, node->mem);
_Multiply(node->grad, a, bGradTMP);
if(n == order - 1){
if (n == order - 1) {
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = a->unitNum/dimSize[order - 1];
reshapedSize[0] = a->unitNum / dimSize[order - 1];
reshapedSize[1] = dimSize[order - 1];
/* we reshape dE/dc * a to a matrix whose column number is equal to the
size of b. Then we can reduce the matrix into a row vector. */
bGradTMP->Reshape(2, reshapedSize);
//if(b->outgo.tailNum > 1){
XTensor * bGradTMP2 = NewTensorBufV2(b->grad, b->devID, b->mem);
_ReduceSum(bGradTMP, bGradTMP2, 0);
_Sum(b->grad, bGradTMP2, b->grad);
DelTensorBuf(bGradTMP2);
/*}
else{
_ReduceSum(bGradTMP, b->grad, 0);
}*/
}
else{
else {
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = 1;
reshapedSize[1] = dimSize[n];
reshapedSize[2] = 1;
for(int i = 0; i < order; i++){
if(i < n)
for (int i = 0; i < order; i++) {
if (i < n)
reshapedSize[0] *= dimSize[i];
}
......@@ -815,22 +854,17 @@ void XMathGrad::GradMultiplyDim(XTensor * node, bool isEfficient)
XTensor * interGrad = NewTensorBufV2(2, reshapedSize, b->dataType, b->denseRatio, b->devID, b->mem);
_ReduceSum(bGradTMP, interGrad, 2);
//if(b->outgo.tailNum > 1){
XTensor * bGradTMP2 = NewTensorBufV2(b->grad, b->devID, b->mem);
_ReduceSum(interGrad, bGradTMP2, 0);
_Sum(b->grad, bGradTMP2, b->grad);
DelTensorBuf(bGradTMP2);
/*}
else{
_ReduceSum(interGrad, b->grad, 0);
}*/
DelTensorBuf(interGrad);
}
DelTensorBuf(bGradTMP);
}
node->visitMark = NODE_FINISHED;
}
......@@ -857,11 +891,18 @@ void XMathGrad::GradMultiplyBroadcast(XTensor * node, bool isEfficient)
XTensor * b = income.tails[1];
XNoder::MakeGrad(a);
/* dE/da = dE/dc * b */
if (!isEfficient || a->isGrad)
_MultiplyBroadcast(node->grad, b, a->grad, 1.0F);
if(b->isVar || b->income.tailNum > 0){
/* dE/db = (dE/dc * a).reduce(0...n) */
if (!isEfficient || b->isGrad) {
if (b->isVar || b->income.tailNum > 0)
ShowNTErrors("TODO");
}
node->visitMark = NODE_FINISHED;
}
/*
......@@ -880,14 +921,12 @@ void XMathGrad::GradNegate(XTensor * node, bool isEfficient)
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for NEGATE!");
XTensor * a = income.tails[0];
XTensor * b = NewTensorBufV2(a, a->devID, a->mem);
/* dE/da = dE/dc * (-1) */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_ScaleAndShift(node->grad, b, -1.0F);
_Sum(a->grad, b, a->grad);
DelTensorBuf(b);
_Sum(a->grad, node->grad, a->grad, -1.0F);
}
node->visitMark = NODE_FINISHED;
}
......@@ -901,7 +940,6 @@ gradient for normalize
void XMathGrad::GradNormalize(XTensor * node, bool isEfficient)
{
ShowNTErrors("TODO!");
}
/*
......@@ -920,17 +958,20 @@ void XMathGrad::GradPower(XTensor * node, bool isEfficient)
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for POWER!");
XTensor * a = income.tails[0];
XTensor * b = NewTensorBufV2(a, a->devID, a->mem);
DTYPE p = income.GetParam(0);
/* dE/da = (dE/dc) * p * a^(p-1) */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Power(a, b, p - 1.0F);
_ScaleAndShiftMe(b, p);
_Multiply(node->grad, b, a->grad, 1.0F);
XTensor * tmp = NewTensorBufV2(a, a->devID, a->mem);
_Power(a, tmp, p - 1.0F);
_ScaleAndShiftMe(tmp, p);
_Multiply(node->grad, tmp, a->grad, 1.0F);
DelTensorBuf(b);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -954,9 +995,12 @@ void XMathGrad::GradScaleAndShift(XTensor * node, bool isEfficient)
DTYPE scale = income.GetParam(0);
/* dE/da = dE/dc * scale */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad, scale);
}
node->visitMark = NODE_FINISHED;
}
......@@ -980,9 +1024,12 @@ void XMathGrad::GradScale(XTensor * node, bool isEfficient)
DTYPE scale = income.GetParam(0);
/* dE/da = dE/dc * scale */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad, scale);
}
node->visitMark = NODE_FINISHED;
}
......@@ -1006,9 +1053,12 @@ void XMathGrad::GradDescale(XTensor * node, bool isEfficient)
DTYPE descale = income.GetParam(0);
/* dE/da = dE/dc / descale */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad, 1/descale);
_Sum(a->grad, node->grad, a->grad, 1 / descale);
}
node->visitMark = NODE_FINISHED;
}
......@@ -1030,9 +1080,12 @@ void XMathGrad::GradShift(XTensor * node, bool isEfficient)
XTensor * a = income.tails[0];
/* dE/da = dE/dc */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad);
}
node->visitMark = NODE_FINISHED;
}
......@@ -1057,11 +1110,17 @@ void XMathGrad::GradSub(XTensor * node, bool isEfficient)
XTensor * b = income.tails[1];
DTYPE beta = income.GetParam(0);
/* dE/da = dE/dc */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
_Sum(a->grad, node->grad, a->grad);
}
/* dE/db = -dE/dc * \beta */
if (!isEfficient || b->isGrad) {
XNoder::MakeGrad(b);
_Sum(b->grad, node->grad, b->grad, -beta);
}
node->visitMark = NODE_FINISHED;
}
......@@ -1085,16 +1144,21 @@ void XMathGrad::GradSubDim(XTensor * node, bool isEfficient)
XTensor * b = income.tails[1];
int n = income.GetParamInt(0);
DTYPE beta = income.GetParam(1);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
/* dE/da = dE/dc */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad);
}
/* dE/db = - dE/dc * b.reduce(0,...,n-1,n+1,...) * \beta */
if (!isEfficient || b->isGrad) {
XNoder::MakeGrad(b);
int order = a->order;
int dimSize[MAX_TENSOR_DIM_NUM];
memcpy(dimSize, a->dimSize, sizeof(int) * a->order);
if(n == order - 1){
if (n == order - 1) {
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = a->unitNum / dimSize[order - 1];
reshapedSize[1] = dimSize[order - 1];
......@@ -1103,31 +1167,23 @@ void XMathGrad::GradSubDim(XTensor * node, bool isEfficient)
size of b. Then we can reduce the matrix into a row vector. */
node->grad->Reshape(2, reshapedSize);
//if(b->outgo.tailNum > 1){
XTensor * bGradTMP = NewTensorBufV2(b->grad, b->devID, b->mem);
_ReduceSum(node->grad, bGradTMP, 0);
if(beta != 1.0F)
if (beta != 1.0F)
_ScaleAndShiftMe(bGradTMP, beta);
_Sub(b->grad, bGradTMP, b->grad);
DelTensorBuf(bGradTMP);
/*}
else{
_ReduceSum(node->grad, b->grad, 0);
if(beta != 1.0F)
_ScaleAndShiftMe(b->grad, beta);
_ScaleAndShiftMe(b->grad, -1.0F);
}*/
node->grad->Reshape(order, dimSize);
}
else{
else {
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = 1;
reshapedSize[1] = dimSize[n];
reshapedSize[2] = 1;
for(int i = 0; i < order; i++){
if(i < n)
for (int i = 0; i < order; i++) {
if (i < n)
reshapedSize[0] *= dimSize[i];
}
......@@ -1141,25 +1197,17 @@ void XMathGrad::GradSubDim(XTensor * node, bool isEfficient)
_ReduceSum(node->grad, interGrad, 2);
//if(b->outgo.tailNum > 1){
XTensor * bGradTMP = NewTensorBufV2(b->grad, b->devID, b->mem);
_ReduceSum(interGrad, bGradTMP, 0);
if(beta != 1.0F)
if (beta != 1.0F)
_ScaleAndShiftMe(bGradTMP, beta);
_Sub(b->grad, bGradTMP, b->grad);
DelTensorBuf(bGradTMP);
/*}
else{
_ReduceSum(interGrad, b->grad, 0);
if(beta != 1.0F)
_ScaleAndShiftMe(b->grad, beta);
_ScaleAndShiftMe(b->grad, -1.0F);
}*/
node->grad->Reshape(order, dimSize);
DelTensorBuf(interGrad);
}
}
node->visitMark = NODE_FINISHED;
......@@ -1172,7 +1220,6 @@ c = a + b * \beta
we have
dE/da = dE/dc
dE/db = dE/dc * \beta
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
......@@ -1186,12 +1233,14 @@ void XMathGrad::GradSum(XTensor * node, bool isEfficient)
XTensor * b = income.tails[1];
DTYPE beta = income.GetParam(0);
if(!isEfficient || a->isGrad){
/* dE/da = dE/dc */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad);
}
if(!isEfficient || b->isGrad){
/* dE/db = dE/dc * \beta */
if (!isEfficient || b->isGrad) {
XNoder::MakeGrad(b);
_Sum(b->grad, node->grad, b->grad, beta);
}
......@@ -1219,48 +1268,46 @@ void XMathGrad::GradSumDim(XTensor * node, bool isEfficient)
XTensor * b = income.tails[1];
int n = income.GetParamInt(0);
DTYPE beta = income.GetParam(1);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
if (!isEfficient || a->isGrad) {
/* dE/da = dE/dc */
XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad);
}
/* dE/db = dE/dc * a.reduce(0,...,n-1,n+1,...) * \beta */
if (!isEfficient || b->isGrad) {
XNoder::MakeGrad(b);
int order = a->order;
int dimSize[MAX_TENSOR_DIM_NUM];
memcpy(dimSize, a->dimSize, sizeof(int) * a->order);
if(n == order - 1){
if (n == order - 1) {
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = a->unitNum/dimSize[order - 1];
reshapedSize[0] = a->unitNum / dimSize[order - 1];
reshapedSize[1] = dimSize[order - 1];
/* we reshape dE/dc to a matrix whose column number is equal to the
size of b. Then we can reduce the matrix into a row vector. */
node->grad->Reshape(2, reshapedSize);
//if(b->outgo.tailNum > 1){
XTensor * bGradTMP = NewTensorBufV2(b->grad, b->devID, b->mem);
_ReduceSum(node->grad, bGradTMP, 0);
if(beta != 1.0F)
if (beta != 1.0F)
_ScaleAndShiftMe(bGradTMP, beta);
_Sum(bGradTMP, b->grad, b->grad);
DelTensorBuf(bGradTMP);
/*}
else{
_ReduceSum(node->grad, b->grad, 0);
if(beta != 1.0F)
_ScaleAndShiftMe(b->grad, beta);
}*/
node->grad->Reshape(order, dimSize);
}
else{
else {
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = 1;
reshapedSize[1] = dimSize[n];
reshapedSize[2] = 1;
for(int i = 0; i < order; i++){
if(i < n)
for (int i = 0; i < order; i++) {
if (i < n)
reshapedSize[0] *= dimSize[i];
}
......@@ -1274,24 +1321,17 @@ void XMathGrad::GradSumDim(XTensor * node, bool isEfficient)
_ReduceSum(node->grad, interGrad, 2);
//if(b->outgo.tailNum > 1){
XTensor * bGradTMP = NewTensorBufV2(b->grad, b->devID, b->mem);
_ReduceSum(interGrad, bGradTMP, 0);
if(beta != 1.0F)
if (beta != 1.0F)
_ScaleAndShiftMe(bGradTMP, beta);
_Sum(bGradTMP, b->grad, b->grad);
DelTensorBuf(bGradTMP);
/*}
else{
_ReduceSum(interGrad, b->grad, 0);
if(beta != 1.0F)
_ScaleAndShiftMe(b->grad, beta);
}*/
node->grad->Reshape(order, dimSize);
DelTensorBuf(interGrad);
}
}
node->visitMark = NODE_FINISHED;
......@@ -1320,12 +1360,20 @@ void XMathGrad::GradSumBroadcast(XTensor * node, bool isEfficient)
XTensor * b = income.tails[1];
//DTYPE beta = income.GetParam(0);
/* dE/da = dE/dc */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad);
}
if(b->isVar || b->income.tailNum > 0){
/* dE/db = dE/dc * a.reduce(0..n) * \beta */
if (!isEfficient || b->isGrad) {
if (b->isVar || b->income.tailNum > 0) {
ShowNTErrors("TODO");
}
}
node->visitMark = NODE_FINISHED;
}
/*
......@@ -1345,18 +1393,21 @@ void XMathGrad::GradReduceMean(XTensor * node, bool isEfficient)
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for Reduce!");
XTensor * a = income.tails[0];
XTensor * b = NewTensorBufV2(a, a->devID, a->mem);
int dim = income.GetParamInt(0);
int n = a->GetDim(dim);
/* dE/da = Unsqueeze(dE/dc) * 1/dimSizeA[dim] */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Unsqueeze(node->grad, b, dim, n);
_ScaleAndShiftMe(b, 1.0F/n);
_Sum(a->grad, b, a->grad);
XTensor * tmp = NewTensorBufV2(a, a->devID, a->mem);
_Unsqueeze(node->grad, tmp, dim, n);
_ScaleAndShiftMe(tmp, 1.0F / n);
_Sum(a->grad, tmp, a->grad);
DelTensorBuf(b);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -1366,7 +1417,7 @@ gradient for reduceSum
for
c = reduceSum(a, dim)
we have
dE/da = Unsqueeze(dE/dc) * 1
dE/da = Unsqueeze(dE/dc)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
......@@ -1378,17 +1429,19 @@ void XMathGrad::GradReduceSum(XTensor * node, bool isEfficient)
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for Reduce!");
XTensor * a = income.tails[0];
XTensor * b = NewTensorBufV2(a, a->devID, a->mem);
int dim = income.GetParamInt(0);
int n = a->GetDim(dim);
/* dE/da = Unsqueeze(dE/dc) */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Unsqueeze(node->grad, b, dim, n);
_Sum(a->grad, b, a->grad);
DelTensorBuf(b);
XTensor * tmp = NewTensorBufV2(a, a->devID, a->mem);
_Unsqueeze(node->grad, tmp, dim, n);
_Sum(a->grad, tmp, a->grad);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -1419,22 +1472,28 @@ void XMathGrad::GradReduceSumSquared(XTensor * node, bool isEfficient)
int dim = income.GetParamInt(0);
int n = a->GetDim(dim);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
/* compute a-b */
_Unsqueeze(b, c, dim, n);
_Sub(a, c, d);
_ReduceSum(d, f, dim);
/* dE/da_i = Unsqueeze(dE/dc) * 2 * (a_i - b) */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_ScaleAndShiftMe(d, 2.0F);
_Unsqueeze(node->grad, e, dim, n);
_Multiply(d, e, a->grad, 1.0F);
}
/* dE/db = dE/dc * -2 * n * \sum_i (a_i - b) */
if (!isEfficient || b->isGrad) {
XNoder::MakeGrad(b);
_ReduceSum(d, f, dim);
_ScaleAndShiftMe(f, -2.0F);
_Multiply(node->grad, f, b->grad, 1.0F);
}
DelTensorBuf(f);
DelTensorBuf(e);
......@@ -1471,22 +1530,27 @@ void XMathGrad::GradReduceVariance(XTensor * node, bool isEfficient)
int dim = income.GetParamInt(0);
int n = a->GetDim(dim);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
/* compute a-b */
_Unsqueeze(b, c, dim, n);
_Sub(a, c, d);
_ReduceSum(d, f, dim);
/* dE/da_i = Unsqueeze(dE/dc) * 2 * (a_i - b) / n */
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_ScaleAndShiftMe(d, 2.0F / n);
_Unsqueeze(node->grad, e, dim, n);
_Multiply(d, e, a->grad, 1.0F);
}
/* dE/db = dE/dc * -2 * \sum_i (a_i - b) */
_ScaleAndShiftMe(f, -2.0F /n);
if (!isEfficient || b->isGrad) {
XNoder::MakeGrad(b);
_ReduceSum(d, f, dim);
_ScaleAndShiftMe(f, -2.0F / n);
_Multiply(node->grad, f, b->grad, 1.0F);
}
DelTensorBuf(f);
DelTensorBuf(e);
......@@ -1496,7 +1560,6 @@ void XMathGrad::GradReduceVariance(XTensor * node, bool isEfficient)
node->visitMark = NODE_FINISHED;
}
/*
gradient for operation
for c = matmul(x, w) + b
......@@ -1521,11 +1584,8 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
MATRIX_TRANS_TYPE transW = income.GetParamTrans(1);
MATRIX_TRANS_TYPE transX = income.GetParamTrans(2);
if (!isEfficient || w->isGrad)
XNoder::MakeGrad(w);
if (!isEfficient || x->isGrad)
XNoder::MakeGrad(x);
if (!isEfficient || b->isGrad)
/* dE/db = dE/dc * x.reduce(0,...,n-1,n+1,...) */
if (!isEfficient || b->isGrad) {
XNoder::MakeGrad(b);
int order = node->order;
......@@ -1567,7 +1627,6 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
node->grad->Reshape(3, reshapedSize);
XTensor * interGrad = NewTensorBufV2(2, reshapedSize, b->dataType, b->denseRatio, b->devID, b->mem);
_ReduceSum(node->grad, interGrad, 2);
XTensor * bGradTMP = NewTensorBufV2(b->grad, b->devID, b->mem);
......@@ -1578,9 +1637,13 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
node->grad->Reshape(order, dimSize);
DelTensorBuf(interGrad);
}
}
if (!isEfficient || w->isGrad)
XNoder::MakeGrad(w);
if (!isEfficient || x->isGrad)
XNoder::MakeGrad(x);
/* compute dE/dx, dE/dw */
XTensor * c = node;
......@@ -1590,7 +1653,7 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
if (x->order == 2 && w->order == 2)
GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, 1.0F, isEfficient);
else if (transX == X_NOTRANS && x->order > 2 && w->order == 2){
else if (transX == X_NOTRANS && x->order > 2 && w->order == 2) {
int orderBackupX = x->order;
int orderBackupC = c->order;
int dimsBackupX[MAX_TENSOR_DIM_NUM];
......
......@@ -32,33 +32,33 @@
namespace nts{
/* compute dE/dx of a node */
void XShapeGrad::MakeGrad(XTensor * node, bool isEfficent)
void XShapeGrad::MakeGrad(XTensor * node, bool isEfficient)
{
CheckNTErrors(node->grad != NULL, "No gradient found!");
XLink &income = node->income;
int operID = income.typeID;
if(operID == MOVEMENT_COPYINDEXED)
GradCopyIndexed(node, isEfficent);
else if(operID == MOVEMENT_GATHER)
GradGather(node, isEfficent);
if (operID == MOVEMENT_COPYINDEXED)
GradCopyIndexed(node, isEfficient);
else if (operID == MOVEMENT_GATHER)
GradGather(node, isEfficient);
else if (operID == MOVEMENT_DROPOUTWITHINDEX)
GradDropoutWithIndex(node, isEfficent);
else if(operID == SHAPE_MERGE)
GradMerge(node, isEfficent);
else if(operID == SHAPE_MERGE_LIST)
GradMergeList(node, isEfficent);
else if(operID == SHAPE_RESHAPE)
GradReshape(node, isEfficent);
else if(operID == SHAPE_SPLIT)
GradSplit(node, isEfficent);
else if(operID == SHAPE_SPLIT_LIST)
GradSplitList(node, isEfficent);
GradDropoutWithIndex(node, isEfficient);
else if (operID == SHAPE_MERGE)
GradMerge(node, isEfficient);
else if (operID == SHAPE_MERGE_LIST)
GradMergeList(node, isEfficient);
else if (operID == SHAPE_RESHAPE)
GradReshape(node, isEfficient);
else if (operID == SHAPE_SPLIT)
GradSplit(node, isEfficient);
else if (operID == SHAPE_SPLIT_LIST)
GradSplitList(node, isEfficient);
else if (operID == SHAPE_TRANSPOSE)
GradTranspose(node, isEfficent);
else if(operID == SHAPE_UNSQUEEZE)
GradUnsqueeze(node, isEfficent);
GradTranspose(node, isEfficient);
else if (operID == SHAPE_UNSQUEEZE)
GradUnsqueeze(node, isEfficient);
else{
ShowNTErrors("TODO!");
}
......@@ -72,10 +72,10 @@ bool XShapeGrad::IsShapeOP(XTensor * node)
}
/* post processing of a node */
void XShapeGrad::PostProcessing(XTensor * node, int typeID, bool isEfficent)
void XShapeGrad::PostProcessing(XTensor * node, int typeID, bool isEfficient)
{
if(typeID == SHAPE_SPLIT_LIST)
GradSplitListPost(node, isEfficent);
if (typeID == SHAPE_SPLIT_LIST)
GradSplitListPost(node, isEfficient);
}
/*
......@@ -88,7 +88,7 @@ dE/da = spreadforcopyindexed(b)
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XShapeGrad::GradCopyIndexed(XTensor * node, bool isEfficent)
void XShapeGrad::GradCopyIndexed(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum > 0, "Wrong input tensor number for CopyIndexed!");
......@@ -100,8 +100,15 @@ void XShapeGrad::GradCopyIndexed(XTensor * node, bool isEfficent)
XTensor * srcIndex = income.tails[1];
XTensor * tgtIndex = income.tails[2];
if (!isEfficient || input->isGrad) {
XNoder::MakeGrad(input);
_SpreadForCopyIndexed(input->grad, node->grad, dim, srcIndex, tgtIndex, copyNum);
XTensor * tmp = NewTensorBufV2(input, input->devID, input->mem);
_SpreadForCopyIndexed(tmp, node->grad, dim, srcIndex, tgtIndex, copyNum);
_SumMe(input->grad, tmp);
DelTensorBuf(tmp);
}
}
/*
......@@ -114,16 +121,23 @@ dE/da = spreadforgather(b)
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XShapeGrad::GradGather(XTensor * node, bool isEfficent)
void XShapeGrad::GradGather(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum > 0, "Wrong input tensor number for Gather!");
XTensor * input = income.tails[0];
XTensor * index = income.tails[1];
if (!isEfficient || input->isGrad) {
XNoder::MakeGrad(input);
_SpreadForGather(input->grad, node->grad, index);
XTensor * tmp = NewTensorBufV2(input, input->devID, input->mem);
_SpreadForGather(tmp, node->grad, index);
_SumMe(input->grad, tmp);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -131,7 +145,7 @@ void XShapeGrad::GradGather(XTensor * node, bool isEfficent)
/*
gradient computation for DropoutWithIndex function
*/
void XShapeGrad::GradDropoutWithIndex(XTensor * node, bool isEfficent)
void XShapeGrad::GradDropoutWithIndex(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum > 0, "Wrong input tensor number for DropoutWithIndex!");
......@@ -139,28 +153,23 @@ void XShapeGrad::GradDropoutWithIndex(XTensor * node, bool isEfficent)
XTensor * input = income.tails[0];
XTensor * index = income.tails[1];
DTYPE scale = income.GetParam(0);
XNoder::MakeGrad(input);
//_Identity(node->grad, input->grad);
_CopyValues(node->grad, input->grad);
int order = node->grad->order;
int * dimSize = new int[order];
if (!isEfficient || input->isGrad) {
XNoder::MakeGrad(input);
for (int i = 0; i < order; i++) {
dimSize[i] = node->grad->dimSize[i];
}
XTensor * tmp = NewTensorBufV2(input, input->devID, input->mem);
_CopyValues(node->grad, tmp);
int order1 = 1;
int * dimSize1 = new int[order1];
dimSize1[0] = input->grad->unitNum;
tmp->Reshape(tmp->unitNum);
input->grad->Reshape(order1, dimSize1);
_DropoutWithIndex(node->grad, index, tmp);
_ScaleAndShiftMe(tmp, scale);
_DropoutWithIndex(node->grad, index, input->grad);
_ScaleAndShiftMe(input->grad, scale);
tmp->Reshape(input->order, input->dimSize);
_SumMe(input->grad, tmp);
input->grad->Reshape(order, dimSize);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -180,7 +189,7 @@ dE/da = split(dE/dc)
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XShapeGrad::GradMerge(XTensor * node, bool isEfficent)
void XShapeGrad::GradMerge(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
XTensor * input = income.tails[0];
......@@ -191,20 +200,13 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficent)
int whereToMerge = income.GetParamInt(0);
int leadDim = income.GetParamInt(1);
int blockSize = 1;
int blockNum = 1;
for(int i = 0; i < input->order; i++){
if(i < leadDim)
blockNum *= input->dimSize[i];
}
blockSize = input->GetDataSizeInChar() / blockNum;
if (!isEfficient || input->isGrad) {
XNoder::MakeGrad(input);
int * dims = new int[input->order];
memset(dims, 0, sizeof(int) * input->order);
for(int i = 0, j = 0; i < input->order; i++){
if(i >= leadDim){
for (int i = 0, j = 0; i < input->order; i++) {
if (i >= leadDim) {
dims[j++] = input->dimSize[i];
}
}
......@@ -218,10 +220,18 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficent)
node->dataType, node->denseRatio,
node->devID, node->mem);
int blockSize = 1;
int blockNum = 1;
for (int i = 0; i < input->order; i++) {
if (i < leadDim)
blockNum *= input->dimSize[i];
}
blockSize = input->GetDataSizeInChar() / blockNum;
/* we can simply split the gradient tensor
if the input is used in merging only */
if(input->outgo.tailNum == 1){
for(int i = 0; i < blockNum; i++){
if (input->outgo.tailNum == 1) {
for (int i = 0; i < blockNum; i++) {
gradNodeSmall.data = (char*)node->grad->data + i * blockSize;
gradInputSmall.data = (char*)input->grad->data + i * blockSize;
_Split(&gradNodeSmall, &gradInputSmall, whereToMerge - leadDim - 1, input->dimSize[leadDim]);
......@@ -232,10 +242,10 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficent)
other operations somewhere else. So we have to do gradient
accumulation after spliting, i.e., we need an additional
SUM operation */
else{
else {
XTensor gradInputSmallBuf(&gradInputSmall);
for(int i = 0; i < blockNum; i++){
for (int i = 0; i < blockNum; i++) {
gradNodeSmall.data = (char*)node->grad->data + i * blockSize;
gradInputSmall.data = (char*)input->grad->data + i * blockSize;
_Split(&gradNodeSmall, &gradInputSmallBuf, whereToMerge - leadDim - 1, input->dimSize[leadDim]);
......@@ -247,6 +257,7 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficent)
gradInputSmall.data = NULL;
delete[] dims;
}
node->visitMark = NODE_FINISHED;
}
......@@ -274,18 +285,18 @@ void XShapeGrad::GradMergeList(XTensor * node, bool isEfficient)
TensorList smalls(income.tailNum);
TensorList smallsGrad(income.tailNum);
bool mergeOnly = true;
for(int i = 0; i < income.tailNum; i++){
for (int i = 0; i < income.tailNum; i++) {
/* TODO! efficient backpropagate */
XTensor * tail = income.tails[i];
XNoder::MakeGrad(tail);
smalls.Add(tail);
smallsGrad.Add(tail->grad);
if(i > 1){
CheckNTErrors(_IsSameShaped(last, tail),
"Input tensors must be of the same size!");
}
if (i > 1)
CheckNTErrors(_IsSameShaped(last, tail), "Input tensors must be of the same size!");
if(tail->outgo.tailNum > 1)
if (tail->outgo.tailNum > 1)
mergeOnly = false;
last = tail;
......@@ -295,7 +306,7 @@ void XShapeGrad::GradMergeList(XTensor * node, bool isEfficient)
/* we can simply split the gradient tensor into the input tensors
if the inputs are used in merging only */
if(mergeOnly)
if (mergeOnly)
_Split(node->grad, &smallsGrad, whereToMerge, smalls.count);
/* a more complicated case is that the input tensors are used for
......@@ -321,7 +332,7 @@ void XShapeGrad::GradMergeList(XTensor * node, bool isEfficient)
last->devID, last->mem);
/* gradient accumulation for each split */
for(int i = 0; i < smalls.count; i++){
for (int i = 0; i < smalls.count; i++) {
XTensor * inputGrad = (XTensor*)smallsGrad.Get(i);
gradSmall.data = (char*)gradSplit.data + i * last->unitNum * last->unitSize;
_Sum(inputGrad, &gradSmall, inputGrad);
......@@ -344,17 +355,20 @@ dE/da = reshape(dE/db)
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XShapeGrad::GradReshape(XTensor * node, bool isEfficent)
void XShapeGrad::GradReshape(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for RESHAPE!");
XTensor * input = income.tails[0];
XNoder::MakeGrad(input);
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for MERGE!");
if (!isEfficient || input->isGrad) {
XNoder::MakeGrad(input);
node->grad->Reshape(input->order, input->dimSize);
_CopyValues(node->grad, input->grad);
node->grad->Reshape(node->order, node->dimSize);
}
node->visitMark = NODE_FINISHED;
}
......@@ -381,16 +395,17 @@ void XShapeGrad::GradSplit(XTensor * node, bool isEfficient)
CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!");
CheckNTErrors(splitNum == node->dimSize[0], "Wrong split number!");
if (!isEfficient || input->isGrad) {
XNoder::MakeGrad(input);
/* we can simply merge the gradient tensor
if the input is used in spliting only */
if(input->outgo.tailNum == 1)
if (input->outgo.tailNum == 1)
_Merge(node->grad, input->grad, whereToSplit + 1, 0);
/* if the tensor is used somewhere else, we need another SUM
for gradient accumulation */
else{
else {
XTensor * inputGradTMP = NewTensorBufV2(input, input->devID, input->mem);
_Merge(node->grad, inputGradTMP, whereToSplit + 1, 0);
......@@ -398,6 +413,7 @@ void XShapeGrad::GradSplit(XTensor * node, bool isEfficient)
DelTensorBuf(inputGradTMP);
}
}
node->visitMark = NODE_FINISHED;
}
......@@ -444,14 +460,14 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient)
int whereToSplit = -1;
int splitNum = 0;
for(int i = 0; i < outgo.tailNum; i++){
for (int i = 0; i < outgo.tailNum; i++) {
XTensor * parent = (XTensor*)outgo.tails[i];
XLink &income = parent->income;
if(income.typeID == SHAPE_SPLIT_LIST){
if (income.typeID == SHAPE_SPLIT_LIST) {
int w = income.GetParamInt(0);
int splitID = income.GetParamInt(1);
if(whereToSplit < 0)
if (whereToSplit < 0)
whereToSplit = w;
splitNum++;
......@@ -463,18 +479,19 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient)
}
}
if (!isEfficient || node->isGrad) {
XNoder::MakeGrad(node);
/* we can simply merge the gradient tensor
if the node is used in spliting only */
if(outgo.tailNum == splitNum){
if (outgo.tailNum == splitNum) {
_Merge(&splits, node->grad, whereToSplit);
}
/* if the tensor is used as input to other nodes
somewhere else, we need another SUM for gradient
accumulation */
else{
else {
XTensor * nodeGradTMP = NewTensorBufV2(node, node->devID, node->mem);
_Merge(&splits, nodeGradTMP, whereToSplit + 1);
......@@ -482,6 +499,7 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient)
DelTensorBuf(nodeGradTMP);
}
}
}
/*
......@@ -501,7 +519,9 @@ void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient)
XTensor * output = node;
XTensor * input = income.tails[0];
XTensor * b = NewTensorBufV2(input, input->devID, input->mem);
if (!isEfficient || input->isGrad) {
XNoder::MakeGrad(input);
int i = income.GetParamInt(0);
......@@ -510,10 +530,12 @@ void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient)
CheckNTErrors(input->order > i && i >= 0, "index of dimension is out of scope!");
CheckNTErrors(input->order > j && j >= 0, "index of dimension is out of scope!");
_Transpose(output->grad, b, i, j);
_Sum(input->grad, b, input->grad);
XTensor * tmp = NewTensorBufV2(input, input->devID, input->mem);
_Transpose(output->grad, tmp, i, j);
_Sum(input->grad, tmp, input->grad);
DelTensorBuf(b);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......@@ -535,7 +557,6 @@ void XShapeGrad::GradUnsqueeze(XTensor * node, bool isEfficient)
XTensor * output = node;
XTensor * input = income.tails[0];
XNoder::MakeGrad(input);
int dim = income.GetParamInt(0);
int dSize = income.GetParamInt(1);
......@@ -543,12 +564,16 @@ void XShapeGrad::GradUnsqueeze(XTensor * node, bool isEfficient)
CheckNTErrors(dSize == output->GetDim(dim), "Wrong dim size for UNSQUEEZE!");
CheckNTErrors(output->unitNum = input->unitNum * dSize, "Wrong tensor size!");
XTensor * g = NewTensorBufV2(input->grad, input->devID, input->mem);
if (!isEfficient || input->isGrad) {
XNoder::MakeGrad(input);
_ReduceSum(output->grad, g, dim);
_Sum(input->grad, g, input->grad);
XTensor * tmp = NewTensorBufV2(input->grad, input->devID, input->mem);
DelTensorBuf(g);
_ReduceSum(output->grad, tmp, dim);
_Sum(input->grad, tmp, input->grad);
DelTensorBuf(tmp);
}
node->visitMark = NODE_FINISHED;
}
......
......@@ -316,7 +316,6 @@ void XNet::ClearGrad(XTensor * node)
}
if(finished){
//fprintf(stderr, "del %d %ld\n", node->id, node->grad->unitNum);
delete node->grad;
node->grad = NULL;
}
......
......@@ -62,7 +62,7 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
/* we transform a higher order tensor to a matrix to kill the number
of calls of matrix multiplication */
if(transposedA == X_NOTRANS && a->order > 2 && b->order == 2){
if (transposedA == X_NOTRANS && a->order > 2 && b->order == 2) {
int ncolA = a->dimSize[a->order - 1];
int ncolC = c->dimSize[c->order - 1];
XTensor * a2 = NewTensor2DV2(a->unitNum/ncolA, -ncolA, a->dataType, a->devID, a->mem);
......
......@@ -199,8 +199,8 @@ void funcName(const XTensor &a, const XTensor &b, XTensor c)
}
#ifdef USE_CUDA
_SIMPLE_MAX_MIN_FUNCTION(_Max, _CudaMax, max)
_SIMPLE_MAX_MIN_FUNCTION(_Min, _CudaMin, min)
_SIMPLE_MAX_MIN_FUNCTION(_Max, _CudaMax, MAX)
_SIMPLE_MAX_MIN_FUNCTION(_Min, _CudaMin, MIN)
#else
_SIMPLE_MAX_MIN_FUNCTION(_Max, max)
_SIMPLE_MAX_MIN_FUNCTION(_Min, min)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __SPLIT_H__
#define __SPLIT_H__
......
......@@ -85,7 +85,7 @@ XTensor Stack(const TensorList &smalls, int dim)
{
int count = smalls.count;
CheckNTErrors(count > 0, "Empty list!");
CheckNTErrors(dim >= 0, "Illegal dimension to concatenate!");
CheckNTErrors(dim >= 0, "Illegal dimension to Stack!");
XTensor * tensor = smalls.GetItem(0);
int order = tensor->order + 1;
......@@ -95,7 +95,7 @@ XTensor Stack(const TensorList &smalls, int dim)
if (i < dim)
dimSize[i] = tensor->GetDim(i);
else if (i > dim)
dimSize[i] = tensor->GetDim(i);
dimSize[i] = tensor->GetDim(i-1);
else if (i == dim)
dimSize[i] = count;
}
......@@ -149,7 +149,7 @@ void Stack(const TensorList &smalls, XTensor &t, int dim)
{
int count = smalls.count;
CheckNTErrors(count > 0, "Empty list!");
CheckNTErrors(dim >= 0, "Illegal dimension to concatenate!");
CheckNTErrors(dim >= 0, "Illegal dimension to Stack!");
if (!t.isInit || !CheckStackShape(smalls, t, dim)) {
XTensor * tensor = smalls.GetItem(0);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论