Commit 75385ebe by xiaotong

add MODX to make the functions accept dimension index with a minus value

parent e87cf208
...@@ -157,7 +157,9 @@ extern bool useCUDA; ...@@ -157,7 +157,9 @@ extern bool useCUDA;
#define XPRINT7(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7);FFLUSH(FILEH);}} #define XPRINT7(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7);FFLUSH(FILEH);}}
#define XPRINT8(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7,ARG8) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7,ARG8);FFLUSH(FILEH);}} #define XPRINT8(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7,ARG8) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7,ARG8);FFLUSH(FILEH);}}
#define B2I(V) V==0?false:true #define B2I(V) V == 0 ? false : true
#define MODX(a, b) int(b == 0 ? a : a - floor(double(a)/b) * b)
/* BLAS interfaces */ /* BLAS interfaces */
#ifdef DOUBELPRICSION #ifdef DOUBELPRICSION
......
...@@ -42,6 +42,8 @@ i.e., a is divided with b by broadcasting ...@@ -42,6 +42,8 @@ i.e., a is divided with b by broadcasting
*/ */
void _DivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha) void _DivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha)
{ {
n = MODX(n, a->order);
CheckNTErrors(a && b && c, "Empty tensor input!"); CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in division!"); CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in division!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType, CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
...@@ -151,6 +153,8 @@ XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha) ...@@ -151,6 +153,8 @@ XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
{ {
XTensor c(&a); XTensor c(&a);
c.SetTMPFlag(); c.SetTMPFlag();
n = MODX(n, a.order);
/* call _Div function */ /* call _Div function */
_DivDim(&a, &b, &c, n, alpha); _DivDim(&a, &b, &c, n, alpha);
......
...@@ -42,8 +42,10 @@ i.e., a is multiplied with b by broadcasting ...@@ -42,8 +42,10 @@ i.e., a is multiplied with b by broadcasting
>> n - the dimension index >> n - the dimension index
>> alpha - the scaling factor >> alpha - the scaling factor
*/ */
void _MultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha) { void _MultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha)
{
n = MODX(n, a->order);
CheckNTErrors(a && b && c, "Empty tensor input!"); CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in multiplication!"); CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in multiplication!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType, CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
...@@ -151,6 +153,8 @@ XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n) ...@@ -151,6 +153,8 @@ XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n)
XTensor c(&a); XTensor c(&a);
c.SetTMPFlag(); c.SetTMPFlag();
n = MODX(n, a.order);
/* call _Multiply function */ /* call _Multiply function */
_MultiplyDim(&a, &b, &c, n, 0); _MultiplyDim(&a, &b, &c, n, 0);
......
...@@ -42,6 +42,8 @@ i.e., a is subtracted with b by broadcasting ...@@ -42,6 +42,8 @@ i.e., a is subtracted with b by broadcasting
*/ */
void _SubDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta) void _SubDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta)
{ {
n = MODX(n, a->order);
CheckNTErrors(a && b && c, "Empty tensor input!"); CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in subtraction!"); CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in subtraction!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType, CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
...@@ -152,6 +154,8 @@ XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta) ...@@ -152,6 +154,8 @@ XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
XTensor c(&a); XTensor c(&a);
c.SetTMPFlag(); c.SetTMPFlag();
n = MODX(n, a.order);
/* call _Sub function */ /* call _Sub function */
_SubDim(&a, &b, &c, n, beta); _SubDim(&a, &b, &c, n, beta);
......
...@@ -46,6 +46,8 @@ i.e., a is summed with b by broadcasting ...@@ -46,6 +46,8 @@ i.e., a is summed with b by broadcasting
*/ */
void _SumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta) void _SumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta)
{ {
n = MODX(n, a->order);
CheckNTErrors(a && b && c, "Empty tensor input!"); CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in addition!"); CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in addition!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType, CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
...@@ -169,6 +171,8 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta) ...@@ -169,6 +171,8 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
{ {
XTensor c(&a); XTensor c(&a);
c.SetTMPFlag(); c.SetTMPFlag();
n = MODX(n, a.order);
/* call _SumDim function */ /* call _SumDim function */
_SumDim(&a, &b, &c, n, beta); _SumDim(&a, &b, &c, n, beta);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论