Commit fbef372b by xiaotong

reload of *

parent 27635638
......@@ -45,7 +45,7 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(78);
/* a tiny test */
if(1)
if(0)
SmallTest();
if(argc > 1 && !strcmp(argv[1], "-test"))
......@@ -76,9 +76,9 @@ void SmallTest()
b = Sum(a, Multiply(a, a));
XTensor c = b;
XTensor c = a * b + a;
XTensor d = b + a + Linear(c, 0.5F);
XTensor d = a + b + Linear(c, 0.5F);
a.Dump(stderr, "a: ");
b.Dump(stderr, "b: ");
......
......@@ -41,6 +41,7 @@
#include "core/shape/MergeBlockLists.h"
#include "core/movement/CopyValues.h"
#include "core/arithmetic/Sum.h"
#include "core/arithmetic/Multiply.h"
#ifdef USE_CUDA
......@@ -315,6 +316,12 @@ XTensor XTensor::operator+ (const XTensor& tensor)
return Sum(*this, tensor);
}
/* overloading of the multiply-sign */
XTensor XTensor::operator* (const XTensor& tensor)
{
return Multiply(*this, tensor);
}
/*
judge whether the two matrices are in the same type and size
>> a - input tensor
......
......@@ -187,6 +187,9 @@ public:
/* overloading of the plus-sign */
XTensor operator+ (const XTensor &tensor);
/* overloading of the multiply-sign */
XTensor operator* (const XTensor &tensor);
/* judge whether the two matrices are in the same type and size */
static
bool IsIdentical(XTensor * a, XTensor * b);
......
......@@ -53,18 +53,24 @@ void _ScaleAndShift(const XTensor * a, XTensor * b, DTYPE scale, DTYPE shift)
int num = a->unitNumNonZero;
char * d = (char*)a->data + sizeof(int);
char * f = d + (sizeof(int) + sizeof(DTYPE)) * 0 + sizeof(int);
char * db = (char*)b->data + sizeof(int);
char * fb = db + (sizeof(int) + sizeof(DTYPE)) * 0 + sizeof(int);
for(int i = 0; i < num; i++){
DTYPE * v = (DTYPE*)f;
*v = *v * scale + shift;
DTYPE * vb = (DTYPE*)fb;
*vb = *v * scale + shift;
f += sizeof(int) + sizeof(DTYPE);
fb += sizeof(int) + sizeof(DTYPE);
}
}
/* dense tensor */
else{
DTYPE * v = (DTYPE*)a->data;
for(int i = 0; i < a->unitNum; i++){
*v = *v * scale + shift;
v++;
DTYPE * va = (DTYPE*)a->data;
DTYPE * vb = (DTYPE*)b->data;
for(int i = 0; i < b->unitNum; i++){
*vb = *va * scale + shift;
va++;
vb++;
}
}
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论