Commit 37b7e09b by xiaotong

fix the bug in dimension setting in the back propagation of Merge

parent 9e5887dd
......@@ -71,9 +71,11 @@ dE/da = split(dE/dc)
void XShapeGrad::GradMerge(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 0, "Wrong input tensor number for MERGE!");
XTensor * input = income.tails[0];
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for MERGE!");
CheckNTErrors(node->order == input->order - 1, "wrong tensor orders!");
int whereToMerge = income.GetParamInt(0);
int leadDim = income.GetParamInt(1);
......@@ -95,13 +97,13 @@ void XShapeGrad::GradMerge(XTensor * node)
}
dims[0] = -dims[0];
XTensor gradInputSmall(input->order - leadDim, dims,
input->dataType, input->denseRatio,
input->devID, input->mem);
input->dataType, input->denseRatio,
input->devID, input->mem);
dims[whereToMerge - leadDim] *= dims[0];
XTensor gradNodeSmall(node->order - leadDim, dims,
node->dataType, node->denseRatio,
node->devID, node->mem);
XTensor gradNodeSmall(node->order - leadDim, dims + leadDim + 1,
node->dataType, node->denseRatio,
node->devID, node->mem);
/* we can simply split the gradient tensor
if the input is used in merging only */
......@@ -109,7 +111,7 @@ void XShapeGrad::GradMerge(XTensor * node)
for(int i = 0; i < blockNum; i++){
gradNodeSmall.data = (char*)node->grad->data + i * blockSize;
gradInputSmall.data = (char*)input->grad->data + i * blockSize;
_Split(&gradNodeSmall, &gradInputSmall, whereToMerge - leadDim, input->dimSize[leadDim]);
_Split(&gradNodeSmall, &gradInputSmall, whereToMerge - leadDim - 1, input->dimSize[leadDim]);
}
}
......@@ -123,7 +125,7 @@ void XShapeGrad::GradMerge(XTensor * node)
for(int i = 0; i < blockNum; i++){
gradNodeSmall.data = (char*)node->grad->data + i * blockSize;
gradInputSmall.data = (char*)input->grad->data + i * blockSize;
_Split(&gradNodeSmall, &gradInputSmallBuf, whereToMerge - leadDim, input->dimSize[leadDim]);
_Split(&gradNodeSmall, &gradInputSmallBuf, whereToMerge - leadDim - 1, input->dimSize[leadDim]);
_Sum(&gradInputSmall, &gradInputSmallBuf, &gradInputSmall);
}
}
......
......@@ -73,8 +73,7 @@ void MakeWordBatch(XTensor &batch, NGram * ngrams, int ngramNum, int n, int vSiz
void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net);
void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NAME loss,
FNNModel &model, FNNModel &grad, FNNNet &net);
void FBInOne(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NAME loss,
FNNModel &model, XNet &net);
void ForwardAutoDiff(XTensor inputs[], XTensor &output, FNNModel &model);
/*
entry of the program
......@@ -415,7 +414,10 @@ void Train(const char * train, bool isShuffled, FNNModel &model)
}
else{
/* forward + backward process */
FBInOne(inputs, output, gold, CROSSENTROPY, model, autoDiffer);
ForwardAutoDiff(inputs, output, model);
/* automatic differentiation */
autoDiffer.Backward(output, gold, CROSSENTROPY);
/* update model parameters */
Update(model, grad, learningRate, true);
......@@ -902,17 +904,14 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
}
/*
forward + backward in one procedure
forward process (with tensor connections)
>> inputs - input word representations
>> output - output probability
>> gold - gold standard
>> loss - loss function name
>> model - the fnn model
*/
void FBInOne(XTensor inputs[], XTensor &output, XTensor &gold,
LOSS_FUNCTION_NAME loss, FNNModel &model, XNet &net)
void ForwardAutoDiff(XTensor inputs[], XTensor &output, FNNModel &model)
{
int batchSize = gold.GetDim(0);
int batchSize = inputs[0].GetDim(0);
int n = model.n;
int depth = model.hDepth;
......@@ -945,9 +944,6 @@ void FBInOne(XTensor inputs[], XTensor &output, XTensor &gold,
/* output layer */
output = LogSoftmax(MMul(hidden, model.outputW) + b, 1);
/* automatic differentiation */
net.Backward(output);
}
/*
......
......@@ -49,7 +49,7 @@ void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
CheckNTErrors((s != NULL && t != NULL), "Invalid tensors!");
CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)),
"the data must be kept on the same device!");
"the data must be kept on the same device!");
CheckNTErrors((s->unitNum == t->unitNum && s->unitSize == t->unitSize), "Unmatched tensors!");
CheckNTErrors((s->order == t->order + 1), "Unmatched tensors!");
......@@ -58,11 +58,11 @@ void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
for (int i = 0; i < s->order; i++) {
if (i == whereToMergeRDI) {
CheckNTErrors((t->dimSizeRDI[i] == s->dimSizeRDI[i] * s->dimSizeRDI[leadingDimRDI]),
"Unmatched tensor sizes!");
"Unmatched tensor sizes!");
}
else if (i > leadingDimRDI) {
CheckNTErrors((s->dimSizeRDI[i - 1] == t->dimSizeRDI[i]),
"Unmatched tensor sizes!");
"Unmatched tensor sizes!");
}
}
......
......@@ -41,7 +41,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
{
CheckNTErrors((s && t), "Invalid tensors!");
CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)),
"the data must be kept on the same device!");
"the data must be kept on the same device!");
CheckNTErrors((s->unitNum == t->unitNum && s->unitSize == t->unitSize), "Unmatched tensors!");
CheckNTErrors((s->order == t->order - 1), "Unmatched tensors!");
......@@ -51,11 +51,11 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
for (int i = 0; i < s->order; i++) {
if (i == whereToSplitRDI) {
CheckNTErrors((s->dimSizeRDI[i] == t->dimSizeRDI[i] * splitNum),
"Unmatched tensor sizes!");
"Unmatched tensor sizes!");
}
else {
CheckNTErrors((s->dimSizeRDI[i] == t->dimSizeRDI[i]),
"Unmatched tensor sizes!");
"Unmatched tensor sizes!");
}
}
......@@ -301,7 +301,7 @@ void Split(const XTensor &big, XList &smalls, int whereToSplit, int splitNum)
XLink::AddParamToHeadInt(s, whereToSplit);
/* it is tricky here that we keep the id of each
block, rather than the total number of splits */
block, rather than the total number of the splits */
XLink::AddParamToHeadInt(s, i);
}
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论