Commit 4bcf6c54 by xiaotong

fix the bug in Merge

parent 2dadc66a
...@@ -246,13 +246,16 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficient) ...@@ -246,13 +246,16 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficient)
dims[j++] = input->dimSize[i]; dims[j++] = input->dimSize[i];
} }
} }
dims[0] = -dims[0];
dims[0] = -abs(dims[0]);
XTensor gradInputSmall(input->order - leadDim, dims, XTensor gradInputSmall(input->order - leadDim, dims,
input->dataType, input->denseRatio, input->dataType, input->denseRatio,
input->devID, input->mem); input->devID, input->mem);
dims[whereToMerge - leadDim] *= dims[0]; dims[whereToMerge - leadDim] *= abs(dims[0]);
XTensor gradNodeSmall(node->order - leadDim, dims + leadDim + 1, int * dimsNode = dims + 1;
dimsNode[0] = -abs(dimsNode[0]);
XTensor gradNodeSmall(node->order - leadDim, dimsNode,
node->dataType, node->denseRatio, node->dataType, node->denseRatio,
node->devID, node->mem); node->devID, node->mem);
......
...@@ -32,14 +32,14 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -32,14 +32,14 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* /*
transform a tensor by merging it along with a dimension. transform a tensor by merging it along with a dimension.
e.g., (N/3, M, 3) -> (N, M) e.g., (3, M, N/3) -> (M, N)
>> s - the source tensor >> s - the source tensor
>> t - the target tensor (for return) >> t - the target tensor (for return)
>> whereToMerge - the merging operation is along with which dimension >> whereToMerge - the merging operation is along with which dimension
>> leadingDim - the leading dimension of merging, take (N/3, M, 3) -> (N, M) >> leadingDim - the leading dimension of merging, take (3, M, N/3) -> (M, N)
for example, whereToMerge = 0 (i.e., the dimension for "N/3") for example, whereToMerge = 2 (i.e., the dimension for "N/3")
leadingDim = 2 (i.e., the dimension for "3") leadingDim = 0 (i.e., the dimension for "3")
*/ */
void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim) void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
{ {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论