Commit 75385ebe by xiaotong

add MODX to make the functions accept dimension index with a minus value

parent e87cf208
......@@ -157,7 +157,9 @@ extern bool useCUDA;
#define XPRINT7(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7);FFLUSH(FILEH);}}
#define XPRINT8(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7,ARG8) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7,ARG8);FFLUSH(FILEH);}}
#define B2I(V) V==0?false:true
#define B2I(V) V == 0 ? false : true
#define MODX(a, b) int(b == 0 ? a : a - floor(double(a)/b) * b)
/* BLAS interfaces */
#ifdef DOUBELPRICSION
......
......@@ -42,6 +42,8 @@ i.e., a is divided with b by broadcasting
*/
void _DivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha)
{
n = MODX(n, a->order);
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in division!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
......@@ -151,6 +153,8 @@ XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
{
XTensor c(&a);
c.SetTMPFlag();
n = MODX(n, a.order);
/* call _Div function */
_DivDim(&a, &b, &c, n, alpha);
......
......@@ -42,8 +42,10 @@ i.e., a is multiplied with b by broadcasting
>> n - the dimension index
>> alpha - the scaling factor
*/
void _MultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha) {
void _MultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha)
{
n = MODX(n, a->order);
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in multiplication!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
......@@ -151,6 +153,8 @@ XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n)
XTensor c(&a);
c.SetTMPFlag();
n = MODX(n, a.order);
/* call _Multiply function */
_MultiplyDim(&a, &b, &c, n, 0);
......
......@@ -42,6 +42,8 @@ i.e., a is subtracted with b by broadcasting
*/
void _SubDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta)
{
n = MODX(n, a->order);
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in subtraction!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
......@@ -152,6 +154,8 @@ XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
XTensor c(&a);
c.SetTMPFlag();
n = MODX(n, a.order);
/* call _Sub function */
_SubDim(&a, &b, &c, n, beta);
......
......@@ -46,6 +46,8 @@ i.e., a is summed with b by broadcasting
*/
void _SumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta)
{
n = MODX(n, a->order);
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in addition!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
......@@ -169,6 +171,8 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
{
XTensor c(&a);
c.SetTMPFlag();
n = MODX(n, a.order);
/* call _SumDim function */
_SumDim(&a, &b, &c, n, beta);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论