Commit a503fe6a by xiaotong

add Lin method to XTensor

parent fbef372b
......@@ -45,7 +45,7 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(78);
/* a tiny test */
if(0)
if(1)
SmallTest();
if(argc > 1 && !strcmp(argv[1], "-test"))
......@@ -78,7 +78,7 @@ void SmallTest()
XTensor c = a * b + a;
XTensor d = a + b + Linear(c, 0.5F);
XTensor d = a + b + c.Lin(0.5F);
a.Dump(stderr, "a: ");
b.Dump(stderr, "b: ");
......
......@@ -42,6 +42,7 @@
#include "core/movement/CopyValues.h"
#include "core/arithmetic/Sum.h"
#include "core/arithmetic/Multiply.h"
#include "core/math/ScaleAndShift.h"
#ifdef USE_CUDA
......@@ -323,6 +324,16 @@ XTensor XTensor::operator* (const XTensor& tensor)
}
/*
linear transformation b = a * \scale + \shift
>> scale - the slope
>> shift - the intercept
*/
XTensor XTensor::Lin(DTYPE scale, DTYPE shift)
{
return Linear(*this, scale, shift);
}
/*
judge whether the two matrices are in the same type and size
>> a - input tensor
>> b - anther tensor to compare with
......
......@@ -190,6 +190,9 @@ public:
/* overloading of the multiply-sign */
XTensor operator* (const XTensor &tensor);
/* linear transformation */
XTensor Lin(DTYPE scale, DTYPE shift = 0);
/* judge whether the two matrices are in the same type and size */
static
bool IsIdentical(XTensor * a, XTensor * b);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论