Commit 4bcf6c54 by xiaotong

fix the bug in Merge

parent 2dadc66a
......@@ -246,13 +246,16 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficient)
dims[j++] = input->dimSize[i];
}
}
dims[0] = -dims[0];
dims[0] = -abs(dims[0]);
XTensor gradInputSmall(input->order - leadDim, dims,
input->dataType, input->denseRatio,
input->devID, input->mem);
dims[whereToMerge - leadDim] *= dims[0];
XTensor gradNodeSmall(node->order - leadDim, dims + leadDim + 1,
dims[whereToMerge - leadDim] *= abs(dims[0]);
int * dimsNode = dims + 1;
dimsNode[0] = -abs(dimsNode[0]);
XTensor gradNodeSmall(node->order - leadDim, dimsNode,
node->dataType, node->denseRatio,
node->devID, node->mem);
......
......@@ -32,14 +32,14 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
transform a tensor by merging it along with a dimension.
e.g., (N/3, M, 3) -> (N, M)
e.g., (3, M, N/3) -> (M, N)
>> s - the source tensor
>> t - the target tensor (for return)
>> whereToMerge - the merging operation is along with which dimension
>> leadingDim - the leading dimension of merging, take (N/3, M, 3) -> (N, M)
for example, whereToMerge = 0 (i.e., the dimension for "N/3")
leadingDim = 2 (i.e., the dimension for "3")
>> leadingDim - the leading dimension of merging, take (3, M, N/3) -> (M, N)
for example, whereToMerge = 2 (i.e., the dimension for "N/3")
leadingDim = 0 (i.e., the dimension for "3")
*/
void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
{
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论