Commit fbef372b by xiaotong

reload of *

parent 27635638
...@@ -45,7 +45,7 @@ int main( int argc, const char ** argv ) ...@@ -45,7 +45,7 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(78); //_CrtSetBreakAlloc(78);
/* a tiny test */ /* a tiny test */
if(1) if(0)
SmallTest(); SmallTest();
if(argc > 1 && !strcmp(argv[1], "-test")) if(argc > 1 && !strcmp(argv[1], "-test"))
...@@ -76,9 +76,9 @@ void SmallTest() ...@@ -76,9 +76,9 @@ void SmallTest()
b = Sum(a, Multiply(a, a)); b = Sum(a, Multiply(a, a));
XTensor c = b; XTensor c = a * b + a;
XTensor d = b + a + Linear(c, 0.5F); XTensor d = a + b + Linear(c, 0.5F);
a.Dump(stderr, "a: "); a.Dump(stderr, "a: ");
b.Dump(stderr, "b: "); b.Dump(stderr, "b: ");
......
...@@ -41,6 +41,7 @@ ...@@ -41,6 +41,7 @@
#include "core/shape/MergeBlockLists.h" #include "core/shape/MergeBlockLists.h"
#include "core/movement/CopyValues.h" #include "core/movement/CopyValues.h"
#include "core/arithmetic/Sum.h" #include "core/arithmetic/Sum.h"
#include "core/arithmetic/Multiply.h"
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -315,6 +316,12 @@ XTensor XTensor::operator+ (const XTensor& tensor) ...@@ -315,6 +316,12 @@ XTensor XTensor::operator+ (const XTensor& tensor)
return Sum(*this, tensor); return Sum(*this, tensor);
} }
/* overloading of the multiply-sign */
XTensor XTensor::operator* (const XTensor& tensor)
{
return Multiply(*this, tensor);
}
/* /*
judge whether the two matrices are in the same type and size judge whether the two matrices are in the same type and size
>> a - input tensor >> a - input tensor
......
...@@ -187,6 +187,9 @@ public: ...@@ -187,6 +187,9 @@ public:
/* overloading of the plus-sign */ /* overloading of the plus-sign */
XTensor operator+ (const XTensor &tensor); XTensor operator+ (const XTensor &tensor);
/* overloading of the multiply-sign */
XTensor operator* (const XTensor &tensor);
/* judge whether the two matrices are in the same type and size */ /* judge whether the two matrices are in the same type and size */
static static
bool IsIdentical(XTensor * a, XTensor * b); bool IsIdentical(XTensor * a, XTensor * b);
......
...@@ -53,18 +53,24 @@ void _ScaleAndShift(const XTensor * a, XTensor * b, DTYPE scale, DTYPE shift) ...@@ -53,18 +53,24 @@ void _ScaleAndShift(const XTensor * a, XTensor * b, DTYPE scale, DTYPE shift)
int num = a->unitNumNonZero; int num = a->unitNumNonZero;
char * d = (char*)a->data + sizeof(int); char * d = (char*)a->data + sizeof(int);
char * f = d + (sizeof(int) + sizeof(DTYPE)) * 0 + sizeof(int); char * f = d + (sizeof(int) + sizeof(DTYPE)) * 0 + sizeof(int);
char * db = (char*)b->data + sizeof(int);
char * fb = db + (sizeof(int) + sizeof(DTYPE)) * 0 + sizeof(int);
for(int i = 0; i < num; i++){ for(int i = 0; i < num; i++){
DTYPE * v = (DTYPE*)f; DTYPE * v = (DTYPE*)f;
*v = *v * scale + shift; DTYPE * vb = (DTYPE*)fb;
*vb = *v * scale + shift;
f += sizeof(int) + sizeof(DTYPE); f += sizeof(int) + sizeof(DTYPE);
fb += sizeof(int) + sizeof(DTYPE);
} }
} }
/* dense tensor */ /* dense tensor */
else{ else{
DTYPE * v = (DTYPE*)a->data; DTYPE * va = (DTYPE*)a->data;
for(int i = 0; i < a->unitNum; i++){ DTYPE * vb = (DTYPE*)b->data;
*v = *v * scale + shift; for(int i = 0; i < b->unitNum; i++){
v++; *vb = *va * scale + shift;
va++;
vb++;
} }
} }
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论