Commit 8703cf97 by xiaotong

bug in Concatenate and Split

parent a7f2f309
...@@ -68,8 +68,7 @@ or "Merge" by means of the tensor shapes ...@@ -68,8 +68,7 @@ or "Merge" by means of the tensor shapes
*/ */
XTensor Concatenate(const XList &smalls, int dim) XTensor Concatenate(const XList &smalls, int dim)
{ {
CheckNTErrors(&smalls != NULL, "Invalid list!"); CheckNTErrors(smalls.count > 0, "Empty list!");
CheckNTErrors((smalls.count > 0), "Empty list!");
CheckNTErrors(dim >= 0, "Illegal dimension to concatenate!"); CheckNTErrors(dim >= 0, "Illegal dimension to concatenate!");
bool uniform = true; bool uniform = true;
...@@ -91,6 +90,22 @@ XTensor Concatenate(const XList &smalls, int dim) ...@@ -91,6 +90,22 @@ XTensor Concatenate(const XList &smalls, int dim)
else else
dimSize[i] = tensor->dimSize[dim] * smalls.count; dimSize[i] = tensor->dimSize[dim] * smalls.count;
} }
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetTMP();
/* call _Merge function */
_Merge(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
/* destroy variables */
delete[] dimSize;
return big;
} }
else { else {
for (int i = 0; i < tensor->order; i++) for (int i = 0; i < tensor->order; i++)
...@@ -103,15 +118,13 @@ XTensor Concatenate(const XList &smalls, int dim) ...@@ -103,15 +118,13 @@ XTensor Concatenate(const XList &smalls, int dim)
catDimSize += tensor->dimSize[dim]; catDimSize += tensor->dimSize[dim];
} }
dimSize[dim] = catDimSize; dimSize[dim] = catDimSize;
}
XTensor big = XTensor(order, dimSize, tensor->dataType, tensor->denseRatio, tensor->devID, tensor->mem); float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetZeroAll();
big.SetTMP(); big.SetTMP();
/* call _Merge function */ /* call _ConcatenateSolely function */
_Merge(&smalls, &big, dim); _ConcatenateSolely(&smalls, &big, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE); XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
...@@ -121,6 +134,7 @@ XTensor Concatenate(const XList &smalls, int dim) ...@@ -121,6 +134,7 @@ XTensor Concatenate(const XList &smalls, int dim)
delete[] dimSize; delete[] dimSize;
return big; return big;
}
} }
/* /*
...@@ -152,14 +166,76 @@ make a new tensor to keep the result and return it. ...@@ -152,14 +166,76 @@ make a new tensor to keep the result and return it.
*/ */
XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim) XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
{ {
ShowNTErrors("rewrite this!!!!!!!!"); CheckNTErrors(dim >= 0, "Illegal dimension to concatenate!");
XList smalls(2); XList smalls(2);
smalls.Add(&smallA); smalls.Add(&smallA);
smalls.Add(&smallB); smalls.Add(&smallB);
/* call Concatenate function */ bool uniform = true;
return Concatenate(smalls, dim); for (int i = 1; i < smalls.count; i++) {
XTensor * a = (XTensor*)smalls.Get(i - 1);
XTensor * b = (XTensor*)smalls.Get(i);
CheckNTErrors((a && b), "Empty input tensors!");
if (!XTensor::IsIdentical(a, b))
uniform = false;
}
XTensor * tensor = (XTensor*)smalls.Get(0);
int order = tensor->order;
int * dimSize = new int[order];
if (uniform) {
for (int i = 0; i < tensor->order; i++) {
if (i != dim)
dimSize[i] = tensor->dimSize[i];
else
dimSize[i] = tensor->dimSize[dim] * smalls.count;
}
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetTMP();
/* call _Merge function */
_Merge(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
/* destroy variables */
delete[] dimSize;
return big;
}
else {
for (int i = 0; i < tensor->order; i++)
if (i != dim)
dimSize[i] = tensor->dimSize[i];
int catDimSize = 0;
for (int i = 0; i < smalls.count; i++) {
XTensor * tensor = (XTensor*)smalls.Get(i);
catDimSize += tensor->dimSize[dim];
}
dimSize[dim] = catDimSize;
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetTMP();
/* call _ConcatenateSolely function */
_ConcatenateSolely(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
/* destroy variables */
delete[] dimSize;
return big;
}
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -150,11 +150,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum) ...@@ -150,11 +150,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
for (int i = 0; i < s.order; i++) { for (int i = 0; i < s.order; i++) {
if (i == whereToSplit) if (i == whereToSplit)
dimSize[i] = s.dimSize[i] / splitNum; dimSize[i + 1] = s.dimSize[i] / splitNum;
else else
dimSize[i] = s.dimSize[i]; dimSize[i + 1] = s.dimSize[i];
} }
dimSize[-1] = splitNum; dimSize[0] = splitNum;
XTensor t = NewTensor(order, dimSize, s.dataType, s.denseRatio, s.devID, s.mem); XTensor t = NewTensor(order, dimSize, s.dataType, s.denseRatio, s.devID, s.mem);
t.SetZeroAll(); t.SetZeroAll();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论