Commit 5bf12b8c by xiaotong

back propagation for Merge

parent b8415485
...@@ -65,6 +65,7 @@ int main( int argc, const char ** argv ) ...@@ -65,6 +65,7 @@ int main( int argc, const char ** argv )
a.Dump(stderr, "a:"); a.Dump(stderr, "a:");
b.Dump(stderr, "b:"); b.Dump(stderr, "b:");
c.Dump(stderr, "c:"); c.Dump(stderr, "c:");
XLink::ShowNetwork(stderr, &c);
net.Backward(c); net.Backward(c);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* backward computation for math operations
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-19
* It was chilly when I came into the office this morning ...
* because i forgot to turn the air-condition off last night :(
*/
#include "XNoder.h"
#include "XBackwardShape.h"
#include "../tensor/XName.h"
#include "../tensor/core/CHeader.h"
namespace nts{
/* compute dE/dx of a node */
void XShapeGrad::MakeGrad(XTensor * node)
{
CheckNTErrors(node->grad != NULL, "No gradient found!");
XLink &income = node->income;
int operID = income.typeID;
if(operID == SHAPE_MERGE)
GradMerge(node);
else if(operID == SHAPE_MERGE_LIST)
GradMergeList(node);
else{
ShowNTErrors("TODO!");
}
}
/* indicates whether the node is for a math operation */
bool XShapeGrad::IsShapeOP(XTensor * node)
{
XLink &income = node->income;
return (income.typeID & DATA_BASE) != 0;
}
/*
gradient for merge
for
c = merge(a_0, a_1, ...)
where a_i is the i-th block in a tensor a
we have
dE/da_0 = dE/dc_{split_0}
dE/db_1 = dE/dc_{split_1}
...
i.e.,
dE/da = split(dE/dc)
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradMerge(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 0, "Wrong input tensor number for MERGE!");
XTensor * input = income.tails[0];
int whereToMerge = income.GetParamInt(0);
int leadDim = income.GetParamInt(1);
int blockSize = 1;
int blockNum = 1;
for(int i = 0; i < input->order; i++){
if(i < leadDim)
blockNum *= input->dimSize[i];
}
blockSize = input->GetDataSizeInChar() / blockNum;
XNoder::MakeGrad(input);
int * dims = new int[input->order];
for(int i = 0, j = 0; i < input->order; i++){
if(i >= leadDim){
dims[j++] = input->dimSize[i];
}
}
dims[0] = -dims[0];
XTensor gradInputSmall(input->order - leadDim, dims,
input->dataType, input->denseRatio,
input->devID, input->mem);
dims[whereToMerge - leadDim] *= dims[0];
XTensor gradNodeSmall(node->order - leadDim, dims,
node->dataType, node->denseRatio,
node->devID, node->mem);
/* we can simply split the gradient tensor
if the input is used in merging only */
if(input->outgo.tailNum == 1){
for(int i = 0; i < blockNum; i++){
gradNodeSmall.data = (char*)node->grad->data + i * blockSize;
gradInputSmall.data = (char*)input->grad->data + i * blockSize;
_Split(&gradNodeSmall, &gradInputSmall, whereToMerge - leadDim, input->dimSize[leadDim]);
}
}
/* a more complicated case is that the input tensor is used for
other operations somewhere else. So we have to do gradient
accumulation after spliting, i.e., we need an additional
SUM operation */
else{
XTensor gradInputSmallBuf(&gradInputSmall);
for(int i = 0; i < blockNum; i++){
gradNodeSmall.data = (char*)node->grad->data + i * blockSize;
gradInputSmall.data = (char*)input->grad->data + i * blockSize;
_Split(&gradNodeSmall, &gradInputSmallBuf, whereToMerge - leadDim, input->dimSize[leadDim]);
_Sum(&gradInputSmall, &gradInputSmallBuf, &gradInputSmall);
}
}
gradNodeSmall.data = NULL;
gradInputSmall.data = NULL;
delete[] dims;
}
/*
gradient for merging a list of tensors
for
c = merge(list(a, b, ...))
where a, b ... are of the same size
we have
dE/da = dE/dc_{split_0}
dE/db = dE/dc_{split_1}
i.e.,
list(dE/da, dE/db, ...) = split(dE/dc)
*/
void XShapeGrad::GradMergeList(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum > 0, "Wrong input tensor number for MERGE!");
XTensor * last = NULL;
XList smalls(income.tailNum);
XList smallsGrad(income.tailNum);
bool mergeOnly = true;
for(int i = 0; i < income.tailNum; i++){
XTensor * tail = income.tails[i];
XNoder::MakeGrad(tail);
smalls.Add(tail);
smallsGrad.Add(tail->grad);
if(i > 1){
CheckNTErrors(XTensor::IsIdentical(last, tail),
"Input tensors must be of the same size!");
}
if(tail->outgo.tailNum > 1)
mergeOnly = false;
last = tail;
}
int whereToMerge = income.GetParamInt(0);
/* we can simply split the gradient tensor into the input tensors
if the inputs are used in merging only */
if(mergeOnly)
_Split(node->grad, &smallsGrad, whereToMerge, smalls.count);
/* a more complicated case is that the input tensors are used for
other operations somewhere else. So we have to do gradient
accumulation after spliting, i.e., we need an additional
SUM operation */
else{
int * dims = new int[last->order + 1];
dims[0] = smalls.count;
for(int i = 0; i < last->order; i++)
dims[i + 1] = last->dimSize[i];
XTensor gradSplit(last->order + 1, dims,
last->dataType, last->denseRatio,
last->devID, last->mem);
_Split(node->grad, &gradSplit, whereToMerge, smalls.count);
memcpy(dims, last->dimSize, sizeof(int) * last->order);
dims[0] = -dims[0];
XTensor gradSmall(last->order, dims,
last->dataType, last->denseRatio,
last->devID, last->mem);
/* gradient accumulation for each split */
for(int i = 0; i < smalls.count; i++){
XTensor * smallGrad = (XTensor*)smallsGrad.Get(i);
gradSmall.data = (char*)gradSplit.data + i * last->unitNum * last->unitSize;
_Sum(smallGrad, &gradSmall, smallGrad);
}
delete[] dims;
}
}
}
\ No newline at end of file
...@@ -28,6 +28,28 @@ ...@@ -28,6 +28,28 @@
namespace nts{ namespace nts{
/* this class computes the gradient for tensor shaping and movement given a node */
class XShapeGrad
{
public:
/* compute dE/dx of a node */
static
void MakeGrad(XTensor * node);
/* indicates whether the node is for a shaping operation */
static
bool IsShapeOP(XTensor * node);
private:
/* gradient for merge: c = merge(a, b, ...) */
static
void GradMerge(XTensor * node);
/* gradient for merging a list of tensors : c = merge(list(a, b, ...)) */
static
void GradMergeList(XTensor * node);
};
} }
#endif #endif
\ No newline at end of file
...@@ -71,10 +71,14 @@ const char * GetOPName(int type) ...@@ -71,10 +71,14 @@ const char * GetOPName(int type)
return "S_CONCATENATE"; return "S_CONCATENATE";
else if(type == SHAPE_MERGE) else if(type == SHAPE_MERGE)
return "S_MERGE"; return "S_MERGE";
else if(type == SHAPE_MERGE_LIST)
return "S_MERGE_LIST";
else if(type == SHAPE_PERMUTE) else if(type == SHAPE_PERMUTE)
return "S_PERMUTE"; return "S_PERMUTE";
else if(type == SHAPE_SPLIT) else if(type == SHAPE_SPLIT)
return "S_SPLIT"; return "S_SPLIT";
else if(type == SHAPE_SPLIT_LIST)
return "S_SPLIT_LIST";
else if(type == SHAPE_TRANSPOSE) else if(type == SHAPE_TRANSPOSE)
return "S_TRANSPOSE"; return "S_TRANSPOSE";
else if(type == SHAPE_UNSQUEEZE) else if(type == SHAPE_UNSQUEEZE)
......
...@@ -62,9 +62,11 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -62,9 +62,11 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define SHAPE REDUCE_REDUCEVARIANCE + 1 #define SHAPE REDUCE_REDUCEVARIANCE + 1
#define SHAPE_CONCATENATE SHAPE + 1 #define SHAPE_CONCATENATE SHAPE + 1
#define SHAPE_MERGE SHAPE_CONCATENATE + 1 #define SHAPE_MERGE SHAPE_CONCATENATE + 1
#define SHAPE_PERMUTE SHAPE_MERGE + 1 #define SHAPE_MERGE_LIST SHAPE_MERGE + 1
#define SHAPE_PERMUTE SHAPE_MERGE_LIST + 1
#define SHAPE_SPLIT SHAPE_PERMUTE + 1 #define SHAPE_SPLIT SHAPE_PERMUTE + 1
#define SHAPE_TRANSPOSE SHAPE_SPLIT + 1 #define SHAPE_SPLIT_LIST SHAPE_SPLIT + 1
#define SHAPE_TRANSPOSE SHAPE_SPLIT_LIST + 1
#define SHAPE_UNSQUEEZE SHAPE_TRANSPOSE + 1 #define SHAPE_UNSQUEEZE SHAPE_TRANSPOSE + 1
/* activation functions */ /* activation functions */
......
...@@ -80,40 +80,19 @@ XTensor Concatenate(const XList &smalls, int dim) ...@@ -80,40 +80,19 @@ XTensor Concatenate(const XList &smalls, int dim)
if (!XTensor::IsIdentical(a, b)) if (!XTensor::IsIdentical(a, b))
uniform = false; uniform = false;
} }
int * dimSize;
if (uniform) {
XTensor * tensor = (XTensor*)smalls.GetItem(0); XTensor * tensor = (XTensor*)smalls.GetItem(0);
int order = tensor->order; int order = tensor->order;
dimSize = new int[order]; int * dimSize = new int[order];
if (uniform) {
for (int i = 0; i < tensor->order; i++) { for (int i = 0; i < tensor->order; i++) {
if (i != dim) if (i != dim)
dimSize[i] = tensor->dimSize[i]; dimSize[i] = tensor->dimSize[i];
else else
dimSize[i] = tensor->dimSize[dim] * smalls.count; dimSize[i] = tensor->dimSize[dim] * smalls.count;
} }
XTensor big = XTensor(order, dimSize, tensor->dataType, tensor->denseRatio, tensor->devID, tensor->mem);
big.SetZeroAll();
big.SetTMP();
/* call _Merge function */
_Merge(&smalls, &big, dim);
///* tensor connection */
//XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
//XLink::AddParamToHead(&big, dim);
/* destroy variables */
delete dimSize;
return big;
} }
else { else {
XTensor * tensor = (XTensor*)smalls.GetItem(0);
int order = tensor->order;
dimSize = new int[order];
for (int i = 0; i < tensor->order; i++) for (int i = 0; i < tensor->order; i++)
if (i != dim) if (i != dim)
dimSize[i] = tensor->dimSize[i]; dimSize[i] = tensor->dimSize[i];
...@@ -124,19 +103,24 @@ XTensor Concatenate(const XList &smalls, int dim) ...@@ -124,19 +103,24 @@ XTensor Concatenate(const XList &smalls, int dim)
catDimSize += tensor->dimSize[dim]; catDimSize += tensor->dimSize[dim];
} }
dimSize[dim] = catDimSize; dimSize[dim] = catDimSize;
}
XTensor big = XTensor(order, dimSize, tensor->dataType, tensor->denseRatio, tensor->devID, tensor->mem);
XTensor big = NewTensor(order, dimSize, tensor->dataType, tensor->denseRatio, tensor->devID, tensor->mem);
big.SetZeroAll(); big.SetZeroAll();
big.SetTMP(); big.SetTMP();
/* call _ConcatenateSolely function */ /* call _Merge function */
_ConcatenateSolely(&smalls, &big, dim); _Merge(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return big; return big;
}
} }
/* /*
...@@ -168,6 +152,8 @@ make a new tensor to keep the result and return it. ...@@ -168,6 +152,8 @@ make a new tensor to keep the result and return it.
*/ */
XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim) XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
{ {
ShowNTErrors("rewrite this!!!!!!!!");
XList smalls(2); XList smalls(2);
smalls.Add(&smallA); smalls.Add(&smallA);
smalls.Add(&smallB); smalls.Add(&smallB);
......
...@@ -187,6 +187,11 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim) ...@@ -187,6 +187,11 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
/* call _Merge function */ /* call _Merge function */
_Merge(&s, &t, whereToMerge, leadingDim); _Merge(&s, &t, whereToMerge, leadingDim);
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
XLink::AddParamToHeadInt(&t, whereToMerge);
XLink::AddParamToHeadInt(&t, leadingDim);
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -327,13 +332,19 @@ XTensor Merge(const XList &smalls, int whereToMerge) ...@@ -327,13 +332,19 @@ XTensor Merge(const XList &smalls, int whereToMerge)
dimSize[i] = tensor->dimSize[whereToMerge] * smalls.count; dimSize[i] = tensor->dimSize[whereToMerge] * smalls.count;
} }
XTensor big = NewTensor(order, dimSize, tensor->dataType, tensor->denseRatio, tensor->devID, tensor->mem); XTensor big = NewTensor(order, dimSize,
tensor->dataType, tensor->denseRatio,
tensor->devID, tensor->mem);
big.SetZeroAll(); big.SetZeroAll();
big.SetTMP(); big.SetTMP();
/* call _Merge function */ /* call _Merge function */
_Merge(&smalls, &big, whereToMerge); _Merge(&smalls, &big, whereToMerge);
/* tensor connections */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
XLink::AddParamToHeadInt(&big, whereToMerge);
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -29,20 +29,16 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -29,20 +29,16 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* transform a tensor by merging it alone with a dimension, e.g., (M, N/3, 3) -> (M, N) */ /* transform a tensor by merging it alone with a dimension, e.g., (M, N/3, 3) -> (M, N) */
void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim = -1); void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim = -1);
/* /* transform a tensor by merging it alone with a dimension (return a XTensor structure).
transform a tensor by merging it alone with a dimension (return a XTensor structure). make a new tensor to keep the result and return it.
make a new tensor to keep the result and return it. e.g., (M, N/3, 3) -> (M, N) */
e.g., (M, N/3, 3) -> (M, N)
*/
XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim = -1); XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim = -1);
/* merge small tensors into a big tensor */ /* merge small tensors into a big tensor */
void _Merge(const XList * smalls, XTensor * big, int whereToMerge); void _Merge(const XList * smalls, XTensor * big, int whereToMerge);
/* /* merge small tensors into a big tensor (return a XTensor structure).
merge small tensors into a big tensor (return a XTensor structure). make a new tensor to keep the result and return it. */
make a new tensor to keep the result and return it.
*/
XTensor Merge(const XList &smalls, int whereToMerge); XTensor Merge(const XList &smalls, int whereToMerge);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24 * $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/ */
#include "../../XTensor.h"
#include "../../XUtility.h"
#include "Split.h" #include "Split.h"
#include "MakeSplitBlockIndex.h" #include "MakeSplitBlockIndex.h"
#include "../../XName.h"
#include "../../XTensor.h"
#include "../../XUtility.h"
#include "../movement/CopyBlocksOnSite.h" #include "../movement/CopyBlocksOnSite.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
...@@ -161,6 +163,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum) ...@@ -161,6 +163,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
/* call _Split function */ /* call _Split function */
_Split(&s, &t, whereToSplit, splitNum); _Split(&s, &t, whereToSplit, splitNum);
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT);
XLink::AddParamToHeadInt(&t, whereToSplit);
XLink::AddParamToHeadInt(&t, splitNum);
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -298,7 +305,9 @@ XList SplitList(const XTensor &big, int whereToSplit, int splitNum) ...@@ -298,7 +305,9 @@ XList SplitList(const XTensor &big, int whereToSplit, int splitNum)
} }
for (int i = 0; i < splitNum; i++) { for (int i = 0; i < splitNum; i++) {
XTensor tensor = NewTensor(order, dimSize, big.dataType, big.denseRatio, big.devID, big.mem); XTensor tensor = NewTensor(order, dimSize,
big.dataType, big.denseRatio,
big.devID, big.mem);
tensor.SetZeroAll(); tensor.SetZeroAll();
tensor.SetTMP(); tensor.SetTMP();
smalls.Add(&tensor); smalls.Add(&tensor);
...@@ -307,6 +316,17 @@ XList SplitList(const XTensor &big, int whereToSplit, int splitNum) ...@@ -307,6 +316,17 @@ XList SplitList(const XTensor &big, int whereToSplit, int splitNum)
/* call _Split function */ /* call _Split function */
_Split(&big, &smalls, whereToSplit, splitNum); _Split(&big, &smalls, whereToSplit, splitNum);
/* tensor connections */
for(int i = 0; i < smalls.count; i++){
XTensor * s = (XTensor*)smalls.Get(i);
XLink::MakeLink(&big, NULL, s, SHAPE_SPLIT_LIST);
XLink::AddParamToHeadInt(s, whereToSplit);
/* it is tricky here that we keep the id of each
block, rather than the total number of splits */
XLink::AddParamToHeadInt(s, i);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论