Commit a503fe6a by xiaotong

add Lin method to XTensor

parent fbef372b
...@@ -45,7 +45,7 @@ int main( int argc, const char ** argv ) ...@@ -45,7 +45,7 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(78); //_CrtSetBreakAlloc(78);
/* a tiny test */ /* a tiny test */
if(0) if(1)
SmallTest(); SmallTest();
if(argc > 1 && !strcmp(argv[1], "-test")) if(argc > 1 && !strcmp(argv[1], "-test"))
...@@ -78,7 +78,7 @@ void SmallTest() ...@@ -78,7 +78,7 @@ void SmallTest()
XTensor c = a * b + a; XTensor c = a * b + a;
XTensor d = a + b + Linear(c, 0.5F); XTensor d = a + b + c.Lin(0.5F);
a.Dump(stderr, "a: "); a.Dump(stderr, "a: ");
b.Dump(stderr, "b: "); b.Dump(stderr, "b: ");
......
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
#include "core/movement/CopyValues.h" #include "core/movement/CopyValues.h"
#include "core/arithmetic/Sum.h" #include "core/arithmetic/Sum.h"
#include "core/arithmetic/Multiply.h" #include "core/arithmetic/Multiply.h"
#include "core/math/ScaleAndShift.h"
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -323,6 +324,16 @@ XTensor XTensor::operator* (const XTensor& tensor) ...@@ -323,6 +324,16 @@ XTensor XTensor::operator* (const XTensor& tensor)
} }
/* /*
linear transformation b = a * \scale + \shift
>> scale - the slope
>> shift - the intercept
*/
XTensor XTensor::Lin(DTYPE scale, DTYPE shift)
{
return Linear(*this, scale, shift);
}
/*
judge whether the two matrices are in the same type and size judge whether the two matrices are in the same type and size
>> a - input tensor >> a - input tensor
>> b - anther tensor to compare with >> b - anther tensor to compare with
......
...@@ -190,6 +190,9 @@ public: ...@@ -190,6 +190,9 @@ public:
/* overloading of the multiply-sign */ /* overloading of the multiply-sign */
XTensor operator* (const XTensor &tensor); XTensor operator* (const XTensor &tensor);
/* linear transformation */
XTensor Lin(DTYPE scale, DTYPE shift = 0);
/* judge whether the two matrices are in the same type and size */ /* judge whether the two matrices are in the same type and size */
static static
bool IsIdentical(XTensor * a, XTensor * b); bool IsIdentical(XTensor * a, XTensor * b);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论