Commit 515af68a by xiaotong

reload Merge to merge two input tensors

parent 5bf12b8c
......@@ -314,7 +314,6 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
/*
merge small tensors into a big tensor (return a XTensor structure)
make a new tensor to keep the result and return it
>> smalls - the list of the small tensors
>> whereToMerge - the merging operation is along with which dimension
......@@ -351,4 +350,47 @@ XTensor Merge(const XList &smalls, int whereToMerge)
return big;
}
/*
merge two tensors into a big tensor (return a XTensor structure)
>> smalls - the list of the small tensors
>> whereToMerge - the merging operation is along with which dimension
<< return - the big tensor merged by small tensors
*/
XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge)
{
CheckNTErrors(XTensor::IsIdentical(&smallA, &smallB),
"The two tensors must be of the same size!");
int order = smallA.order;
int * dimSize = new int[order];
for (int i = 0; i < smallA.order; i++) {
if (i != whereToMerge)
dimSize[i] = smallA.dimSize[i];
else
dimSize[i] = smallA.dimSize[whereToMerge] * 2;
}
XTensor big = NewTensor(order, dimSize,
smallA.dataType, smallA.denseRatio,
smallA.devID, smallA.mem);
big.SetZeroAll();
big.SetTMP();
XList smalls(2);
smalls.Add(&smallA);
smalls.Add(&smallB);
/* call _Merge function */
_Merge(&smalls, &big, whereToMerge);
/* tensor connections */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
XLink::AddParamToHeadInt(&big, whereToMerge);
/* destroy variables */
delete[] dimSize;
return big;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -29,18 +29,19 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* transform a tensor by merging it alone with a dimension, e.g., (M, N/3, 3) -> (M, N) */
void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim = -1);
/* transform a tensor by merging it alone with a dimension (return a XTensor structure).
make a new tensor to keep the result and return it.
/* transform a tensor by merging it alone with a dimension (return a XTensor structure)
e.g., (M, N/3, 3) -> (M, N) */
XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim = -1);
/* merge small tensors into a big tensor */
void _Merge(const XList * smalls, XTensor * big, int whereToMerge);
/* merge small tensors into a big tensor (return a XTensor structure).
make a new tensor to keep the result and return it. */
/* merge small tensors into a big tensor (return a XTensor structure) */
XTensor Merge(const XList &smalls, int whereToMerge);
/* merge two tensors into a big tensor (return a XTensor structure) */
XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge);
} // namespace nts(NiuTrans.Tensor)
#endif // __MERGE_H__
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论