Commit dcd3a86b by liyinqiao

Optimize the functions.

Optimize the MultiplyMe and DivMe functions when operate the scalar tensors.
parent a7d832bc
......@@ -157,16 +157,9 @@ where i is the index of the item
void DivMe(XTensor & a, const XTensor & b, DTYPE alpha, int leadingDim)
{
if (b.order == 0){
DTYPE scale = 1.0F / b.Get0D();
XTensor * tmp1 = NewTensorBufV2(&a, a.devID, a.mem);
XTensor * tmp2 = NewTensorBufV2(&a, a.devID, a.mem);
_ScaleAndShift(&a, tmp1, scale, 0.0F);
_ScaleAndShift(&a, tmp2, alpha, 0.0F);
_Sum(tmp2, tmp1, &a);
DTYPE scale = 1.0F / b.Get0D() + alpha;
DelTensorBuf(tmp1);
DelTensorBuf(tmp2);
_ScaleAndShift(&a, &a, scale, 0.0F);
}
else {
int n = GetBroadcastDimIndex(a, b);
......
......@@ -158,16 +158,9 @@ where i is the index of the item
void MultiplyMe(XTensor& a, const XTensor& b, DTYPE alpha, int leadingDim)
{
if (b.order == 0){
DTYPE scale = b.Get0D();
XTensor * tmp1 = NewTensorBufV2(&a, a.devID, a.mem);
XTensor * tmp2 = NewTensorBufV2(&a, a.devID, a.mem);
_ScaleAndShift(&a, tmp1, scale, 0.0F);
_ScaleAndShift(&a, tmp2, alpha, 0.0F);
_Sum(tmp2, tmp1, &a);
DTYPE scale = b.Get0D() + alpha;
DelTensorBuf(tmp1);
DelTensorBuf(tmp2);
_ScaleAndShift(&a, &a, scale, 0.0F);
}
else {
int n = GetBroadcastDimIndex(a, b);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论