Commit 8703cf97 by xiaotong

bug in Concatenate and Split

parent a7f2f309
......@@ -68,8 +68,7 @@ or "Merge" by means of the tensor shapes
*/
XTensor Concatenate(const XList &smalls, int dim)
{
CheckNTErrors(&smalls != NULL, "Invalid list!");
CheckNTErrors((smalls.count > 0), "Empty list!");
CheckNTErrors(smalls.count > 0, "Empty list!");
CheckNTErrors(dim >= 0, "Illegal dimension to concatenate!");
bool uniform = true;
......@@ -90,7 +89,23 @@ XTensor Concatenate(const XList &smalls, int dim)
dimSize[i] = tensor->dimSize[i];
else
dimSize[i] = tensor->dimSize[dim] * smalls.count;
}
}
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetTMP();
/* call _Merge function */
_Merge(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
/* destroy variables */
delete[] dimSize;
return big;
}
else {
for (int i = 0; i < tensor->order; i++)
......@@ -103,24 +118,23 @@ XTensor Concatenate(const XList &smalls, int dim)
catDimSize += tensor->dimSize[dim];
}
dimSize[dim] = catDimSize;
}
XTensor big = XTensor(order, dimSize, tensor->dataType, tensor->denseRatio, tensor->devID, tensor->mem);
big.SetZeroAll();
big.SetTMP();
/* call _Merge function */
_Merge(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetTMP();
/* destroy variables */
delete[] dimSize;
return big;
/* call _ConcatenateSolely function */
_ConcatenateSolely(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
/* destroy variables */
delete[] dimSize;
return big;
}
}
/*
......@@ -152,14 +166,76 @@ make a new tensor to keep the result and return it.
*/
XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
{
ShowNTErrors("rewrite this!!!!!!!!");
CheckNTErrors(dim >= 0, "Illegal dimension to concatenate!");
XList smalls(2);
smalls.Add(&smallA);
smalls.Add(&smallB);
/* call Concatenate function */
return Concatenate(smalls, dim);
bool uniform = true;
for (int i = 1; i < smalls.count; i++) {
XTensor * a = (XTensor*)smalls.Get(i - 1);
XTensor * b = (XTensor*)smalls.Get(i);
CheckNTErrors((a && b), "Empty input tensors!");
if (!XTensor::IsIdentical(a, b))
uniform = false;
}
XTensor * tensor = (XTensor*)smalls.Get(0);
int order = tensor->order;
int * dimSize = new int[order];
if (uniform) {
for (int i = 0; i < tensor->order; i++) {
if (i != dim)
dimSize[i] = tensor->dimSize[i];
else
dimSize[i] = tensor->dimSize[dim] * smalls.count;
}
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetTMP();
/* call _Merge function */
_Merge(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
/* destroy variables */
delete[] dimSize;
return big;
}
else {
for (int i = 0; i < tensor->order; i++)
if (i != dim)
dimSize[i] = tensor->dimSize[i];
int catDimSize = 0;
for (int i = 0; i < smalls.count; i++) {
XTensor * tensor = (XTensor*)smalls.Get(i);
catDimSize += tensor->dimSize[dim];
}
dimSize[dim] = catDimSize;
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetTMP();
/* call _ConcatenateSolely function */
_ConcatenateSolely(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
/* destroy variables */
delete[] dimSize;
return big;
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -150,11 +150,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
for (int i = 0; i < s.order; i++) {
if (i == whereToSplit)
dimSize[i] = s.dimSize[i] / splitNum;
dimSize[i + 1] = s.dimSize[i] / splitNum;
else
dimSize[i] = s.dimSize[i];
dimSize[i + 1] = s.dimSize[i];
}
dimSize[-1] = splitNum;
dimSize[0] = splitNum;
XTensor t = NewTensor(order, dimSize, s.dataType, s.denseRatio, s.devID, s.mem);
t.SetZeroAll();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论