Commit 37b7e09b by xiaotong

fix the bug in dimension setting in the back propagation of Merge

parent 9e5887dd
...@@ -71,9 +71,11 @@ dE/da = split(dE/dc) ...@@ -71,9 +71,11 @@ dE/da = split(dE/dc)
void XShapeGrad::GradMerge(XTensor * node) void XShapeGrad::GradMerge(XTensor * node)
{ {
XLink &income = node->income; XLink &income = node->income;
CheckNTErrors(income.tailNum == 0, "Wrong input tensor number for MERGE!");
XTensor * input = income.tails[0]; XTensor * input = income.tails[0];
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for MERGE!");
CheckNTErrors(node->order == input->order - 1, "wrong tensor orders!");
int whereToMerge = income.GetParamInt(0); int whereToMerge = income.GetParamInt(0);
int leadDim = income.GetParamInt(1); int leadDim = income.GetParamInt(1);
...@@ -95,13 +97,13 @@ void XShapeGrad::GradMerge(XTensor * node) ...@@ -95,13 +97,13 @@ void XShapeGrad::GradMerge(XTensor * node)
} }
dims[0] = -dims[0]; dims[0] = -dims[0];
XTensor gradInputSmall(input->order - leadDim, dims, XTensor gradInputSmall(input->order - leadDim, dims,
input->dataType, input->denseRatio, input->dataType, input->denseRatio,
input->devID, input->mem); input->devID, input->mem);
dims[whereToMerge - leadDim] *= dims[0]; dims[whereToMerge - leadDim] *= dims[0];
XTensor gradNodeSmall(node->order - leadDim, dims, XTensor gradNodeSmall(node->order - leadDim, dims + leadDim + 1,
node->dataType, node->denseRatio, node->dataType, node->denseRatio,
node->devID, node->mem); node->devID, node->mem);
/* we can simply split the gradient tensor /* we can simply split the gradient tensor
if the input is used in merging only */ if the input is used in merging only */
...@@ -109,7 +111,7 @@ void XShapeGrad::GradMerge(XTensor * node) ...@@ -109,7 +111,7 @@ void XShapeGrad::GradMerge(XTensor * node)
for(int i = 0; i < blockNum; i++){ for(int i = 0; i < blockNum; i++){
gradNodeSmall.data = (char*)node->grad->data + i * blockSize; gradNodeSmall.data = (char*)node->grad->data + i * blockSize;
gradInputSmall.data = (char*)input->grad->data + i * blockSize; gradInputSmall.data = (char*)input->grad->data + i * blockSize;
_Split(&gradNodeSmall, &gradInputSmall, whereToMerge - leadDim, input->dimSize[leadDim]); _Split(&gradNodeSmall, &gradInputSmall, whereToMerge - leadDim - 1, input->dimSize[leadDim]);
} }
} }
...@@ -123,7 +125,7 @@ void XShapeGrad::GradMerge(XTensor * node) ...@@ -123,7 +125,7 @@ void XShapeGrad::GradMerge(XTensor * node)
for(int i = 0; i < blockNum; i++){ for(int i = 0; i < blockNum; i++){
gradNodeSmall.data = (char*)node->grad->data + i * blockSize; gradNodeSmall.data = (char*)node->grad->data + i * blockSize;
gradInputSmall.data = (char*)input->grad->data + i * blockSize; gradInputSmall.data = (char*)input->grad->data + i * blockSize;
_Split(&gradNodeSmall, &gradInputSmallBuf, whereToMerge - leadDim, input->dimSize[leadDim]); _Split(&gradNodeSmall, &gradInputSmallBuf, whereToMerge - leadDim - 1, input->dimSize[leadDim]);
_Sum(&gradInputSmall, &gradInputSmallBuf, &gradInputSmall); _Sum(&gradInputSmall, &gradInputSmallBuf, &gradInputSmall);
} }
} }
......
...@@ -73,8 +73,7 @@ void MakeWordBatch(XTensor &batch, NGram * ngrams, int ngramNum, int n, int vSiz ...@@ -73,8 +73,7 @@ void MakeWordBatch(XTensor &batch, NGram * ngrams, int ngramNum, int n, int vSiz
void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net); void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net);
void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NAME loss, void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NAME loss,
FNNModel &model, FNNModel &grad, FNNNet &net); FNNModel &model, FNNModel &grad, FNNNet &net);
void FBInOne(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NAME loss, void ForwardAutoDiff(XTensor inputs[], XTensor &output, FNNModel &model);
FNNModel &model, XNet &net);
/* /*
entry of the program entry of the program
...@@ -415,7 +414,10 @@ void Train(const char * train, bool isShuffled, FNNModel &model) ...@@ -415,7 +414,10 @@ void Train(const char * train, bool isShuffled, FNNModel &model)
} }
else{ else{
/* forward + backward process */ /* forward + backward process */
FBInOne(inputs, output, gold, CROSSENTROPY, model, autoDiffer); ForwardAutoDiff(inputs, output, model);
/* automatic differentiation */
autoDiffer.Backward(output, gold, CROSSENTROPY);
/* update model parameters */ /* update model parameters */
Update(model, grad, learningRate, true); Update(model, grad, learningRate, true);
...@@ -902,17 +904,14 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA ...@@ -902,17 +904,14 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
} }
/* /*
forward + backward in one procedure forward process (with tensor connections)
>> inputs - input word representations >> inputs - input word representations
>> output - output probability >> output - output probability
>> gold - gold standard
>> loss - loss function name
>> model - the fnn model >> model - the fnn model
*/ */
void FBInOne(XTensor inputs[], XTensor &output, XTensor &gold, void ForwardAutoDiff(XTensor inputs[], XTensor &output, FNNModel &model)
LOSS_FUNCTION_NAME loss, FNNModel &model, XNet &net)
{ {
int batchSize = gold.GetDim(0); int batchSize = inputs[0].GetDim(0);
int n = model.n; int n = model.n;
int depth = model.hDepth; int depth = model.hDepth;
...@@ -945,9 +944,6 @@ void FBInOne(XTensor inputs[], XTensor &output, XTensor &gold, ...@@ -945,9 +944,6 @@ void FBInOne(XTensor inputs[], XTensor &output, XTensor &gold,
/* output layer */ /* output layer */
output = LogSoftmax(MMul(hidden, model.outputW) + b, 1); output = LogSoftmax(MMul(hidden, model.outputW) + b, 1);
/* automatic differentiation */
net.Backward(output);
} }
/* /*
......
...@@ -49,7 +49,7 @@ void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim) ...@@ -49,7 +49,7 @@ void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
CheckNTErrors((s != NULL && t != NULL), "Invalid tensors!"); CheckNTErrors((s != NULL && t != NULL), "Invalid tensors!");
CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)), CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)),
"the data must be kept on the same device!"); "the data must be kept on the same device!");
CheckNTErrors((s->unitNum == t->unitNum && s->unitSize == t->unitSize), "Unmatched tensors!"); CheckNTErrors((s->unitNum == t->unitNum && s->unitSize == t->unitSize), "Unmatched tensors!");
CheckNTErrors((s->order == t->order + 1), "Unmatched tensors!"); CheckNTErrors((s->order == t->order + 1), "Unmatched tensors!");
...@@ -58,11 +58,11 @@ void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim) ...@@ -58,11 +58,11 @@ void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
for (int i = 0; i < s->order; i++) { for (int i = 0; i < s->order; i++) {
if (i == whereToMergeRDI) { if (i == whereToMergeRDI) {
CheckNTErrors((t->dimSizeRDI[i] == s->dimSizeRDI[i] * s->dimSizeRDI[leadingDimRDI]), CheckNTErrors((t->dimSizeRDI[i] == s->dimSizeRDI[i] * s->dimSizeRDI[leadingDimRDI]),
"Unmatched tensor sizes!"); "Unmatched tensor sizes!");
} }
else if (i > leadingDimRDI) { else if (i > leadingDimRDI) {
CheckNTErrors((s->dimSizeRDI[i - 1] == t->dimSizeRDI[i]), CheckNTErrors((s->dimSizeRDI[i - 1] == t->dimSizeRDI[i]),
"Unmatched tensor sizes!"); "Unmatched tensor sizes!");
} }
} }
......
...@@ -41,7 +41,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum) ...@@ -41,7 +41,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
{ {
CheckNTErrors((s && t), "Invalid tensors!"); CheckNTErrors((s && t), "Invalid tensors!");
CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)), CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)),
"the data must be kept on the same device!"); "the data must be kept on the same device!");
CheckNTErrors((s->unitNum == t->unitNum && s->unitSize == t->unitSize), "Unmatched tensors!"); CheckNTErrors((s->unitNum == t->unitNum && s->unitSize == t->unitSize), "Unmatched tensors!");
CheckNTErrors((s->order == t->order - 1), "Unmatched tensors!"); CheckNTErrors((s->order == t->order - 1), "Unmatched tensors!");
...@@ -51,11 +51,11 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum) ...@@ -51,11 +51,11 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
for (int i = 0; i < s->order; i++) { for (int i = 0; i < s->order; i++) {
if (i == whereToSplitRDI) { if (i == whereToSplitRDI) {
CheckNTErrors((s->dimSizeRDI[i] == t->dimSizeRDI[i] * splitNum), CheckNTErrors((s->dimSizeRDI[i] == t->dimSizeRDI[i] * splitNum),
"Unmatched tensor sizes!"); "Unmatched tensor sizes!");
} }
else { else {
CheckNTErrors((s->dimSizeRDI[i] == t->dimSizeRDI[i]), CheckNTErrors((s->dimSizeRDI[i] == t->dimSizeRDI[i]),
"Unmatched tensor sizes!"); "Unmatched tensor sizes!");
} }
} }
...@@ -301,7 +301,7 @@ void Split(const XTensor &big, XList &smalls, int whereToSplit, int splitNum) ...@@ -301,7 +301,7 @@ void Split(const XTensor &big, XList &smalls, int whereToSplit, int splitNum)
XLink::AddParamToHeadInt(s, whereToSplit); XLink::AddParamToHeadInt(s, whereToSplit);
/* it is tricky here that we keep the id of each /* it is tricky here that we keep the id of each
block, rather than the total number of splits */ block, rather than the total number of the splits */
XLink::AddParamToHeadInt(s, i); XLink::AddParamToHeadInt(s, i);
} }
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论