Commit 515af68a by xiaotong

reload Merge to merge two input tensors

parent 5bf12b8c
...@@ -314,7 +314,6 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge) ...@@ -314,7 +314,6 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
/* /*
merge small tensors into a big tensor (return a XTensor structure) merge small tensors into a big tensor (return a XTensor structure)
make a new tensor to keep the result and return it
>> smalls - the list of the small tensors >> smalls - the list of the small tensors
>> whereToMerge - the merging operation is along with which dimension >> whereToMerge - the merging operation is along with which dimension
...@@ -351,4 +350,47 @@ XTensor Merge(const XList &smalls, int whereToMerge) ...@@ -351,4 +350,47 @@ XTensor Merge(const XList &smalls, int whereToMerge)
return big; return big;
} }
/*
merge two tensors into a big tensor (return a XTensor structure)
>> smalls - the list of the small tensors
>> whereToMerge - the merging operation is along with which dimension
<< return - the big tensor merged by small tensors
*/
XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge)
{
CheckNTErrors(XTensor::IsIdentical(&smallA, &smallB),
"The two tensors must be of the same size!");
int order = smallA.order;
int * dimSize = new int[order];
for (int i = 0; i < smallA.order; i++) {
if (i != whereToMerge)
dimSize[i] = smallA.dimSize[i];
else
dimSize[i] = smallA.dimSize[whereToMerge] * 2;
}
XTensor big = NewTensor(order, dimSize,
smallA.dataType, smallA.denseRatio,
smallA.devID, smallA.mem);
big.SetZeroAll();
big.SetTMP();
XList smalls(2);
smalls.Add(&smallA);
smalls.Add(&smallB);
/* call _Merge function */
_Merge(&smalls, &big, whereToMerge);
/* tensor connections */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
XLink::AddParamToHeadInt(&big, whereToMerge);
/* destroy variables */
delete[] dimSize;
return big;
}
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
...@@ -29,18 +29,19 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -29,18 +29,19 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* transform a tensor by merging it alone with a dimension, e.g., (M, N/3, 3) -> (M, N) */ /* transform a tensor by merging it alone with a dimension, e.g., (M, N/3, 3) -> (M, N) */
void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim = -1); void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim = -1);
/* transform a tensor by merging it alone with a dimension (return a XTensor structure). /* transform a tensor by merging it alone with a dimension (return a XTensor structure)
make a new tensor to keep the result and return it.
e.g., (M, N/3, 3) -> (M, N) */ e.g., (M, N/3, 3) -> (M, N) */
XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim = -1); XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim = -1);
/* merge small tensors into a big tensor */ /* merge small tensors into a big tensor */
void _Merge(const XList * smalls, XTensor * big, int whereToMerge); void _Merge(const XList * smalls, XTensor * big, int whereToMerge);
/* merge small tensors into a big tensor (return a XTensor structure). /* merge small tensors into a big tensor (return a XTensor structure) */
make a new tensor to keep the result and return it. */
XTensor Merge(const XList &smalls, int whereToMerge); XTensor Merge(const XList &smalls, int whereToMerge);
/* merge two tensors into a big tensor (return a XTensor structure) */
XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // __MERGE_H__ #endif // __MERGE_H__
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论