Commit 43331674 by xiaotong

new code for Multiply

parent f1e69d00
......@@ -34,18 +34,20 @@ where i is the index of the item
>> b - matrix b
>> c - result matrix
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
>>
*/
void Multiply(XTensor * a, XTensor * b, XTensor * c, int leadingDim, DTYPE alpha)
void _Multiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim)
{
int leadingDimRDI = a->order - leadingDim - 1;
CheckNTErrors((a->unitNum <= c->unitNum && b->unitNum <= c->unitNum),
"Unmatched tensors in multiplication!");
CheckNTErrors((a->order == b->order && a->order == c->order), "Unmatched tensors!");
CheckNTErrors((a->order == b->order && a->order == c->order),
"Unmatched tensors!");
#ifdef USE_CUDA
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
CudaMultiply(a, b, c, leadingDim, alpha);
_CudaMultiply(a, b, c, alpha, leadingDim);
return;
}
#endif
......@@ -118,4 +120,46 @@ void Multiply(XTensor * a, XTensor * b, XTensor * c, int leadingDim, DTYPE alpha
}
}
/*
element-wise product of two tensors and keep the result in the input
a(i) = a(i)*b(i) + \alpha * a(i)
where i is the index of the item
>> a - tensor a (where keep the result)
>> b - tensor b
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
*/
void _MultiplyMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim)
{
_Multiply(a, b, a, alpha, leadingDim);
}
/*
make a tensor of the element-wise product for two input tensors:
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the item
>> a - tensor a
>> b - tensor b
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors
*/
XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
{
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XTensor c(&a);
c.SetTMP();
/* computation */
_Multiply(&a, &b, &c, alpha, leadingDim);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
return c;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -117,11 +117,11 @@ where i is the item index
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> leadingDim - leading dimension
>> alpha - the coefficient
>> leadingDim - dimension along which we perform broadcasting
*/
extern "C"
void CudaMultiply(XTensor * a, XTensor * b, XTensor * c, int leadingDim, DTYPE alpha)
void _CudaMultiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim)
{
int leadingDimRDI = a->order - leadingDim - 1;
CheckNTErrors((a->unitNum <= c->unitNum && b->unitNum <= c->unitNum),
......
......@@ -42,7 +42,7 @@ void KernelMulElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE al
/* element-wise product of two tensors */
extern "C"
void CudaMultiply(XTensor * a, XTensor * b, XTensor * c, int leadingDim = 0, DTYPE alpha = 0);
void _CudaMultiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
#endif // USE_CUDA
......
......@@ -26,9 +26,20 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* element-wise product of two tensors */
extern "C"
void Multiply(XTensor * a, XTensor * b, XTensor * c, int leadingDim = 0, DTYPE alpha = 0);
/* element-wise product of two tensors:
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the element */
void _Multiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
/* element-wise product of two tensors and keep the result in the input tensor:
a(i) = a(i)*b(i) + \alpha * a(i)
where i is the index of the element */
void _MultiplyMe(XTensor * a, const XTensor * b, DTYPE alpha = 0, int leadingDim = 0);
/* make a tensor of the element-wise product for two input tensors:
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the element */
XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha = 0, int leadingDim = 0);
} // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论