Commit cadda317 by xuchen

implement the left-constant overloading and TtypeAs function

parent 36903fdb
...@@ -704,6 +704,12 @@ void XTensor::ReshapeMerged(const int i, const int j) ...@@ -704,6 +704,12 @@ void XTensor::ReshapeMerged(const int i, const int j)
Reshape(order - 1, dims); Reshape(order - 1, dims);
} }
/* return a tensor that datatype is same as the special tensor */
XTensor XTensor::TypeAs(const XTensor input)
{
return ConvertDataType(*this, input.dataType);
}
/* get the number of items in the data array */ /* get the number of items in the data array */
int XTensor::GetSize() const int XTensor::GetSize() const
{ {
...@@ -3047,4 +3053,28 @@ void DelTensorBuf(XTensor * tensor) ...@@ -3047,4 +3053,28 @@ void DelTensorBuf(XTensor * tensor)
delete tensor; delete tensor;
} }
/* overloading of the plus-sign */
XTensor operator+ (const DTYPE shift, const XTensor &tensor)
{
return ScaleAndShift(tensor, 1, shift);
}
/* overloading of the minus-sign */
XTensor operator- (const DTYPE shift, const XTensor &tensor)
{
return ScaleAndShift(tensor, 1, -shift);
}
/* overloading of the multiply-sign */
XTensor operator* (const DTYPE scale, const XTensor &tensor)
{
return ScaleAndShift(tensor, scale, 0);
}
/* overloading of the division-sign */
XTensor operator/ (const DTYPE scale, const XTensor &tensor)
{
return ScaleAndShift(tensor, (DTYPE)1/scale, 0);
}
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -283,6 +283,9 @@ public: ...@@ -283,6 +283,9 @@ public:
/* reshape the tensor by merging two consecutive dimensions */ /* reshape the tensor by merging two consecutive dimensions */
void ReshapeMerged(const int i, const int j = -1); void ReshapeMerged(const int i, const int j = -1);
/* return a tensor that datatype is same as the special tensor */
XTensor TypeAs(const XTensor input);
/* get the number of items in the data array */ /* get the number of items in the data array */
int GetSize() const; int GetSize() const;
...@@ -599,6 +602,18 @@ void DelTensor(XTensor * tensor); ...@@ -599,6 +602,18 @@ void DelTensor(XTensor * tensor);
/* free the data space of a given tensor (on the buffer) */ /* free the data space of a given tensor (on the buffer) */
void DelTensorBuf(XTensor * tensor); void DelTensorBuf(XTensor * tensor);
/* overloading of the plus-sign */
XTensor operator+ (const DTYPE shift, const XTensor &tensor);
/* overloading of the minus-sign */
XTensor operator- (const DTYPE shift, const XTensor &tensor);
/* overloading of the multiply-sign */
XTensor operator* (const DTYPE scale, const XTensor &tensor);
/* overloading of the division-sign */
XTensor operator/ (const DTYPE scale, const XTensor &tensor);
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
#endif #endif
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论