Commit a7d832bc by liyinqiao

Bug fixed and clean the codes.

1. Fix the bug in DivMe function which cannot handle the scalar tensor and broadcast case.
2. Clean the codes.
3. Fix minor errors.
parent a4b98ac6
......@@ -146,7 +146,7 @@ void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim)
element-wise division of two tensors (do it on site)
keep the result in the input tensor a and return nothing
a(i) = a(i)*b(i) + \alpha * a(i)
a(i) = a(i)/b(i) + \alpha * a(i)
where i is the index of the item
>> a - tensor a (where keep the result)
......@@ -154,9 +154,35 @@ where i is the index of the item
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
*/
void DivMe(XTensor& a, const XTensor& b, DTYPE alpha, int leadingDim)
void DivMe(XTensor & a, const XTensor & b, DTYPE alpha, int leadingDim)
{
_Div(&a, &b, &a, alpha, leadingDim);
if (b.order == 0){
DTYPE scale = 1.0F / b.Get0D();
XTensor * tmp1 = NewTensorBufV2(&a, a.devID, a.mem);
XTensor * tmp2 = NewTensorBufV2(&a, a.devID, a.mem);
_ScaleAndShift(&a, tmp1, scale, 0.0F);
_ScaleAndShift(&a, tmp2, alpha, 0.0F);
_Sum(tmp2, tmp1, &a);
DelTensorBuf(tmp1);
DelTensorBuf(tmp2);
}
else {
int n = GetBroadcastDimIndex(a, b);
if (n == -1) {
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* call _Div function */
_Div(&a, &b, &a, alpha, leadingDim);
}
else if (n >= 0 && n < a.order)
/* call _DivDim function */
_DivDim(&a, &b, &a, n, alpha);
else
ShowNTErrors("Something is wrong!");
}
}
/*
......@@ -172,7 +198,7 @@ where i is the index of the item
>> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors
*/
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim)
XTensor Div(const XTensor & a, const XTensor & b, int leadingDim)
{
XTensor c(&a);
c.SetTMPFlag();
......@@ -226,7 +252,7 @@ where i is the index of the item
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
*/
void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadingDim)
void Div(const XTensor & a, const XTensor & b, XTensor & c, DTYPE alpha, int leadingDim)
{
if (!c.isInit || !IsSameShaped(a, c)) {
InitTensorV2(&c, &a);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论