Commit bc49d32a by xuchen

Merge with liyinqiao brach and add the max/min function

parent cadda317
...@@ -71,6 +71,9 @@ void BackwardTest() ...@@ -71,6 +71,9 @@ void BackwardTest()
XTensor a; XTensor a;
XTensor b; XTensor b;
XTensor c; XTensor c;
a.enableGrad = true;
b.enableGrad = false;
c.enableGrad = false;
XTensor mean; XTensor mean;
XTensor origin; XTensor origin;
InitTensor2D(&a, 2, 3); InitTensor2D(&a, 2, 3);
...@@ -88,14 +91,15 @@ void BackwardTest() ...@@ -88,14 +91,15 @@ void BackwardTest()
b.Set1D(2.0F, 0); b.Set1D(2.0F, 0);
b.Set1D(1.0F, 1); b.Set1D(1.0F, 1);
c = DivDim(a, b, 0); DivDim(a, b, c, 0);
c.Dump(stderr, "c:"); c.Dump(stderr, "c:");
auto loss = CrossEntropy(c, a);
//XLink::ShowNetwork(stderr, &c); //XLink::ShowNetwork(stderr, &c);
net.Backward(c); net.Backward(loss);
net.Dump(stderr); a.grad->Dump(stderr);
} }
......
...@@ -26,183 +26,9 @@ ...@@ -26,183 +26,9 @@
* *
*/ */
#ifdef WIN32
#include <wtypes.h>
#endif
#include <stdlib.h>
#include <stdio.h>
#include "XBLAS.h"
#include "XGlobal.h"
/* the nts (NiuTrans.Tensor) namespace */ /* the nts (NiuTrans.Tensor) namespace */
namespace nts{ namespace nts{
#ifdef WIN32
HINSTANCE hBLASDll;
#endif
/* single-precision floating matrix-matrix multiplication */
void (*XBLAS_SGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
float *, OPENBLAS_CONST BLASINT);
/* double-precision floating matrix-matrix multiplication */
void (*XBLAS_DGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
double *, OPENBLAS_CONST BLASINT);
/* single-precision floating vector-vector multiplication (rank-1) */
void (*XBLAS_SGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha,
OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
float *, OPENBLAS_CONST BLASINT);
/* double-precision floating vector-vector multiplication (rank-1) */
void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
double *, OPENBLAS_CONST BLASINT);
/* set the number of threads */
void (*XBLAS_SET_THREAD_NUM)(int);
/* get the number of threads */
//int (*XBLAS_GET_THREAD_NUM)();
/* get the number of physical processors (cores).*/
int (*XBLAS_GET_CORE_NUM)();
/* get the CPU corename */
//char * (*XBLAS_GET_CORE_NAME)();
/* get the parallelization type used by OpenBLAS */
//int (*XBLAS_GET_PARALLEL_TYPE)(void);
#if defined(USE_BLAS)
/* load some stuff for BLAS */
void LoadBLAS(const char * dllFileName)
{
#ifndef CUDA_BLAS
#ifdef _WIN32
#if defined(OPENBLAS)
/* non-ascii characters are not supported yet */
wchar_t * fn = new wchar_t[strlen(dllFileName) + 1];
memset(fn, 0, sizeof(wchar_t) * (strlen(dllFileName) + 1));
for(int i = 0; i < strlen(dllFileName); i++)
fn[i] = dllFileName[i];
hBLASDll = LoadLibrary((LPCWSTR)fn);
if(!hBLASDll){
XPRINT1(0, stderr, "[LoadBLAS] Error! Cannot load dll %s!\n", dllFileName);
exit(1);
}
/* matrix-matrix multiplicatoin */
(FARPROC&)XBLAS_SGEMM = GetProcAddress(hBLASDll, "cblas_sgemm");
(FARPROC&)XBLAS_DGEMM = GetProcAddress(hBLASDll, "cblas_dgemm");
/* vector-vector multiplication */
(FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger");
(FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger");
/* multi-threading */
(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "openblas_set_num_threads");
//(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "goto_set_num_threads");
//(FARPROC&)XBLAS_GET_THREAD_NUM = GetProcAddress(hBLASDll, "openblas_get_num_threads");
(FARPROC&)XBLAS_GET_CORE_NUM = GetProcAddress(hBLASDll, "openblas_get_num_procs");
//(FARPROC&)XBLAS_GET_CORE_NAME = GetProcAddress(hBLASDll, "openblas_get_corename");
//(FARPROC&)XBLAS_GET_PARALLEL_TYPE = GetProcAddress(hBLASDll, "openblas_get_parallel");
delete[] fn;
#endif // defined(OPENBLAS)
#if defined(MKL)
/* non-ascii characters are not supported yet */
wchar_t * fn = new wchar_t[strlen(dllFileName) + 1];
memset(fn, 0, sizeof(wchar_t) * (strlen(dllFileName) + 1));
for(int i = 0; i < strlen(dllFileName); i++)
fn[i] = dllFileName[i];
hBLASDll = LoadLibrary((LPCWSTR)fn);
if(!hBLASDll){
XPRINT1(0, stderr, "[LoadBLAS] Error! Cannot load dll %s!\n", dllFileName);
exit(1);
}
/* matrix-matrix multiplicatoin */
(FARPROC&)XBLAS_SGEMM = GetProcAddress(hBLASDll, "cblas_sgemm");
(FARPROC&)XBLAS_DGEMM = GetProcAddress(hBLASDll, "cblas_dgemm");
/* vector-vector multiplication */
(FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger");
(FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger");
/* multi-threading */
(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "MKL_Set_Num_Threads");
(FARPROC&)XBLAS_GET_CORE_NUM = GetProcAddress(hBLASDll, "MKL_Get_Max_Threads");
#endif // defined(MKL)
#else // _WIN32
XBLAS_SGEMM = &cblas_sgemm;
XBLAS_DGEMM = &cblas_dgemm;
XBLAS_SGER = &cblas_sger;
XBLAS_DGER = &cblas_dger;
#if defined(OPENBLAS)
XBLAS_SET_THREAD_NUM = &openblas_set_num_threads;
XBLAS_GET_CORE_NUM = &openblas_get_num_procs;
#endif // defined(OPENBLAS)
#if defined(MKL)
XBLAS_SET_THREAD_NUM = &mkl_set_num_threads;
XBLAS_GET_CORE_NUM = &mkl_get_max_num_threads;
#endif // defined(MKL)
#endif // _WIN32
XBLAS_SET_THREAD_NUM(1);
#endif // ndef(CUDA_BLAS)
}
/* unload the libs */
void UnloadBLAS()
{
#ifdef _WIN32
if(!FreeLibrary(hBLASDll)){
XPRINT(0, stderr, "[UnloadBLAS] Error! Cannot free the BLAS dll!\n");
exit(1);
}
#else
#endif
}
#else // undefined(USE_BLAS) || undefined(OPENBLAS)
void LoadBLAS(const char * dllFileName)
{
XPRINT(0, stderr, "[LoadBLAS] Error! No Blas lib is available. Please use OPENBLAS or MKL!\n");
exit(1);
}
void UnloadBLAS()
{
XPRINT(0, stderr, "[UnloadBLAS] Error! No Blas lib is available. Please use OPENBLAS or MKL!\n");
exit(1);
}
#endif // defined(USE_BLAS) && defined(OPENBLAS)
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
...@@ -34,7 +34,6 @@ namespace nts{ ...@@ -34,7 +34,6 @@ namespace nts{
/* some of the code below is from OpenBLAS (https://github.com/xianyi/OpenBLAS) */ /* some of the code below is from OpenBLAS (https://github.com/xianyi/OpenBLAS) */
//#define OPENBLAS
#define OPENBLAS_CONST const #define OPENBLAS_CONST const
typedef int BLASINT; typedef int BLASINT;
...@@ -46,7 +45,26 @@ typedef enum CBLAS_SIDE {CblasLeft=141, CblasRight=142} CBLAS_SIDE; ...@@ -46,7 +45,26 @@ typedef enum CBLAS_SIDE {CblasLeft=141, CblasRight=142} CBLAS_SIDE;
#if defined(USE_BLAS) #if defined(USE_BLAS)
#ifdef OPENBLAS
#define XBLAS_SGEMM cblas_sgemm
#define XBLAS_DGEMM cblas_dgemm
#define XBLAS_SGER cblas_sger
#define XBLAS_DGER cblas_dger
#define XBLAS_SAXPY cblas_saxpy
#define XBLAS_DAXPY cblas_daxpy
#define XBLAS_SET_THREAD_NUM openblas_set_num_threads
#define XBLAS_GET_CORE_NUM openblas_get_num_procs
#endif
#ifdef MKL
#define XBLAS_SGEMM cblas_sgemm
#define XBLAS_DGEMM cblas_dgemm
#define XBLAS_SGER cblas_sger
#define XBLAS_DGER cblas_dger
#define XBLAS_SAXPY cblas_saxpy
#define XBLAS_DAXPY cblas_daxpy
#define XBLAS_SET_THREAD_NUM MKL_Set_Num_Threads
#define XBLAS_GET_CORE_NUM MKL_Get_Max_Threads
#endif
/* /*
single/double-precision floating matrix-matrix multiplication (rank-3) single/double-precision floating matrix-matrix multiplication (rank-3)
- SGEMM (ORDER, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) - SGEMM (ORDER, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
...@@ -62,14 +80,14 @@ where A, B and C are matrices, ...@@ -62,14 +80,14 @@ where A, B and C are matrices,
LDB(=N) specifies the size of the first dimension of B as declared in the calling (sub) program, LDB(=N) specifies the size of the first dimension of B as declared in the calling (sub) program,
and LDC(=N) specifies the size of the first dimension of C as declared in the calling (sub) program. and LDC(=N) specifies the size of the first dimension of C as declared in the calling (sub) program.
*/ */
extern "C" void (*XBLAS_SGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE, extern "C" void XBLAS_SGEMM(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
float *, OPENBLAS_CONST BLASINT); float *, OPENBLAS_CONST BLASINT);
/* double-precision floating matrix-matrix multiplication */ /* double-precision floating matrix-matrix multiplication */
extern "C" void (*XBLAS_DGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE, extern "C" void XBLAS_DGEMM(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
...@@ -88,24 +106,33 @@ where X and Y are vectors with m and n elements respectively, ...@@ -88,24 +106,33 @@ where X and Y are vectors with m and n elements respectively,
E.g., if we are using CblasRowMajor, the leading dimension is the number of columns of A. E.g., if we are using CblasRowMajor, the leading dimension is the number of columns of A.
*/ */
extern "C" void (*XBLAS_SGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha, extern "C" void XBLAS_SGER(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha,
OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
float *, OPENBLAS_CONST BLASINT); float *, OPENBLAS_CONST BLASINT);
/* double-precision floating vector-vector multiplication (rank-1) */ /* double-precision floating vector-vector multiplication (rank-1) */
extern "C" void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha, extern "C" void XBLAS_DGER(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
double *, OPENBLAS_CONST BLASINT); double *, OPENBLAS_CONST BLASINT);
/*
some description
*/
extern "C" void XBLAS_SAXPY(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
/* double-precision floating sumMe function */
extern "C" void XBLAS_DAXPY(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST double a, OPENBLAS_CONST double *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST double *y, OPENBLAS_CONST BLASINT incy);
/* set the number of threads */ /* set the number of threads */
extern "C" void (*XBLAS_SET_THREAD_NUM)(int); extern "C" void XBLAS_SET_THREAD_NUM(int);
/* get the number of threads */ /* get the number of threads */
//extern "C" int (*XBLAS_GET_THREAD_NUM)(); //extern "C" int (*XBLAS_GET_THREAD_NUM)();
/* get the number of physical processors (cores).*/ /* get the number of physical processors (cores).*/
extern "C" int (*XBLAS_GET_CORE_NUM)(); extern "C" int XBLAS_GET_CORE_NUM();
/* get the CPU corename */ /* get the CPU corename */
//extern "C" char * (*XBLAS_GET_CORE_NAME)(); //extern "C" char * (*XBLAS_GET_CORE_NAME)();
...@@ -113,58 +140,6 @@ extern "C" int (*XBLAS_GET_CORE_NUM)(); ...@@ -113,58 +140,6 @@ extern "C" int (*XBLAS_GET_CORE_NUM)();
/* get the parallelization type used by OpenBLAS */ /* get the parallelization type used by OpenBLAS */
//extern "C" int (*XBLAS_GET_PARALLEL_TYPE)(void); //extern "C" int (*XBLAS_GET_PARALLEL_TYPE)(void);
/* linux systems */
#ifndef _WIN32
/* cblas functions that are imported from the lib. See cblas.h in OpenBlas for more information */
extern "C" void cblas_sgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB,
OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST BLASINT K, OPENBLAS_CONST float alpha,
OPENBLAS_CONST float *A, OPENBLAS_CONST BLASINT lda,
OPENBLAS_CONST float *B, OPENBLAS_CONST BLASINT ldb,
OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST BLASINT ldc);
extern "C" void cblas_dgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB,
OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST BLASINT K, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *A, OPENBLAS_CONST BLASINT lda,
OPENBLAS_CONST double *B, OPENBLAS_CONST BLASINT ldb,
OPENBLAS_CONST double beta, double *C, OPENBLAS_CONST BLASINT ldc);
extern "C" void cblas_sger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha,
OPENBLAS_CONST float *X, OPENBLAS_CONST BLASINT incX, OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT incY,
float *A, OPENBLAS_CONST BLASINT lda);
extern "C" void cblas_dger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *X, OPENBLAS_CONST BLASINT incX, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT incY,
double *A, OPENBLAS_CONST BLASINT lda);
#if defined(OPENBLAS)
/* better control of multi-threading */
extern "C" void openblas_set_num_threads(int num_threads);
extern "C" void goto_set_num_threads(int num_threads);
//extern "C" int openblas_get_num_threads(void);
extern "C" int openblas_get_num_procs(void);
//extern "C" char* openblas_get_config(void);
//extern "C" char* openblas_get_corename(void);
//extern "C" int openblas_get_parallel(void);
#endif
#endif
#if defined(MKL)
/* better control of multi-threading */
//_Mkl_Api(void,MKL_Set_Num_Threads,(int nth))
//_Mkl_Api(int,MKL_Get_Max_Threads,(void))
extern "C" void MKL_Set_Num_Threads(int num_threads);
extern "C" int MKL_Get_Max_Threads();
#define mkl_set_num_threads MKL_Set_Num_Threads
#define mkl_get_max_num_threads MKL_Get_Max_Threads
//extern "C" void mkl_set_num_threads(int num_threads);
//extern "C" void omp_set_num_threads(int num_threads);
//extern "C" int mkl_get_max_num_threads();
#endif
#if defined(CUDA_BLAS) #if defined(CUDA_BLAS)
...@@ -186,24 +161,8 @@ extern void BLASMatrixMULD(int deviceID, double * a, double * b, double * c, int ...@@ -186,24 +161,8 @@ extern void BLASMatrixMULD(int deviceID, double * a, double * b, double * c, int
#endif #endif
#endif
#ifdef _WIN32
#include "windows.h"
extern HINSTANCE hBLASDll;
#else
#endif #endif
/* load some stuff for BLAS */
extern void LoadBLAS(const char * dllFileName);
/* unload the libs */
extern void UnloadBLAS();
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
#endif #endif
...@@ -160,8 +160,10 @@ extern bool useCUDA; ...@@ -160,8 +160,10 @@ extern bool useCUDA;
/* BLAS interfaces */ /* BLAS interfaces */
#ifdef DOUBELPRICSION #ifdef DOUBELPRICSION
#define GEMM XBLAS_DGEMM #define GEMM XBLAS_DGEMM
#define AXPY XBLAS_DAXPY
#else #else
#define GEMM XBLAS_SGEMM #define GEMM XBLAS_SGEMM
#define AXPY XBLAS_SAXPY
#endif #endif
extern void InitGlobalAll(); extern void InitGlobalAll();
......
...@@ -300,6 +300,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id ...@@ -300,6 +300,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id
if(h == NULL) if(h == NULL)
return; return;
if (!t1->enableGrad)
return;
TensorList list(2); TensorList list(2);
list.Add((XTensor*)t1); list.Add((XTensor*)t1);
list.Add((XTensor*)t2); list.Add((XTensor*)t2);
...@@ -320,6 +323,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, const XTensor * t3, ...@@ -320,6 +323,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, const XTensor * t3,
if (h == NULL) if (h == NULL)
return; return;
if (!t1->enableGrad || !t2->enableGrad)
return;
TensorList list(3); TensorList list(3);
list.Add((XTensor*)t1); list.Add((XTensor*)t1);
list.Add((XTensor*)t2); list.Add((XTensor*)t2);
...@@ -370,6 +376,9 @@ create a hyper edge with a input tensors and a list of output tensors ...@@ -370,6 +376,9 @@ create a hyper edge with a input tensors and a list of output tensors
*/ */
void XLink::MakeLink(XTensor * t, TensorList * list, int id) void XLink::MakeLink(XTensor * t, TensorList * list, int id)
{ {
if (!t->enableGrad)
return;
/* forward */ /* forward */
for(int i = 0; i < list->count; i++){ for(int i = 0; i < list->count; i++){
XTensor * h = (XTensor*)list->GetItem(i); XTensor * h = (XTensor*)list->GetItem(i);
......
...@@ -23,15 +23,11 @@ ...@@ -23,15 +23,11 @@
* *
*/ */
#include "XList.h" #include "time.h"
#include "XMem.h" #include "XMem.h"
#include "XList.h"
#include "XGlobal.h" #include "XGlobal.h"
#include <ctime>
#include <utility>
#include <algorithm>
/* the nts (NiuTrans.Tensor) namespace */ /* the nts (NiuTrans.Tensor) namespace */
namespace nts { namespace nts {
...@@ -78,7 +74,8 @@ TensorListBase<T>::TensorListBase(int myMaxNum, XMem* myMem) ...@@ -78,7 +74,8 @@ TensorListBase<T>::TensorListBase(int myMaxNum, XMem* myMem)
template <typename T> template <typename T>
TensorListBase<T>::~TensorListBase() TensorListBase<T>::~TensorListBase()
{ {
delete[] items; if(items && mem)
delete[] items;
} }
...@@ -90,7 +87,7 @@ template <typename T> ...@@ -90,7 +87,7 @@ template <typename T>
void TensorListBase<T>::Add(T&& item) void TensorListBase<T>::Add(T&& item)
{ {
if (count == maxNum) { if (count == maxNum) {
T* newItems; T* newItems;
if (mem == NULL) if (mem == NULL)
newItems = new T[maxNum * 2 + 1]; newItems = new T[maxNum * 2 + 1];
...@@ -101,7 +98,13 @@ void TensorListBase<T>::Add(T&& item) ...@@ -101,7 +98,13 @@ void TensorListBase<T>::Add(T&& item)
maxNum = maxNum * 2 + 1; maxNum = maxNum * 2 + 1;
} }
items[count++] = item; items[count++] = item;
}
/* return number of elements */
template<typename T>
size_t TensorListBase<T>::Size()
{
return count;
} }
/* /*
...@@ -111,18 +114,18 @@ add an item into the list ...@@ -111,18 +114,18 @@ add an item into the list
template <typename T> template <typename T>
void TensorListBase<T>::Add(const T& item) void TensorListBase<T>::Add(const T& item)
{ {
if (count == maxNum) { if (count == maxNum) {
T* newItems; T* newItems;
if (mem == NULL) if (mem == NULL)
newItems = new T[maxNum * 2 + 1]; newItems = new T[maxNum * 2 + 1];
else else
newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * (maxNum * 2 + 1)); newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * (maxNum * 2 + 1));
memcpy(newItems, items, sizeof(T) * maxNum); memcpy(newItems, items, sizeof(T) * maxNum);
items = newItems; items = newItems;
maxNum = maxNum * 2 + 1; maxNum = maxNum * 2 + 1;
} }
items[count++] = item; items[count++] = item;
} }
/* /*
...@@ -131,7 +134,7 @@ add a number of items into the list ...@@ -131,7 +134,7 @@ add a number of items into the list
>> inputItemCount - number of input items >> inputItemCount - number of input items
*/ */
template <typename T> template <typename T>
void TensorListBase<T>::Add(T* inputItems, int inputItemCount) void TensorListBase<T>::Add(const T* inputItems, int inputItemCount)
{ {
if (count + inputItemCount >= maxNum) { if (count + inputItemCount >= maxNum) {
int newMaxNum = (count + inputItemCount) * 2 + 1; int newMaxNum = (count + inputItemCount) * 2 + 1;
...@@ -186,31 +189,31 @@ void TensorListBase<T>::Insert(int pos, const T& item) ...@@ -186,31 +189,31 @@ void TensorListBase<T>::Insert(int pos, const T& item)
template<typename T> template<typename T>
void TensorListBase<T>::Insert(int pos, T&& item) void TensorListBase<T>::Insert(int pos, T&& item)
{ {
if (count == maxNum) { if (count == maxNum) {
T* newItems; T* newItems;
if (mem == NULL) if (mem == NULL)
newItems = new T[maxNum * 2 + 1]; newItems = new T[maxNum * 2 + 1];
else else
newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * (maxNum * 2 + 1)); newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * (maxNum * 2 + 1));
memcpy(newItems, items, sizeof(T) * maxNum); memcpy(newItems, items, sizeof(T) * maxNum);
items = newItems; items = newItems;
maxNum = maxNum * 2 + 1; maxNum = maxNum * 2 + 1;
} }
for (int i = count - 1; i >= pos; i--) for (int i = count - 1; i >= pos; i--)
items[i + 1] = items[i]; items[i + 1] = items[i];
items[pos] = item; items[pos] = item;
count++; count++;
} }
/* get the item at position i */ /* get the item at position i */
template <typename T> template <typename T>
T& TensorListBase<T>::GetItem(int i) const T& TensorListBase<T>::GetItem(int i) const
{ {
CheckNTErrors(i >= -1 && i < count, "Index of a list item is out of scope!"); CheckNTErrors(i >= -count && i < count, "Index of a list item is out of scope!");
CheckNTErrors(count > 0, "Cannt index the item in an empty list!"); CheckNTErrors(count > 0, "Cannt index the item in an empty list!");
if (i == -1) if (i < 0)
return items[count - 1]; return items[count + i];
else else
return items[i]; return items[i];
} }
...@@ -226,8 +229,8 @@ inline void TensorListBase<T>::SetItem(int i, const T& item) ...@@ -226,8 +229,8 @@ inline void TensorListBase<T>::SetItem(int i, const T& item)
template<typename T> template<typename T>
inline void TensorListBase<T>::SetItem(int i, T&& item) inline void TensorListBase<T>::SetItem(int i, T&& item)
{ {
if (i >= 0 && i < count) if (i >= 0 && i < count)
items[i] = std::move(item); items[i] = item;
} }
/* /*
...@@ -246,11 +249,31 @@ inline int TensorListBase<T>::FindFirst(const T& item) ...@@ -246,11 +249,31 @@ inline int TensorListBase<T>::FindFirst(const T& item)
return -1; return -1;
} }
template <>
inline int TensorListBase<Example>::FindFirst(const Example& item)
{
for (int i = 0; i < count; i++) {
if (item.id == items[i].id)
return i;
}
return -1;
}
template <>
inline int TensorListBase<Result>::FindFirst(const Result& item)
{
for (int i = 0; i < count; i++) {
if (item.id == items[i].id)
return i;
}
return -1;
}
/* clear the data array */ /* clear the data array */
template <typename T> template <typename T>
void TensorListBase<T>::Clear() void TensorListBase<T>::Clear()
{ {
count = 0; count = 0;
} }
/* /*
...@@ -295,6 +318,17 @@ void TensorListBase<T>::Remove(int i) ...@@ -295,6 +318,17 @@ void TensorListBase<T>::Remove(int i)
count--; count--;
} }
template<typename T>
void TensorListBase<T>::Reserve(int n)
{
if (items) {
/* reserve failed */
return;
}
items = new T[n];
}
/* /*
copy the list copy the list
>> myMem - memory pool used for allocating the data in the new list >> myMem - memory pool used for allocating the data in the new list
...@@ -349,6 +383,8 @@ template struct TensorListBase<long>; ...@@ -349,6 +383,8 @@ template struct TensorListBase<long>;
template struct TensorListBase<float>; template struct TensorListBase<float>;
template struct TensorListBase<short>; template struct TensorListBase<short>;
template struct TensorListBase<XTensor*>; template struct TensorListBase<XTensor*>;
template struct TensorListBase<Result>;
template struct TensorListBase<Example>;
template struct TensorListBase<void*>; template struct TensorListBase<void*>;
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
/* the nts (NiuTrans.Tensor) namespace */ /* the nts (NiuTrans.Tensor) namespace */
namespace nts { namespace nts {
/* the TensorListBase class */ /* the TensorListBase class */
template <typename T> template <typename T>
struct TensorListBase { struct TensorListBase {
...@@ -66,68 +66,85 @@ public: ...@@ -66,68 +66,85 @@ public:
/* add an item into the list */ /* add an item into the list */
void Add(T&& item); void Add(T&& item);
/* add an item into the list */ /* return number of elements */
void Add(const T& item); size_t Size();
/* add a number of items into the list */ /* add an item into the list */
void Add(T* inputItems, int inputItemCount); void Add(const T& item);
/* add a number of items into the list */
void Add(const T* inputItems, int inputItemCount);
/* append a list to the current list */ /* append a list to the current list */
void AddList(TensorListBase* l); void AddList(TensorListBase* l);
/* insert an item to the given position of the list */ /* insert an item to the given position of the list */
void Insert(int pos, const T& item); void Insert(int pos, const T& item);
/* insert an item to the given position of the list */ /* insert an item to the given position of the list */
void Insert(int pos, T&& item); void Insert(int pos, T&& item);
/* get the item at position i */ /* get the item at position i */
T& GetItem(int i) const; T& GetItem(int i) const;
/* set the item at position i */ /* set the item at position i */
void SetItem(int i, const T& item); void SetItem(int i, const T& item);
/* set the item at position i */ /* set the item at position i */
void SetItem(int i, T&& item); void SetItem(int i, T&& item);
/* find the position of the first matched item */ /* find the position of the first matched item */
int FindFirst(const T& item); int FindFirst(const T& item);
/* clear the data array */ /* clear the data array */
void Clear(); void Clear();
/* sort the list */ /* sort the list */
void Sort(int itemSize); void Sort(int itemSize);
/* reverse the list */ /* reverse the list */
void Reverse(); void Reverse();
/* remove the item at position i */ /* remove the item at position i */
void Remove(int i); void Remove(int i);
/* copy the list */ /* reserve space for data entry */
void Reserve(int n);
/* copy the list */
TensorListBase* Copy(XMem* myMem); TensorListBase* Copy(XMem* myMem);
/* shuffle the list */ /* shuffle the list */
void Shuffle(int nround = 10, int beg = -1, int len = 0); void Shuffle(int nround = 10, int beg = -1, int len = 0);
/* short */ /* short */
T& operator[] (int i) { T& operator[] (int i) { return GetItem(i); };
return GetItem(i);
};
T& Get(int i) { return GetItem(i); }; T& Get(int i) { return GetItem(i); };
void Set(int i, T item) { SetItem(i, item); }; void Set(int i, T item) { SetItem(i, item); };
}; };
struct XTensor; struct XTensor;
typedef TensorListBase<void*> XList;
typedef TensorListBase<int> IntList; typedef TensorListBase<int> IntList;
typedef TensorListBase<char> CharList; typedef TensorListBase<char> CharList;
typedef TensorListBase<char*> StrList; typedef TensorListBase<char*> StrList;
typedef TensorListBase<long> LongList; typedef TensorListBase<long> LongList;
typedef TensorListBase<float> FloatList; typedef TensorListBase<float> FloatList;
typedef TensorListBase<short> ShortList; typedef TensorListBase<short> ShortList;
typedef TensorListBase<void*> XList;
struct Example {
int id;
IntList data;
};
struct Result {
int id;
IntList data;
};
typedef TensorListBase<Result> ResultList;
typedef TensorListBase<Example> ExampleList;
typedef TensorListBase<XTensor*> TensorList; typedef TensorListBase<XTensor*> TensorList;
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
......
...@@ -51,7 +51,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -51,7 +51,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_MASK MATH_DIVDIM + 1 #define MATH_MASK MATH_DIVDIM + 1
#define MATH_MATRIXMUL MATH_MASK + 1 #define MATH_MATRIXMUL MATH_MASK + 1
#define MATH_MATRIXMULBATCHED MATH_MATRIXMUL + 1 #define MATH_MATRIXMULBATCHED MATH_MATRIXMUL + 1
#define MATH_MULTIPLY MATH_MATRIXMULBATCHED + 1 #define MATH_MAX MATH_MATRIXMULBATCHED + 1
#define MATH_MIN MATH_MAX + 1
#define MATH_MULTIPLY MATH_MIN + 1
#define MATH_MULTIPLYDIM MATH_MULTIPLY + 1 #define MATH_MULTIPLYDIM MATH_MULTIPLY + 1
#define MATH_MULTIPLYBROADCAST MATH_MULTIPLYDIM + 1 #define MATH_MULTIPLYBROADCAST MATH_MULTIPLYDIM + 1
#define MATH_NEGATE MATH_MULTIPLYBROADCAST + 1 #define MATH_NEGATE MATH_MULTIPLYBROADCAST + 1
......
...@@ -215,18 +215,22 @@ XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim) ...@@ -215,18 +215,22 @@ XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
_Div(&a, &b, &c, alpha, leadingDim); _Div(&a, &b, &c, alpha, leadingDim);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHead(&c, alpha); XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHeadInt(&c, leadingDim); XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
} }
else if(n >= 0 && n < a.order){ else if(n >= 0 && n < a.order){
/* call _DivDim function */ /* call _DivDim function */
_DivDim(&a, &b, &c, n, alpha); _DivDim(&a, &b, &c, n, alpha);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
} }
else{ else{
ShowNTErrors("Something is wrong!"); ShowNTErrors("Something is wrong!");
...@@ -261,7 +265,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin ...@@ -261,7 +265,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
/* call _Div function */ /* call _Div function */
_Div(&a, &b, &c, 0, leadingDim); _Div(&a, &b, &c, 0, leadingDim);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV); XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHead(&c, alpha);
...@@ -272,7 +276,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin ...@@ -272,7 +276,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
/* call _DivDim function */ /* call _DivDim function */
_DivDim(&a, &b, &c, n, alpha); _DivDim(&a, &b, &c, n, alpha);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM); XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
......
...@@ -164,10 +164,12 @@ XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha) ...@@ -164,10 +164,12 @@ XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
_DivDim(&a, &b, &c, n, alpha); _DivDim(&a, &b, &c, n, alpha);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
return c; return c;
} }
...@@ -193,7 +195,7 @@ void DivDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE alpha) ...@@ -193,7 +195,7 @@ void DivDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE alpha)
/* call _Div function */ /* call _Div function */
_DivDim(&a, &b, &c, n, alpha); _DivDim(&a, &b, &c, n, alpha);
if (c.enableGrad == true) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM); XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
......
...@@ -155,8 +155,10 @@ XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha) ...@@ -155,8 +155,10 @@ XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha)
_Mask(&a, &mask, &c, alpha); _Mask(&a, &mask, &c, alpha);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &mask, &c, MATH_MASK); if (a.enableGrad) {
XLink::AddParamToHead(&c, alpha); XLink::MakeLink(&a, &mask, &c, MATH_MASK);
XLink::AddParamToHead(&c, alpha);
}
return c; return c;
} }
...@@ -176,7 +178,7 @@ void Mask(const XTensor &a, const XTensor &mask, XTensor &c, DTYPE alpha) ...@@ -176,7 +178,7 @@ void Mask(const XTensor &a, const XTensor &mask, XTensor &c, DTYPE alpha)
/* call _Mask function */ /* call _Mask function */
_Mask(&a, &mask, &c, alpha); _Mask(&a, &mask, &c, alpha);
if (c.enableGrad) { if (a.enableGrad) {
XLink::MakeLink(&a, &mask, &c, MATH_MASK); XLink::MakeLink(&a, &mask, &c, MATH_MASK);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHead(&c, alpha);
} }
......
...@@ -296,10 +296,12 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, ...@@ -296,10 +296,12 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
_MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner); _MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadTrans(&c, transposedA); XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, transposedB); XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadTrans(&c, transposedB);
XLink::AddParamToHead(&c, alpha);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -344,7 +346,7 @@ void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, ...@@ -344,7 +346,7 @@ void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
/* call _MatrixMul function */ /* call _MatrixMul function */
_MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, beta, parallelRunner); _MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, beta, parallelRunner);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL); XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, transposedA); XLink::AddParamToHeadTrans(&c, transposedA);
...@@ -393,10 +395,12 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b, ...@@ -393,10 +395,12 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b,
_MatrixMul(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner); _MatrixMul(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -440,7 +444,7 @@ void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c, ...@@ -440,7 +444,7 @@ void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c,
/* call _MatrixMul function */ /* call _MatrixMul function */
_MatrixMul(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner); _MatrixMul(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL); XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
......
...@@ -54,15 +54,15 @@ void _MatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -54,15 +54,15 @@ void _MatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
CheckNTErrors((a->order == 2 && b->order == 2 && c->order == 2), CheckNTErrors((a->order == 2 && b->order == 2 && c->order == 2),
"Input tensors must have a order = 2!"); "Input tensors must have a order = 2!");
int an = a->dimSize[0], am = a->dimSize[1]; int an = a->dimSize[0], am = a->dimSize[1];
int bn = b->dimSize[0], bm = b->dimSize[1]; int bn = b->dimSize[0], bm = b->dimSize[1];
int cn = c->dimSize[0], cm = c->dimSize[1]; int cn = c->dimSize[0], cm = c->dimSize[1];
int am2 = transposedA == X_TRANS ? an : am; int am2 = transposedA == X_TRANS ? an : am;
int an2 = transposedA == X_TRANS ? am : an; int an2 = transposedA == X_TRANS ? am : an;
int bm2 = transposedB == X_TRANS ? bn : bm; int bm2 = transposedB == X_TRANS ? bn : bm;
int bn2 = transposedB == X_TRANS ? bm : bn; int bn2 = transposedB == X_TRANS ? bm : bn;
int cm2 = cm; int cm2 = cm;
int cn2 = cn; int cn2 = cn;
CheckNTErrors((am2 == bn2 && an2 == cn2 && bm2 == cm2), CheckNTErrors((am2 == bn2 && an2 == cn2 && bm2 == cm2),
"Unmatched tensors in multiplication!"); "Unmatched tensors in multiplication!");
...@@ -82,10 +82,11 @@ void _MatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -82,10 +82,11 @@ void _MatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
b->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE) c->dataType == DEFAULT_DTYPE)
{ {
if (useBLAS) #if defined(USE_BLAS)
_MatrixMULCPU(a, transposedA, b, transposedB, c, alpha, beta); _MatrixMULCPU(a, transposedA, b, transposedB, c, alpha, beta);
else #else
_MatrixMul2DParallel(a, transposedA, b, transposedB, c, alpha, beta, parallelRunner); _MatrixMul2DParallel(a, transposedA, b, transposedB, c, alpha, beta, parallelRunner);
#endif
} }
else { else {
// TODO!! // TODO!!
......
...@@ -199,10 +199,7 @@ void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -199,10 +199,7 @@ void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
bi->data = (char*)b->data + i * bRealBlockSize; bi->data = (char*)b->data + i * bRealBlockSize;
ci->data = (char*)c->data + i * cRealBlockSize; ci->data = (char*)c->data + i * cRealBlockSize;
#ifdef USE_BLAS #ifdef USE_BLAS
if (useBLAS) _MatrixMULCPU(ai, transposedA, bi, transposedB, ci, alpha, beta);
_MatrixMULCPU(ai, transposedA, bi, transposedB, ci, alpha, beta);
else
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#else #else
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta); _MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#endif #endif
...@@ -262,10 +259,7 @@ void _MatrixMulBatchedCPU(const TensorList * a, MATRIX_TRANS_TYPE transposedA, ...@@ -262,10 +259,7 @@ void _MatrixMulBatchedCPU(const TensorList * a, MATRIX_TRANS_TYPE transposedA,
CheckNTErrors((bi->order == 2), "2d tensor (i.e., matrix) is required!"); CheckNTErrors((bi->order == 2), "2d tensor (i.e., matrix) is required!");
CheckNTErrors((ci->order == 2), "2d tensor (i.e., matrix) is required!"); CheckNTErrors((ci->order == 2), "2d tensor (i.e., matrix) is required!");
#ifdef USE_BLAS #ifdef USE_BLAS
if (useBLAS)
_MatrixMULCPU(ai, transposedA, bi, transposedB, ci, alpha, beta); _MatrixMULCPU(ai, transposedA, bi, transposedB, ci, alpha, beta);
else
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#else #else
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta); _MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#endif #endif
...@@ -320,10 +314,12 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const ...@@ -320,10 +314,12 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const
_MatrixMulBatched(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner); _MatrixMulBatched(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadTrans(&c, transposedA); XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED);
XLink::AddParamToHeadTrans(&c, transposedB); XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadTrans(&c, transposedB);
XLink::AddParamToHead(&c, alpha);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -376,10 +372,12 @@ XTensor MatrixMulBatched(const XTensor &a, const XTensor &b, ...@@ -376,10 +372,12 @@ XTensor MatrixMulBatched(const XTensor &a, const XTensor &b,
_MatrixMulBatched(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner); _MatrixMulBatched(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -118,11 +118,87 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, ...@@ -118,11 +118,87 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
} }
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT); if (w.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
//XLink::AddParamToHead(&c, beta); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
}
/* destroy variables */
delete[] dimSize;
DelTensorBuf(tmp);
return c;
}
/*
operation c = x * w + b MulAndShift
>> x - tensor x
>> w - tensor w
>> b - tensor b
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
*/
XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedA,
const XTensor& w, MATRIX_TRANS_TYPE transposedB,
const XTensor& b, DTYPE alpha, XPRunner* parallelRunner)
{
CheckNTErrors(x.dataType == w.dataType, "Input tensors should have the same data type!");
CheckNTErrors(x.order >= 2 && w.order >= 2, "Input tensors must have a order >= 2!");
int xn = transposedA == X_TRANS ? x.dimSizeRDI[0] : x.dimSizeRDI[1];
int xm = transposedA == X_TRANS ? x.dimSizeRDI[1] : x.dimSizeRDI[0];
int wn = transposedB == X_TRANS ? w.dimSizeRDI[0] : w.dimSizeRDI[1];
int wm = transposedB == X_TRANS ? w.dimSizeRDI[1] : w.dimSizeRDI[0];
int order = x.order + w.order - 2;
int sub = 0;
int * dimSize = new int[order];
for (int i = 2; i < x.order; i++)
dimSize[sub++] = x.dimSizeRDI[x.order + 1 - i];
for (int i = 2; i < w.order; i++)
dimSize[sub++] = w.dimSizeRDI[w.order + 1 - i];
dimSize[sub++] = xn;
dimSize[sub++] = wm;
float dr = (!x.isSparse || !w.isSparse) ? 1.0F : MAX(x.denseRatio, w.denseRatio);
XTensor * tmp = NewTensorBuf(order, dimSize, x.dataType, dr, x.devID, x.mem);
/* call _MatrixMul function */
_MatrixMul(&x, transposedA, &w, transposedB, tmp, alpha, 0, parallelRunner);
XTensor c(tmp);
c.SetTMPFlag();
int n = GetSumIndex(tmp, b);
if (n == -1) {
/* call _Sum function */
_Sum(tmp, &b, &c);
// TODO!!
ShowNTErrors("TODO!");
}
else if (n >= 0 && n < tmp->order) {
/* call _SumDim function */
_SumDim(tmp, &b, &c, n);
}
else {
ShowNTErrors("Something is wrong!");
}
/* tensor connections */
if (w.enableGrad && b.enableGrad) {
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHeadTrans(&c, transposedB);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -31,6 +31,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -31,6 +31,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL); DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
XTensor MulAndShift(const XTensor &x, MATRIX_TRANS_TYPE transposedA,
const XTensor &w, MATRIX_TRANS_TYPE transposedB,
const XTensor &b, DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -216,18 +216,22 @@ XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim ...@@ -216,18 +216,22 @@ XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim
_Multiply(&a, &b, &c, 0, leadingDim); _Multiply(&a, &b, &c, 0, leadingDim);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHead(&c, alpha); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHeadInt(&c, leadingDim); XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
} }
else if(n >= 0 && n < a.order){ else if(n >= 0 && n < a.order){
/* call _MultiplyDim function */ /* call _MultiplyDim function */
_MultiplyDim(&a, &b, &c, n, alpha); _MultiplyDim(&a, &b, &c, n, alpha);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
} }
else{ else{
ShowNTErrors("Something is wrong!"); ShowNTErrors("Something is wrong!");
...@@ -262,7 +266,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l ...@@ -262,7 +266,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
/* call _Multiply function */ /* call _Multiply function */
_Multiply(&a, &b, &c, 0, leadingDim); _Multiply(&a, &b, &c, 0, leadingDim);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHead(&c, alpha);
...@@ -273,7 +277,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l ...@@ -273,7 +277,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
/* call _MultiplyDim function */ /* call _MultiplyDim function */
_MultiplyDim(&a, &b, &c, n, alpha); _MultiplyDim(&a, &b, &c, n, alpha);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
......
...@@ -180,9 +180,11 @@ XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n) ...@@ -180,9 +180,11 @@ XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n)
_MultiplyDim(&a, &b, &c, n, 0); _MultiplyDim(&a, &b, &c, n, 0);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHead(&c, 0); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, 0);
}
return c; return c;
} }
...@@ -208,7 +210,7 @@ void MultiplyDim(const XTensor &a, const XTensor &b, XTensor &c, int n) ...@@ -208,7 +210,7 @@ void MultiplyDim(const XTensor &a, const XTensor &b, XTensor &c, int n)
/* call _Multiply function */ /* call _Multiply function */
_MultiplyDim(&a, &b, &c, n, 0); _MultiplyDim(&a, &b, &c, n, 0);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
...@@ -350,8 +352,10 @@ XTensor MultiplyBroadcast(const XTensor &a, const XTensor &b) ...@@ -350,8 +352,10 @@ XTensor MultiplyBroadcast(const XTensor &a, const XTensor &b)
_MultiplyBroadcast(&a, &b, &c, 0); _MultiplyBroadcast(&a, &b, &c, 0);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYBROADCAST); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHead(&c, 0); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYBROADCAST);
XLink::AddParamToHead(&c, 0);
}
return c; return c;
} }
...@@ -374,7 +378,7 @@ void MultiplyBroadcast(const XTensor &a, const XTensor &b, XTensor &c) ...@@ -374,7 +378,7 @@ void MultiplyBroadcast(const XTensor &a, const XTensor &b, XTensor &c)
/* call _SumBroadcast function */ /* call _SumBroadcast function */
_MultiplyBroadcast(&a, &b, &c, 0); _MultiplyBroadcast(&a, &b, &c, 0);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYBROADCAST); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYBROADCAST);
XLink::AddParamToHead(&c, 0); XLink::AddParamToHead(&c, 0);
......
...@@ -190,17 +190,21 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta) ...@@ -190,17 +190,21 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
_Sub(&a, &b, &c, beta); _Sub(&a, &b, &c, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHead(&c, beta); XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
}
} }
else if(n >= 0 && n < a.order){ else if(n >= 0 && n < a.order){
/* call _SubDim function */ /* call _SubDim function */
_SubDim(&a, &b, &c, n, beta); _SubDim(&a, &b, &c, n, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
} }
else{ else{
ShowNTErrors("Something is wrong!"); ShowNTErrors("Something is wrong!");
...@@ -229,7 +233,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) ...@@ -229,7 +233,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
/* call _Sub function */ /* call _Sub function */
_Sub(&a, &b, &c, beta); _Sub(&a, &b, &c, beta);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB); XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
...@@ -239,7 +243,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) ...@@ -239,7 +243,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
/* call _SubDim function */ /* call _SubDim function */
_SubDim(&a, &b, &c, n, beta); _SubDim(&a, &b, &c, n, beta);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM); XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
......
...@@ -164,9 +164,11 @@ XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta) ...@@ -164,9 +164,11 @@ XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
_SubDim(&a, &b, &c, n, beta); _SubDim(&a, &b, &c, n, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
return c; return c;
} }
...@@ -193,7 +195,7 @@ void SubDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta) ...@@ -193,7 +195,7 @@ void SubDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta)
/* call _Sub function */ /* call _Sub function */
_SubDim(&a, &b, &c, n, beta); _SubDim(&a, &b, &c, n, beta);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM); XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XUtility.h" #include "../../XUtility.h"
#include "../../XBLAS.h"
#include "../movement/CopyValues.h" #include "../movement/CopyValues.h"
#include "Sum.h" #include "Sum.h"
#include "Sum.cuh" #include "Sum.cuh"
...@@ -84,29 +85,57 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta) ...@@ -84,29 +85,57 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
DTYPE * ap = (DTYPE*)a->data; DTYPE * ap = (DTYPE*)a->data;
DTYPE * bp = (DTYPE*)b->data; DTYPE * bp = (DTYPE*)b->data;
DTYPE * cp = (DTYPE*)c->data; DTYPE * cp = (DTYPE*)c->data;
/* when c != a, OpenBLAS needs to copy a to c first. This operation
/* unrolling */ slow down the speed, so just use OpenBLAS when c == a */
int num = a->unitNum; #if defined(USE_BLAS)
if (num % 4 == 0) { if( c == a){
for (int i = 0; i < num; i += 4) { AXPY(a->unitNum,beta,bp,1,cp,1);
cp[i] = ap[i] + bp[i] * beta; } else{
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta; int num = a->unitNum;
cp[i + 2] = ap[i + 2] + bp[i + 2] * beta; if (num % 4 == 0) {
cp[i + 3] = ap[i + 3] + bp[i + 3] * beta; for (int i = 0; i < num; i += 4) {
} cp[i] = ap[i] + bp[i] * beta;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
cp[i + 2] = ap[i + 2] + bp[i + 2] * beta;
cp[i + 3] = ap[i + 3] + bp[i + 3] * beta;
}
}
else if (num % 2 == 0) {
for (int i = 0; i < num; i += 2) {
cp[i] = ap[i] + bp[i] * beta;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
}
}
else {
for (int i = 0; i < num; i++) {
cp[i] = ap[i] + bp[i] * beta;
}
}
} }
else if (num % 2 == 0) { #else
for (int i = 0; i < num; i += 2) { /* unrolling */
cp[i] = ap[i] + bp[i] * beta; int num = a->unitNum;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta; if (num % 4 == 0) {
for (int i = 0; i < num; i += 4) {
cp[i] = ap[i] + bp[i] * beta;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
cp[i + 2] = ap[i + 2] + bp[i + 2] * beta;
cp[i + 3] = ap[i + 3] + bp[i + 3] * beta;
}
} }
} else if (num % 2 == 0) {
else { for (int i = 0; i < num; i += 2) {
for (int i = 0; i < num; i++) { cp[i] = ap[i] + bp[i] * beta;
cp[i] = ap[i] + bp[i] * beta; cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
}
} }
else {
for (int i = 0; i < num; i++) {
cp[i] = ap[i] + bp[i] * beta;
}
}
#endif
} }
}
else { else {
// TODO!! // TODO!!
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
...@@ -195,17 +224,21 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta) ...@@ -195,17 +224,21 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
_Sum(&a, &b, &c, beta); _Sum(&a, &b, &c, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHead(&c, beta); XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta);
}
} }
else if(n >= 0 && n < a.order){ else if(n >= 0 && n < a.order){
/* call _SumDim function */ /* call _SumDim function */
_SumDim(&a, &b, &c, n, beta); _SumDim(&a, &b, &c, n, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
} }
else{ else{
ShowNTErrors("Something is wrong!"); ShowNTErrors("Something is wrong!");
...@@ -232,9 +265,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) ...@@ -232,9 +265,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
if (n == -1) { if (n == -1) {
/* call _Sum function */ /* call _Sum function */
_Sum(&a, &b, &c, beta); _Sum(&a, &b, &c, beta);
if (c.enableGrad) { /* tensor connections */
/* tensor connections */ if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUM); XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
} }
...@@ -242,9 +275,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) ...@@ -242,9 +275,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
else if (n >= 0 && n < a.order) { else if (n >= 0 && n < a.order) {
/* call _SumDim function */ /* call _SumDim function */
_SumDim(&a, &b, &c, n, beta); _SumDim(&a, &b, &c, n, beta);
if (c.enableGrad) { /* tensor connections */
/* tensor connections */ if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM); XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
......
...@@ -181,9 +181,11 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta) ...@@ -181,9 +181,11 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
_SumDim(&a, &b, &c, n, beta); _SumDim(&a, &b, &c, n, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
return c; return c;
} }
...@@ -210,7 +212,7 @@ void SumDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta) ...@@ -210,7 +212,7 @@ void SumDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta)
/* call _SumDim function */ /* call _SumDim function */
_SumDim(&a, &b, &c, n, beta); _SumDim(&a, &b, &c, n, beta);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM); XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
...@@ -353,9 +355,11 @@ XTensor SumBroadcast(const XTensor &a, const XTensor &b, DTYPE beta) ...@@ -353,9 +355,11 @@ XTensor SumBroadcast(const XTensor &a, const XTensor &b, DTYPE beta)
_SumBroadcast(&a, &b, &c, beta); _SumBroadcast(&a, &b, &c, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMBROADCAST); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHead(&c, beta); XLink::MakeLink(&a, &b, &c, MATH_SUMBROADCAST);
XLink::AddParamToHead(&c, beta);
}
return c; return c;
} }
...@@ -377,7 +381,7 @@ void SumBroadcast(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) ...@@ -377,7 +381,7 @@ void SumBroadcast(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
/* call _SumBroadcast function */ /* call _SumBroadcast function */
_SumBroadcast(&a, &b, &c, beta); _SumBroadcast(&a, &b, &c, beta);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMBROADCAST); XLink::MakeLink(&a, &b, &c, MATH_SUMBROADCAST);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
......
...@@ -121,7 +121,8 @@ XTensor ConvertDataType(const XTensor & input, TENSOR_DATA_TYPE dataType) ...@@ -121,7 +121,8 @@ XTensor ConvertDataType(const XTensor & input, TENSOR_DATA_TYPE dataType)
_ConvertDataType(&input, &output); _ConvertDataType(&input, &output);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE); if(input.enableGrad)
XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE);
return output; return output;
} }
...@@ -136,7 +137,7 @@ void ConvertDataType(const XTensor & input, XTensor & output, TENSOR_DATA_TYPE d ...@@ -136,7 +137,7 @@ void ConvertDataType(const XTensor & input, XTensor & output, TENSOR_DATA_TYPE d
_ConvertDataType(&input, &output); _ConvertDataType(&input, &output);
/* tensor connection */ /* tensor connection */
if (output.enableGrad) if (input.enableGrad)
XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE); XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE);
} }
......
...@@ -32,65 +32,43 @@ convert onehot tensor to index tensor ...@@ -32,65 +32,43 @@ convert onehot tensor to index tensor
>> index - index tensor, which value is an integer num >> index - index tensor, which value is an integer num
>> size - the last dimension size of the onehot tensor >> size - the last dimension size of the onehot tensor
*/ */
void _OnehotToIndex(XTensor * onehot, XTensor * index, int dim) void _OnehotToIndex(const XTensor * onehot, XTensor * index, int size)
{ {
dim = (dim < 0 ? onehot->GetDim(-1) : dim); CheckNTErrors(onehot->GetDim(-1) == size, "Illegal tensor dimension!");
CheckNTErrors(onehot->order == index->order + 1, "Illegal tensor order!"); CheckNTErrors(onehot->order == index->order + 1, "Illegal tensor order!");
CheckNTErrors(dim < onehot->order, "Illegal speficied dimension!")
CheckNTErrors(onehot->dataType == X_INT, "The onehot tensor must be in X_INT!") CheckNTErrors(onehot->dataType == X_INT, "The onehot tensor must be in X_INT!")
CheckNTErrors(index->dataType == X_INT, "The index tensor must be in X_INT!") CheckNTErrors(index->dataType == X_INT, "The index tensor must be in X_INT!")
for (int i = 0; i < index->order; i++) { for (int i = 0; i < index->order; i++)
if (i < dim) { CheckNTErrors(index->GetDim(i) == onehot->GetDim(i), "Illegal tensor order!");
CheckNTErrors(index->GetDim(i) == onehot->GetDim(i), "Illegal tensor order!");
}
else {
CheckNTErrors(index->GetDim(i) == onehot->GetDim(i + 1), "Illegal tensor order!");
}
}
#ifdef USE_CUDA #ifdef USE_CUDA
if(onehot->devID >= 0 && index->devID >= 0) { if(onehot->devID >= 0 && index->devID >= 0) {
_CudaOnehotToIndex(onehot, index, dim); _CudaOnehotToIndex(onehot, index, size);
return; return;
} }
#endif #endif
int blockNum = 1; int blockNum = index->unitNum;
int blockSize = 1; int stride = size;
int dimSize = 1;
int stride = 1;
for (int i = 0; i < dim; i++)
blockNum *= onehot->GetDim(i);
blockSize = onehot->unitNum / blockNum;
dimSize = onehot->GetDim(dim);
for (int i = dim + 1; i < onehot->order; i++)
stride *= onehot->GetDim(i);
int * onehotData = (int *)onehot->data; int * onehotData = (int *)onehot->data;
int * indexData = (int *)index->data; int * indexData = (int *)index->data;
for (int i = 0; i < blockNum; i++) { for (int i = 0; i < blockNum; i++) {
int * od = onehotData + i * stride;
int record = -1;
for (int j = 0; j < stride; j++) { for (int j = 0; j < stride; j++) {
int * od = onehotData + i * blockSize + j; if (od[j] != 0) {
int * index = indexData + i * stride + j; if (record == -1)
record = j;
int record = -1; else
for (int j = 0; j < dimSize; j++) { ShowNTErrors("The value of onehot tensor is illegal!");
if (od[j*stride] != 0) {
if (record == -1)
record = j;
else
ShowNTErrors("The value of onehot tensor is illegal!");
}
} }
*index = record;
} }
indexData[i] = record;
} }
} }
/* /*
...@@ -101,7 +79,7 @@ make a new tensor to keep the result and return it ...@@ -101,7 +79,7 @@ make a new tensor to keep the result and return it
>> size - the last dimension size of the onehot tensor >> size - the last dimension size of the onehot tensor
<< return - the index tensor << return - the index tensor
*/ */
XTensor OnehotToIndex(XTensor & onehot, int size) XTensor OnehotToIndex(const XTensor & onehot, int size)
{ {
CheckNTErrors(onehot.GetDim(-1) == size, "Illegal tensor dimension!"); CheckNTErrors(onehot.GetDim(-1) == size, "Illegal tensor dimension!");
CheckNTErrors(onehot.dataType == X_INT, "The onehot tensor must be in X_INT!") CheckNTErrors(onehot.dataType == X_INT, "The onehot tensor must be in X_INT!")
...@@ -123,10 +101,9 @@ convert index tensor to onehot tensor ...@@ -123,10 +101,9 @@ convert index tensor to onehot tensor
>> size - the last dimension size of the onehot tensor >> size - the last dimension size of the onehot tensor
*/ */
void _IndexToOnehot(const XTensor * index, XTensor * onehot, void _IndexToOnehot(const XTensor * index, XTensor * onehot,
float labelSmoothingP) int size, float labelSmoothingP)
{ {
int size = onehot->GetDim(-1); CheckNTErrors(onehot->GetDim(-1) == size, "Illegal tensor dimension!");
CheckNTErrors(onehot->order == index->order + 1, "Illegal tensor order!"); CheckNTErrors(onehot->order == index->order + 1, "Illegal tensor order!");
//CheckNTErrors(onehot->dataType == X_INT, "The onehot tensor must be in X_INT!") //CheckNTErrors(onehot->dataType == X_INT, "The onehot tensor must be in X_INT!")
CheckNTErrors(index->dataType == X_INT, "The index tensor must be in X_INT!") CheckNTErrors(index->dataType == X_INT, "The index tensor must be in X_INT!")
...@@ -171,7 +148,7 @@ make a new tensor to keep the result and return it ...@@ -171,7 +148,7 @@ make a new tensor to keep the result and return it
>> confidence - labelsmoothing >> confidence - labelsmoothing
<< return - the onehot tensor << return - the onehot tensor
*/ */
XTensor IndexToOnehot(XTensor & index, int size, float labelSmoothingP) XTensor IndexToOnehot(const XTensor & index, int size, float labelSmoothingP)
{ {
CheckNTErrors(index.dataType == X_INT, "The onehot tensor must be in X_INT!") CheckNTErrors(index.dataType == X_INT, "The onehot tensor must be in X_INT!")
...@@ -184,11 +161,11 @@ XTensor IndexToOnehot(XTensor & index, int size, float labelSmoothingP) ...@@ -184,11 +161,11 @@ XTensor IndexToOnehot(XTensor & index, int size, float labelSmoothingP)
dim[order] = size; dim[order] = size;
InitTensor(&onehot, index.order + 1, dim, X_FLOAT, 1.0F, index.devID, index.mem); InitTensor(&onehot, index.order + 1, dim, X_FLOAT, 1.0F, index.devID, index.mem);
_IndexToOnehot(&index, &onehot, labelSmoothingP); _IndexToOnehot(&index, &onehot, size, labelSmoothingP);
delete[] dim; delete[] dim;
return onehot; return onehot;
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -61,7 +61,7 @@ convert onehot tensor to index tensor (cuda version) ...@@ -61,7 +61,7 @@ convert onehot tensor to index tensor (cuda version)
>> index - index tensor, which value is an integer num >> index - index tensor, which value is an integer num
>> size - the last dimension size of the onehot tensor >> size - the last dimension size of the onehot tensor
*/ */
void _CudaOnehotToIndex(XTensor * onehot, XTensor * index, int size) void _CudaOnehotToIndex(const XTensor * onehot, XTensor * index, int size)
{ {
int devID = onehot->devID; int devID = onehot->devID;
...@@ -153,4 +153,4 @@ void _CudaIndexToOnehot(const XTensor * index, XTensor * onehot, ...@@ -153,4 +153,4 @@ void _CudaIndexToOnehot(const XTensor * index, XTensor * onehot,
#endif // USE_CUDA #endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -27,10 +27,11 @@ ...@@ -27,10 +27,11 @@
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
/* convert onehot tensor to index tensor (cuda version) */ /* convert onehot tensor to index tensor (cuda version) */
void _CudaOnehotToIndex(XTensor * onehot, XTensor * index, int size); void _CudaOnehotToIndex(const XTensor * onehot, XTensor * index, int size);
/* convert index tensor to onehot tensor (cuda version) */ /* convert index tensor to onehot tensor (cuda version) */
void _CudaIndexToOnehot(const XTensor * index, XTensor * onehot, int size, float confidence, float lowconfidence); void _CudaIndexToOnehot(const XTensor * index, XTensor * onehot,
int size, float confidence, float lowconfidence);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -27,19 +27,18 @@ ...@@ -27,19 +27,18 @@
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
/* convert onehot tensor to index tensor */ /* convert onehot tensor to index tensor */
void _OnehotToIndex(XTensor * onehot, XTensor * index, int dim); void _OnehotToIndex(const XTensor * onehot, XTensor * index, int size);
/* convert onehot tensor to index tensor (return an XTensor structure) /* convert onehot tensor to index tensor (return an XTensor structure)
make a new tensor to keep the result and return it */ make a new tensor to keep the result and return it */
XTensor OnehotToIndex(XTensor & onehot, int size); XTensor OnehotToIndex(const XTensor & onehot, int num);
/* convert index tensor to onehot tensor */ /* convert index tensor to onehot tensor */
void _IndexToOnehot(const XTensor * index, XTensor * onehot, void _IndexToOnehot(const XTensor * index, XTensor * onehot, int size, float labelSmoothingP);
float labelSmoothingP = 0.0F);
/* convert index tensor to onehot tensor (return an XTensor structure) /* convert index tensor to onehot tensor (return an XTensor structure)
make a new tensor to keep the result and return it */ make a new tensor to keep the result and return it */
XTensor IndexToOnehot(XTensor & index, int size, float labelSmoothingP = 0.0F); XTensor IndexToOnehot(const XTensor & index, int num, float labelSmoothingP);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -117,10 +117,12 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high) ...@@ -117,10 +117,12 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high)
_SelectRange(&a, &c, dim, low, high); _SelectRange(&a, &c, dim, low, high);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&a, NULL, &c, GETANDSET_SELECT); if (a.enableGrad) {
XLink::AddParamToHeadInt(&c, dim); XLink::MakeLink(&a, NULL, &c, GETANDSET_SELECT);
XLink::AddParamToHeadInt(&c, low); XLink::AddParamToHeadInt(&c, dim);
XLink::AddParamToHeadInt(&c, high); XLink::AddParamToHeadInt(&c, low);
XLink::AddParamToHeadInt(&c, high);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -526,6 +526,43 @@ void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper) ...@@ -526,6 +526,43 @@ void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
} }
} }
/* generate data items with a range by start, end and the step
>> tensor - the tensor whose data array would be initialized
>> start - the begin of the array
>> end - the end of the array (not included self)
>> step - the step of two items
*/
void _SetDataRange(XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE step)
{
CheckNTErrors((tensor->order == 1), "Tensor must be 1 dimension!");
/* compute the true length according to the (start, end, step) */
DTYPE size = fabs(upper - lower);
int num = ceil(size / fabs(step));
CheckNTErrors((tensor->unitNum == num), "Unit number of the tensor is not matched.");
/* init a integer array to store the sequence */
void * data = NULL;
if (tensor->dataType == X_INT) {
data = new int[num];
for (int i = 0; i < num; i++)
*((int*)data + i) = lower + i * step;
}
else if (tensor->dataType == X_FLOAT) {
data = new float[num];
for (int i = 0; i < num; i++)
*((float*)data + i) = lower + i * step;
}
else {
ShowNTErrors("TODO!");
}
/* set the data from the array */
tensor->SetData(data, num);
delete[] data;
}
/* /*
generate data items with a uniform distribution in [lower, upper] and set generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise the item to a pre-defined value if the item >= p, set the item to 0 otherwise
......
...@@ -69,6 +69,9 @@ void _SetDataRand(XTensor * tensor, int rNum, int cNum); ...@@ -69,6 +69,9 @@ void _SetDataRand(XTensor * tensor, int rNum, int cNum);
/* generate data items with a uniform distribution in [lower, upper] */ /* generate data items with a uniform distribution in [lower, upper] */
void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper); void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
/* generate data items with a range by start, end and the step */
void _SetDataRange(XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE step);
/* generate data items with a uniform distribution in [lower, upper] and set /* generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise */ the item to a pre-defined value if the item >= p, set the item to 0 otherwise */
void _SetDataRandP(XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value); void _SetDataRandP(XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value);
......
...@@ -167,7 +167,9 @@ XTensor funcName(const XTensor &a, T num) ...@@ -167,7 +167,9 @@ XTensor funcName(const XTensor &a, T num)
XTensor b(&a); \ XTensor b(&a); \
b.SetTMPFlag(); \ b.SetTMPFlag(); \
_funcName(&a, &b, num); \ _funcName(&a, &b, num); \
XLink::MakeLink(&a, NULL, &b, operationId); \ if(a.enableGrad){ \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
XLink::AddParamToHead(&b, num); \ XLink::AddParamToHead(&b, num); \
return b; \ return b; \
} \ } \
...@@ -183,7 +185,7 @@ void funcName(const XTensor &a, XTensor &b, T num) ...@@ -183,7 +185,7 @@ void funcName(const XTensor &a, XTensor &b, T num)
InitTensor(&b, &a); \ InitTensor(&b, &a); \
} \ } \
_funcName(&a, &b, num); \ _funcName(&a, &b, num); \
if (b.enableGrad) { \ if (a.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \ XLink::MakeLink(&a, NULL, &b, operationId); \
XLink::AddParamToHead(&b, num); \ XLink::AddParamToHead(&b, num); \
} \ } \
......
...@@ -36,26 +36,26 @@ set every entry to its clip value ...@@ -36,26 +36,26 @@ set every entry to its clip value
void _Clip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper) void _Clip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper)
{ {
#ifdef USE_CUDA #ifdef USE_CUDA
/* run it on GPUs */ /* run it on GPUs */
if (a->devID >= 0) { if (a->devID >= 0) {
_CudaClip(a, b, lower, upper); _CudaClip(a, b, lower, upper);
return; return;
} }
#endif #endif
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!"); CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!"); CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
DTYPE * d = (DTYPE*)a->data; DTYPE * d = (DTYPE*)a->data;
DTYPE * db = (DTYPE*)b->data; DTYPE * db = (DTYPE*)b->data;
for (int i = 0; i < a->unitNum; i++) { for (int i = 0; i < a->unitNum; i++) {
if (d[i] > upper) if (d[i] > upper)
db[i] = upper; db[i] = upper;
else if (d[i] < lower) else if (d[i] < lower)
db[i] = lower; db[i] = lower;
else else
db[i] = d[i]; db[i] = d[i];
} }
} }
/* /*
...@@ -99,9 +99,11 @@ XTensor Clip(const XTensor & a, DTYPE lower, DTYPE upper) ...@@ -99,9 +99,11 @@ XTensor Clip(const XTensor & a, DTYPE lower, DTYPE upper)
_Clip(&a, &b, lower, upper); _Clip(&a, &b, lower, upper);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_CLIP); if (a.enableGrad) {
XLink::AddParamToHead(&b, lower); XLink::MakeLink(&a, NULL, &b, MATH_CLIP);
XLink::AddParamToHead(&b, upper); XLink::AddParamToHead(&b, lower);
XLink::AddParamToHead(&b, upper);
}
return b; return b;
} }
...@@ -115,8 +117,8 @@ void Clip(const XTensor & a, XTensor & b, DTYPE lower, DTYPE upper) ...@@ -115,8 +117,8 @@ void Clip(const XTensor & a, XTensor & b, DTYPE lower, DTYPE upper)
/* call _Clip function */ /* call _Clip function */
_Clip(&a, &b, lower, upper); _Clip(&a, &b, lower, upper);
if (b.enableGrad) { /* tensor connections */
/* tensor connections */ if (a.enableGrad) {
XLink::MakeLink(&a, NULL, &b, MATH_CLIP); XLink::MakeLink(&a, NULL, &b, MATH_CLIP);
XLink::AddParamToHead(&b, lower); XLink::AddParamToHead(&b, lower);
XLink::AddParamToHead(&b, upper); XLink::AddParamToHead(&b, upper);
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
*/ */
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XDevice.h"
#include "../../XName.h" #include "../../XName.h"
#include "Compare.h" #include "Compare.h"
#include "Compare.cuh" #include "Compare.cuh"
...@@ -123,4 +124,95 @@ SIMPLE_COMPARE_FUNCTION_ME(NotEqualMe, _NotEqual) ...@@ -123,4 +124,95 @@ SIMPLE_COMPARE_FUNCTION_ME(NotEqualMe, _NotEqual)
SIMPLE_COMPARE_FUNCTION(NotEqual, _NotEqual, MATH_NOTEQUAL) SIMPLE_COMPARE_FUNCTION(NotEqual, _NotEqual, MATH_NOTEQUAL)
SIMPLE_COMPARE_FUNCTION_VOID(NotEqual, _NotEqual, MATH_NOTEQUAL) SIMPLE_COMPARE_FUNCTION_VOID(NotEqual, _NotEqual, MATH_NOTEQUAL)
/* define three marco separately, specify the respective function names */
#ifdef USE_CUDA
#define _SIMPLE_MAX_MIN_FUNCTION(_funcName, _cudaFuncName, origFunc) \
void _funcName(const XTensor * a, const XTensor * b, XTensor * c) \
{ \
CheckNTErrors((XTensor::IsSameShaped(a, b, c)), \
"Input and output tensors should have the same type!"); \
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!"); \
CheckDev(a->devID, b->devID); \
CheckDev(a->devID, c->devID); \
/* run it on GPUs */ \
if (a->devID >= 0) { \
_cudaFuncName(a, b, c); \
return; \
} \
DTYPE * da = (DTYPE*)a->data; \
DTYPE * db = (DTYPE*)b->data; \
DTYPE * dc = (DTYPE*)c->data; \
for (int i = 0; i < a->unitNum; i++) \
dc[i] = (DTYPE)origFunc(da[i], db[i]); \
}
#else
#define _SIMPLE_MAX_MIN_FUNCTION(_funcName, origFunc) \
void _funcName(const XTensor * a, const XTensor * b, XTensor *c) \
{ \
CheckNTErrors((XTensor::IsSameShaped(a, b, c)), \
"Input and output tensors should have the same type!"); \
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!"); \
CheckDev(a, b); \
CheckDev(a, c); \
/* run it on GPUs */ \
if (a->devID >= 0) { \
ShowNTErrors("No GPU devices support!") \
} \
DTYPE * da = (DTYPE*)a->data; \
DTYPE * db = (DTYPE*)b->data; \
DTYPE * dc = (DTYPE*)c->data; \
for (int i = 0; i < a->unitNum; i++) \
dc[i] = (DTYPE)origFunc(da[i], db[i]); \
}
#endif
#define _SIMPLE_MAX_MIN_FUNCTION_ME(_funcNameMe, _funcName) \
void _funcNameMe(XTensor * a, const XTensor * b) \
{ \
_funcName(a, b, a); \
}
#define SIMPLE_MAX_MIN_FUNCTION_ME(funcNameMe, _funcName) \
void funcNameMe(XTensor & a, const XTensor & b) \
{ \
_funcName(&a, &b, &a); \
}
#define SIMPLE_MAX_MIN_FUNCTION(funcName, _funcName, operationId) \
XTensor funcName(const XTensor & a, const XTensor & b) \
{ \
XTensor c(&a); \
c.SetTMPFlag(); \
_funcName(&a, &b, &c); \
return c; \
}
#define SIMPLE_MAX_MIN_FUNCTION_VOID(funcName, _funcName, operationId) \
void funcName(const XTensor &a, const XTensor &b, XTensor c) \
{ \
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) { \
InitTensor(&c, &a); \
} \
_funcName(&a, &b, &c); \
}
#ifdef USE_CUDA
_SIMPLE_MAX_MIN_FUNCTION(_Max, _CudaMax, max)
_SIMPLE_MAX_MIN_FUNCTION(_Min, _CudaMin, min)
#else
_SIMPLE_MAX_MIN_FUNCTION(_Max, max)
_SIMPLE_MAX_MIN_FUNCTION(_Min, min)
#endif
_SIMPLE_MAX_MIN_FUNCTION_ME(_MaxMe, _Max)
SIMPLE_MAX_MIN_FUNCTION_ME(MaxMe, _Max)
SIMPLE_MAX_MIN_FUNCTION(Max, _Max, MATH_MAX)
SIMPLE_MAX_MIN_FUNCTION_VOID(Max, _Max, MATH_MAX)
_SIMPLE_MAX_MIN_FUNCTION_ME(_MinMe, _Min)
SIMPLE_MAX_MIN_FUNCTION_ME(MinMe, _Min)
SIMPLE_MAX_MIN_FUNCTION(Min, _Min, MATH_MIN)
SIMPLE_MAX_MIN_FUNCTION_VOID(Min, _Min, MATH_MIN)
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -89,6 +89,53 @@ void _Cuda##funcName(const XTensor * a, XTensor * b, DTYPE number) \ ...@@ -89,6 +89,53 @@ void _Cuda##funcName(const XTensor * a, XTensor * b, DTYPE number) \
SIMPLE_COMPARE_FUNCTION_GPU(Equal, cudaIsEqual) SIMPLE_COMPARE_FUNCTION_GPU(Equal, cudaIsEqual)
SIMPLE_COMPARE_FUNCTION_GPU(NotEqual, cudaIsNotEqual) SIMPLE_COMPARE_FUNCTION_GPU(NotEqual, cudaIsNotEqual)
#define SIMPLE_MAX_MIN_FUNCTION_GPU(funcName, origFunc) \
__global__ \
void Kernel##funcName(DTYPE * a, DTYPE * b, DTYPE * c, int size) \
{ \
int i = blockDim.x * blockIdx.x + threadIdx.x; \
\
if (i < size) \
c[i] = (DTYPE)origFunc(a[i], b[i]); \
} \
__global__ \
void Kernel##funcName(__half * a, __half * b, __half * c, int size) \
{ \
return; \
} \
void _Cuda##funcName(const XTensor * a, const XTensor * b, XTensor * c) \
{ \
\
int gridSize[3]; \
int blockSize[3]; \
\
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize); \
\
dim3 blocks(gridSize[0]); \
dim3 threads(blockSize[0]); \
\
int devIDBackup; \
ProtectCudaDev(a->devID, devIDBackup); \
\
if (a->dataType == DEFAULT_DTYPE) { \
Kernel##funcName<<<blocks, threads>>> \
((DTYPE*)a->data, (DTYPE*)b->data, \
(DTYPE*)c->data, a->unitNum); \
} \
else if (a->dataType == X_FLOAT16) { \
Kernel##funcName<<<blocks, threads>>> \
((__half*)a->data, (__half*)b->data, \
(__half*)c->data, a->unitNum); \
} \
else { \
ShowNTErrors("TODO!"); \
} \
\
BacktoCudaDev(a->devID, devIDBackup); \
}
SIMPLE_MAX_MIN_FUNCTION_GPU(Max, max)
SIMPLE_MAX_MIN_FUNCTION_GPU(Min, min)
#endif // USE_CUDA #endif // USE_CUDA
......
...@@ -34,6 +34,12 @@ void _CudaEqual(const XTensor * a, XTensor * b, DTYPE value); ...@@ -34,6 +34,12 @@ void _CudaEqual(const XTensor * a, XTensor * b, DTYPE value);
/* check whether every entry is not equal to the given value (cuda version) */ /* check whether every entry is not equal to the given value (cuda version) */
void _CudaNotEqual(const XTensor * a, XTensor * b, DTYPE value); void _CudaNotEqual(const XTensor * a, XTensor * b, DTYPE value);
/* return maximum of two tensor for each items (cuda version) */
void _CudaMax(const XTensor * a, const XTensor * b, XTensor *c);
/* return minimum of two tensor for each items (cuda version) */
void _CudaMin(const XTensor * a, const XTensor * b, XTensor *c);
#endif // USE_CUDA #endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -56,6 +56,36 @@ XTensor NotEqual(const XTensor & a, DTYPE value); ...@@ -56,6 +56,36 @@ XTensor NotEqual(const XTensor & a, DTYPE value);
/* check whether every entry is not equal to the given value */ /* check whether every entry is not equal to the given value */
void NotEqual(const XTensor & a, XTensor & b, DTYPE value); void NotEqual(const XTensor & a, XTensor & b, DTYPE value);
/* return maximum of two tensor for each items */
void _Max(const XTensor * a, const XTensor * b, XTensor * c);
/* return maximum of two tensor for each items (do it on site) */
void _MaxMe(XTensor * a, const XTensor * b);
/* return maximum of two tensor for each items (do it on site) */
void MaxMe(XTensor & a, const XTensor & b);
/* return maximum of two tensor for each items (return an XTensor structure) */
XTensor Max(const XTensor & a, const XTensor & b);
/* return maximum of two tensor for each items */
void Max(const XTensor & a, const XTensor & b, XTensor & c);
/* return minimum of two tensor for each items */
void _Min(const XTensor * a, const XTensor * b, XTensor * c);
/* return minimum of two tensor for each items (do it on site) */
void _MinMe(XTensor * a, const XTensor * b);
/* return minimum of two tensor for each items (do it on site) */
void MinMe(XTensor & a, const XTensor & b);
/* return minimum of two tensor for each items (return an XTensor structure) */
XTensor Min(const XTensor & a, const XTensor & b);
/* return minimum of two tensor for each items */
void Min(const XTensor & a, const XTensor & b, XTensor & c);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // end __COMPARE_H__ #endif // end __COMPARE_H__
\ No newline at end of file
...@@ -46,7 +46,7 @@ void _Normalize(const XTensor * input, XTensor * output, int dim, ...@@ -46,7 +46,7 @@ void _Normalize(const XTensor * input, XTensor * output, int dim,
const XTensor * mean, const XTensor * var, const XTensor * mean, const XTensor * var,
const XTensor * a, const XTensor * b, DTYPE epsilon) const XTensor * a, const XTensor * b, DTYPE epsilon)
{ {
int dimRDI = input->order - dim - 1; int dimRDI = input->order - dim - 1;
CheckNTErrors((XTensor::IsSameShaped(input, output)), "Unmatched input tensors!"); CheckNTErrors((XTensor::IsSameShaped(input, output)), "Unmatched input tensors!");
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Unmatched input tensors"); CheckNTErrors((XTensor::IsSameShaped(a, b)), "Unmatched input tensors");
CheckNTErrors((XTensor::IsSameShaped(mean, var)), "Unmatched input tensors"); CheckNTErrors((XTensor::IsSameShaped(mean, var)), "Unmatched input tensors");
...@@ -173,9 +173,11 @@ XTensor Normalize(const XTensor &input, int dim, ...@@ -173,9 +173,11 @@ XTensor Normalize(const XTensor &input, int dim,
list.Add((XTensor*)&var); list.Add((XTensor*)&var);
list.Add((XTensor*)&a); list.Add((XTensor*)&a);
list.Add((XTensor*)&b); list.Add((XTensor*)&b);
XLink::MakeLink(&list, &output, MATH_NORMALIZE); if (input.enableGrad) {
XLink::AddParamToHeadInt(&output, dim); XLink::MakeLink(&list, &output, MATH_NORMALIZE);
XLink::AddParamToHead(&output, epsilon); XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHead(&output, epsilon);
}
return output; return output;
} }
...@@ -208,7 +210,7 @@ void Normalize(const XTensor &input, XTensor &output, int dim, ...@@ -208,7 +210,7 @@ void Normalize(const XTensor &input, XTensor &output, int dim,
/* call _Normalize function */ /* call _Normalize function */
_Normalize(&input, &output, dim, &mean, &var, &a, &b, epsilon); _Normalize(&input, &output, dim, &mean, &var, &a, &b, epsilon);
if (output.enableGrad == true) { if (input.enableGrad == true) {
/* tensor connections */ /* tensor connections */
TensorList list(5); TensorList list(5);
list.Add((XTensor*)&input); list.Add((XTensor*)&input);
......
...@@ -126,9 +126,11 @@ XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift) ...@@ -126,9 +126,11 @@ XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift)
_ScaleAndShift(&a, &b, scale, shift); _ScaleAndShift(&a, &b, scale, shift);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT); if (a.enableGrad) {
XLink::AddParamToHead(&b, scale); XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT);
XLink::AddParamToHead(&b, shift); XLink::AddParamToHead(&b, scale);
XLink::AddParamToHead(&b, shift);
}
return b; return b;
} }
...@@ -152,7 +154,7 @@ void ScaleAndShift(const XTensor & a, XTensor & b, DTYPE scale, DTYPE shift) ...@@ -152,7 +154,7 @@ void ScaleAndShift(const XTensor & a, XTensor & b, DTYPE scale, DTYPE shift)
/* call _ScaleAndShift function */ /* call _ScaleAndShift function */
_ScaleAndShift(&a, &b, scale, shift); _ScaleAndShift(&a, &b, scale, shift);
if (b.enableGrad) { if (a.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT); XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT);
XLink::AddParamToHead(&b, scale); XLink::AddParamToHead(&b, scale);
......
...@@ -151,7 +151,9 @@ XTensor funcName(const XTensor & a) ...@@ -151,7 +151,9 @@ XTensor funcName(const XTensor & a)
XTensor b(&a); \ XTensor b(&a); \
b.SetTMPFlag(); \ b.SetTMPFlag(); \
_funcName(&a, &b); \ _funcName(&a, &b); \
XLink::MakeLink(&a, NULL, &b, operationId); \ if(a.enableGrad){ \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
return b; \ return b; \
} }
...@@ -162,7 +164,7 @@ void funcName(const XTensor & a, XTensor & b) ...@@ -162,7 +164,7 @@ void funcName(const XTensor & a, XTensor & b)
InitTensor(&b, &a); \ InitTensor(&b, &a); \
} \ } \
_funcName(&a, &b); \ _funcName(&a, &b); \
if (b.enableGrad) { \ if (a.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \ XLink::MakeLink(&a, NULL, &b, operationId); \
} \ } \
} }
......
...@@ -258,10 +258,12 @@ XTensor CopyIndexed(const XTensor & s, int dim, ...@@ -258,10 +258,12 @@ XTensor CopyIndexed(const XTensor & s, int dim,
list.Add((XTensor*)&tgtIndex); list.Add((XTensor*)&tgtIndex);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&list, &t, MOVEMENT_COPYINDEXED); if (s.enableGrad) {
XLink::AddParamToHeadInt(&t, dim); XLink::MakeLink(&list, &t, MOVEMENT_COPYINDEXED);
XLink::AddParamToHeadInt(&t, copyNum); XLink::AddParamToHeadInt(&t, dim);
XLink::AddParamToHeadInt(&t, copyNum);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -314,13 +316,15 @@ XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, in ...@@ -314,13 +316,15 @@ XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, in
memcpy(saveTgtIndex, tgtIndex, indexSize * sizeof(int)); memcpy(saveTgtIndex, tgtIndex, indexSize * sizeof(int));
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYINDEXED); if (s.enableGrad) {
XLink::AddParamToHeadInt(&t, dim); XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYINDEXED);
XLink::AddParamToHeadPointer(&t, saveSrcIndex); XLink::AddParamToHeadInt(&t, dim);
XLink::AddParamToHeadInt(&t, indexSize); XLink::AddParamToHeadPointer(&t, saveSrcIndex);
XLink::AddParamToHeadPointer(&t, saveTgtIndex); XLink::AddParamToHeadInt(&t, indexSize);
XLink::AddParamToHeadInt(&t, copyNum); XLink::AddParamToHeadPointer(&t, saveTgtIndex);
XLink::AddParamToHeadInt(&t, copyNum);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -134,7 +134,9 @@ XTensor CopyValues(const XTensor &s, XStream * stream) ...@@ -134,7 +134,9 @@ XTensor CopyValues(const XTensor &s, XStream * stream)
_CopyValues(&s, &t, stream); _CopyValues(&s, &t, stream);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYVALUES); if (s.enableGrad) {
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYVALUES);
}
return t; return t;
} }
......
...@@ -93,7 +93,9 @@ XTensor Gather(XTensor &s, XTensor &index) ...@@ -93,7 +93,9 @@ XTensor Gather(XTensor &s, XTensor &index)
_Gather(&s, &t, &index); _Gather(&s, &t, &index);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&s, &index, &t, MOVEMENT_GATHER); if (s.enableGrad) {
XLink::MakeLink(&s, &index, &t, MOVEMENT_GATHER);
}
return t; return t;
} }
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XBLAS.h"
#include "VectorBuffer.h"
#include "ReduceMax.h" #include "ReduceMax.h"
#include "ReduceMax.cuh" #include "ReduceMax.cuh"
...@@ -41,8 +43,8 @@ void _ReduceMax(const XTensor * input, XTensor * output, int dim) ...@@ -41,8 +43,8 @@ void _ReduceMax(const XTensor * input, XTensor * output, int dim)
CheckNTErrors((input->order == output->order + 1), "Incorrect tensor sizes!"); CheckNTErrors((input->order == output->order + 1), "Incorrect tensor sizes!");
CheckNTErrors((input->order > dim && dim >=0), "Illegal dimension to reduce!"); CheckNTErrors((input->order > dim && dim >=0), "Illegal dimension to reduce!");
CheckNTErrors((input->dataType == output->dataType), "Unmatched data types!"); CheckNTErrors((input->dataType == output->dataType), "Unmatched data types!");
int dimRDI = input->order - dim - 1; int dimRDI = input->order - dim - 1;
CheckNTErrors(dimRDI >= 0, "Wrong dimension!"); CheckNTErrors(dimRDI >= 0, "Wrong dimension!");
for(int i = 0; i < input->order; i++){ for(int i = 0; i < input->order; i++){
...@@ -76,18 +78,75 @@ void _ReduceMax(const XTensor * input, XTensor * output, int dim) ...@@ -76,18 +78,75 @@ void _ReduceMax(const XTensor * input, XTensor * output, int dim)
} }
blockSize = stride * strideNum; blockSize = stride * strideNum;
for(int k = 0; k < blockNum; k++){ if(input->dimSizeRDI[0] % (4 * 32 / sizeof(DTYPE)) == 0 && input->dimSizeRDI[0] >= 32){
DTYPE * ip = (DTYPE*)input->data + blockSize * k; int vecBufLength = 32 / sizeof(DTYPE);
DTYPE * op = (DTYPE*)output->data + stride * k;
for(int i = 0; i < stride; i++){ if(dimRDI == 0){
DTYPE max = FLOAT_MIN; //data is contiguous in dim 0
DTYPE * ipe = ip + blockSize; for(int i = 0; i < blockNum; i++){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE v = *ipb; DTYPE * op = (DTYPE*)output->data + i;
if(max < v) VectorBuffer vecBuf[4];
max = v; for(int j = 0; j < 4; j++){
vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip) + j * vecBufLength);
}
for(int j = 1; j < strideNum / 32; j++){
const DTYPE* ptr = (DTYPE*)(ip + j * vecBufLength);
vecBuf[0] = vecBuf[0].maxData(VectorBuffer::loadu(ptr + 0 * vecBufLength));
vecBuf[1] = vecBuf[1].maxData(VectorBuffer::loadu(ptr + 1 * vecBufLength));
vecBuf[2] = vecBuf[2].maxData(VectorBuffer::loadu(ptr + 2 * vecBufLength));
vecBuf[3] = vecBuf[3].maxData(VectorBuffer::loadu(ptr + 3 * vecBufLength));
}
vecBuf[0] = vecBuf[0].maxData(vecBuf[1]);
vecBuf[0] = vecBuf[0].maxData(vecBuf[2]);
vecBuf[0] = vecBuf[0].maxData(vecBuf[3]);
DTYPE maxN = DTYPE_MIN;
for(int k = 0; k < vecBufLength; k++){
maxN = MAX(maxN,vecBuf[0][k]);
}
*op = maxN;
}
} else{
//data is separated
for(int i = 0; i < blockNum; i++){
for(int j = 0; j < input->dimSizeRDI[0] / 32; j++){
DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE * op = (DTYPE*)output->data + stride * i;
VectorBuffer vecBuf[4];
for(int k = 0; k < 4; k++){
vecBuf[k] = VectorBuffer::loadu((DTYPE*)(ip) + (j * 4 + k) * 32 / sizeof(DTYPE));
}
for(int k = 1; k < strideNum; k++){
DTYPE * ptr = ip + k * stride + (j * 4) * vecBufLength;
vecBuf[0] = vecBuf[0].maxData(VectorBuffer::loadu(ptr + 0 * vecBufLength));
vecBuf[1] = vecBuf[1].maxData(VectorBuffer::loadu(ptr + 1 * vecBufLength));
vecBuf[2] = vecBuf[2].maxData(VectorBuffer::loadu(ptr + 2 * vecBufLength));
vecBuf[3] = vecBuf[3].maxData(VectorBuffer::loadu(ptr + 3 * vecBufLength));
}
for(int k = 0; k < 4; k++){
for(int l = 0; l < vecBufLength; l++)
*(op + j * 32 + 8 * k + l) = vecBuf[k][l];
}
}
}
}
}//run vector buffer
else{
for(int k = 0; k < blockNum; k++){
DTYPE * ip = (DTYPE*)input->data + blockSize * k;
DTYPE * op = (DTYPE*)output->data + stride * k;
for(int i = 0; i < stride; i++){
DTYPE max = DTYPE_MIN;
DTYPE * ipe = ip + blockSize;
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE v = *ipb;
if(max < v)
max = v;
}
*(op + i) = max;
} }
*(op + i) = max;
} }
} }
} }
...@@ -104,7 +163,7 @@ make a new tensor to keep the result and return it ...@@ -104,7 +163,7 @@ make a new tensor to keep the result and return it
XTensor ReduceMax(const XTensor &input, int dim) XTensor ReduceMax(const XTensor &input, int dim)
{ {
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!"); CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!");
int order = input.order - 1; int order = input.order - 1;
int * dimSize = new int[order]; int * dimSize = new int[order];
for(int i = 0; i < order; i++){ for(int i = 0; i < order; i++){
...@@ -122,8 +181,10 @@ XTensor ReduceMax(const XTensor &input, int dim) ...@@ -122,8 +181,10 @@ XTensor ReduceMax(const XTensor &input, int dim)
_ReduceMax(&input, &output, dim); _ReduceMax(&input, &output, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX); if (input.enableGrad) {
XLink::AddParamToHeadInt(&output, dim); XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX);
XLink::AddParamToHeadInt(&output, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -162,7 +223,7 @@ void ReduceMax(const XTensor &input, XTensor &output, int dim) ...@@ -162,7 +223,7 @@ void ReduceMax(const XTensor &input, XTensor &output, int dim)
/* call _ReduceMax function */ /* call _ReduceMax function */
_ReduceMax(&input, &output, dim); _ReduceMax(&input, &output, dim);
if (output.enableGrad) { if (input.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX); XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX);
XLink::AddParamToHeadInt(&output, dim); XLink::AddParamToHeadInt(&output, dim);
......
...@@ -39,7 +39,7 @@ void _ReduceMean(const XTensor * input, XTensor * output, int dim) ...@@ -39,7 +39,7 @@ void _ReduceMean(const XTensor * input, XTensor * output, int dim)
{ {
CheckNTErrors((input->order > dim), "Illegal dimension specified!"); CheckNTErrors((input->order > dim), "Illegal dimension specified!");
int dimRDI = input->order - dim - 1; int dimRDI = input->order - dim - 1;
int num = input->dimSizeRDI[dimRDI]; int num = input->dimSizeRDI[dimRDI];
_ReduceSum(input, output, dim); _ReduceSum(input, output, dim);
...@@ -59,7 +59,7 @@ For a 1-dimensional data array a, mean = (1/n) * sum_i input_i ...@@ -59,7 +59,7 @@ For a 1-dimensional data array a, mean = (1/n) * sum_i input_i
XTensor ReduceMean(const XTensor &input, int dim) XTensor ReduceMean(const XTensor &input, int dim)
{ {
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!"); CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!");
int order = input.order - 1; int order = input.order - 1;
int * dimSize = new int[order]; int * dimSize = new int[order];
for(int i = 0; i < order; i++){ for(int i = 0; i < order; i++){
...@@ -77,8 +77,10 @@ XTensor ReduceMean(const XTensor &input, int dim) ...@@ -77,8 +77,10 @@ XTensor ReduceMean(const XTensor &input, int dim)
_ReduceMean(&input, &output, dim); _ReduceMean(&input, &output, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMEAN); if (input.enableGrad) {
XLink::AddParamToHeadInt(&output, dim); XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMEAN);
XLink::AddParamToHeadInt(&output, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -119,7 +121,7 @@ void ReduceMean(const XTensor &input, XTensor &output, int dim) ...@@ -119,7 +121,7 @@ void ReduceMean(const XTensor &input, XTensor &output, int dim)
/* call _ReduceMean function */ /* call _ReduceMean function */
_ReduceMean(&input, &output, dim); _ReduceMean(&input, &output, dim);
if (output.enableGrad) { if (input.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMEAN); XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMEAN);
XLink::AddParamToHeadInt(&output, dim); XLink::AddParamToHeadInt(&output, dim);
......
...@@ -55,7 +55,7 @@ For a 1-dimensional data array a, sum = \sum_i (a_i - shift)^2 ...@@ -55,7 +55,7 @@ For a 1-dimensional data array a, sum = \sum_i (a_i - shift)^2
XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift) XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift)
{ {
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!"); CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!");
int order = input.order - 1; int order = input.order - 1;
int * dimSize = new int[order]; int * dimSize = new int[order];
for(int i = 0; i < order; i++){ for(int i = 0; i < order; i++){
...@@ -73,8 +73,10 @@ XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift) ...@@ -73,8 +73,10 @@ XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift)
_ReduceSumSquared(&input, &output, dim, &shift); _ReduceSumSquared(&input, &output, dim, &shift);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUMSQUARED); if (input.enableGrad) {
XLink::AddParamToHeadInt(&output, dim); XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUMSQUARED);
XLink::AddParamToHeadInt(&output, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -116,7 +118,7 @@ void ReduceSumSquared(const XTensor &input, XTensor &output, int dim, const XTen ...@@ -116,7 +118,7 @@ void ReduceSumSquared(const XTensor &input, XTensor &output, int dim, const XTen
/* call _ReduceSumSquared function */ /* call _ReduceSumSquared function */
_ReduceSumSquared(&input, &output, dim, &shift); _ReduceSumSquared(&input, &output, dim, &shift);
if (output.enableGrad) { if (input.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUMSQUARED); XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUMSQUARED);
XLink::AddParamToHeadInt(&output, dim); XLink::AddParamToHeadInt(&output, dim);
......
...@@ -38,7 +38,7 @@ For a 1-dimensional data array a, variance = 1/n * \sum_i (a_i - mean)^2 ...@@ -38,7 +38,7 @@ For a 1-dimensional data array a, variance = 1/n * \sum_i (a_i - mean)^2
*/ */
void _ReduceVariance(const XTensor * input, XTensor * output, int dim, const XTensor * mean) void _ReduceVariance(const XTensor * input, XTensor * output, int dim, const XTensor * mean)
{ {
int dimRDI = input->order - dim - 1; int dimRDI = input->order - dim - 1;
int num = input->dimSizeRDI[dimRDI]; int num = input->dimSizeRDI[dimRDI];
_ReduceSum(input, output, dim, mean, 2.0F); _ReduceSum(input, output, dim, mean, 2.0F);
_ScaleAndShiftMe(output, (DTYPE)1 / num, 0); _ScaleAndShiftMe(output, (DTYPE)1 / num, 0);
...@@ -58,7 +58,7 @@ For a 1-dimensional data array a, variance = 1/n * \sum_i (a_i - mean)^2 ...@@ -58,7 +58,7 @@ For a 1-dimensional data array a, variance = 1/n * \sum_i (a_i - mean)^2
XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean) XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean)
{ {
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!"); CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!");
int order = input.order - 1; int order = input.order - 1;
int * dimSize = new int[order]; int * dimSize = new int[order];
for(int i = 0; i < order; i++){ for(int i = 0; i < order; i++){
...@@ -76,8 +76,10 @@ XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean) ...@@ -76,8 +76,10 @@ XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean)
_ReduceVariance(&input, &output, dim, &mean); _ReduceVariance(&input, &output, dim, &mean);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, &mean, &output, REDUCE_REDUCEVARIANCE); if (input.enableGrad) {
XLink::AddParamToHeadInt(&output, dim); XLink::MakeLink(&input, &mean, &output, REDUCE_REDUCEVARIANCE);
XLink::AddParamToHeadInt(&output, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -119,7 +121,7 @@ void ReduceVariance(const XTensor &input, XTensor &output, int dim, const XTenso ...@@ -119,7 +121,7 @@ void ReduceVariance(const XTensor &input, XTensor &output, int dim, const XTenso
/* call _ReduceVariance function */ /* call _ReduceVariance function */
_ReduceVariance(&input, &output, dim, &mean); _ReduceVariance(&input, &output, dim, &mean);
if (output.enableGrad) { if (input.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, &mean, &output, REDUCE_REDUCEVARIANCE); XLink::MakeLink(&input, &mean, &output, REDUCE_REDUCEVARIANCE);
XLink::AddParamToHeadInt(&output, dim); XLink::AddParamToHeadInt(&output, dim);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: ZHANG Yuhao (email: zhangyuhao@stu.neu.edu.cn) 2019-07-23
*/
#include "VectorBuffer.h"
namespace nts {
/* data size for each buffer */
int VectorBuffer::size()
{
return 32 / sizeof(DTYPE);
}
/* constructor */
VectorBuffer::VectorBuffer()
{
}
/*
constructor
initial values with val
*/
VectorBuffer::VectorBuffer(DTYPE val)
{
for (int i = 0; i != size(); i++) {
values[i] = val;
}
}
/* load data */
VectorBuffer VectorBuffer::loadu(const DTYPE* ptr, bool isExp , DTYPE power , DTYPE* bias )
{
int count = 32 / sizeof(DTYPE);
VectorBuffer vec;
if (isExp) {
if (bias == NULL) {
if (power == (DTYPE)1.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(*(ptr + i));
}
}
else if (power == (DTYPE)2.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp((*(ptr + i)) * (*(ptr + i)));
}
}
else if (power == (DTYPE)0.5) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(sqrt(*(ptr + i)));
}
}
else {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(pow(*(ptr + i), power));
}
}
}/*is bias == NULL*/
else {
if (power == (DTYPE)1.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(*(ptr + i) - bias[i]);
}
}
else if (power == (DTYPE)2.0) {
for (int i = 0; i != count; i++) {
DTYPE value = *(ptr + i) - bias[i];
vec.values[i] = (DTYPE)exp(value * value);
}
}
else if (power == (DTYPE)0.5) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(sqrt(*(ptr + i) - bias[i]));
}
}
else {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(pow(*(ptr + i) - bias[i], power));
}
}
}
}//isExp
else {
if (bias == NULL) {
if (power == (DTYPE)1.0) {
memcpy(vec.values, ptr, count * sizeof(DTYPE));
}
else if (power == (DTYPE)2.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = (*(ptr + i)) * (*(ptr + i));
}
}
else if (power == (DTYPE)0.5) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)sqrt(*(ptr + i));
}
}
else {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)pow(*(ptr + i), power);
}
}
}// if bias == NULL
else {
if (power == (DTYPE)1.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = *(ptr + i) - bias[i];
}
}
else if (power == (DTYPE)2.0) {
for (int i = 0; i != count; i++) {
DTYPE value = *(ptr + i) - bias[i];
vec.values[i] = value * value;
}
}
else if (power == (DTYPE)0.5) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)sqrt(*(ptr + i) - bias[i]);
}
}
else {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)pow(*(ptr + i) - bias[i], power);
}
}
}
}
return vec;
}
/* overloading [] */
const DTYPE& VectorBuffer::operator[](int idx)const
{
return values[idx];
}
/* overloading + */
VectorBuffer VectorBuffer::operator+(const VectorBuffer &a)
{
for (int i = 0; i != a.size(); i++) {
this->values[i] = a[i] + this->values[i];
}
return *this;
}
/* conculte the max of two buffer */
VectorBuffer VectorBuffer::maxData(const VectorBuffer &a) {
for (int i = 0; i != a.size(); i++) {
this->values[i] = MAX(a[i], this->values[i]);
}
return *this;
}
}/* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: ZHANG Yuhao (email: zhangyuhao@stu.neu.edu.cn) 2019-07-23
*/
//#include <cstring>
#include <math.h>
#include "../../XGlobal.h"
namespace nts {
class VectorBuffer {
private:
/* buffer for concluter */
DTYPE values[32 / sizeof(DTYPE)] = { 0 };
public:
/* data size for each buffer */
static int size();
/* constructor */
VectorBuffer();
/* constructor */
VectorBuffer(DTYPE val);
/* load data */
static VectorBuffer loadu(const DTYPE* ptr, bool isExp = false, DTYPE power = (DTYPE)1.0F, DTYPE* bias = NULL);
/* overloading [] */
const DTYPE& operator[](int idx)const;
/* overloading + */
VectorBuffer operator+(const VectorBuffer &a);
/* conculte the max of two buffer */
VectorBuffer maxData(const VectorBuffer &a);
};
}
\ No newline at end of file
...@@ -99,9 +99,11 @@ XTensor Concatenate(const TensorList &smalls, int dim) ...@@ -99,9 +99,11 @@ XTensor Concatenate(const TensorList &smalls, int dim)
_Merge(&smalls, &big, dim); _Merge(&smalls, &big, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE); if (tensor->enableGrad) {
XLink::AddParamToHeadInt(&big, dim); XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -127,8 +129,10 @@ XTensor Concatenate(const TensorList &smalls, int dim) ...@@ -127,8 +129,10 @@ XTensor Concatenate(const TensorList &smalls, int dim)
_ConcatenateSolely(&smalls, &big, dim); _ConcatenateSolely(&smalls, &big, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE); if (tensor->enableGrad) {
XLink::AddParamToHeadInt(&big, dim); XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -309,9 +313,11 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim) ...@@ -309,9 +313,11 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
_Merge(&smalls, &big, dim); _Merge(&smalls, &big, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE); if (tensor->enableGrad) {
XLink::AddParamToHeadInt(&big, dim); XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -337,8 +343,10 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim) ...@@ -337,8 +343,10 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
_ConcatenateSolely(&smalls, &big, dim); _ConcatenateSolely(&smalls, &big, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE); if (tensor->enableGrad) {
XLink::AddParamToHeadInt(&big, dim); XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -222,9 +222,11 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim) ...@@ -222,9 +222,11 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
_Merge(&s, &t, whereToMerge, leadingDim); _Merge(&s, &t, whereToMerge, leadingDim);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE); if (s.enableGrad) {
XLink::AddParamToHeadInt(&t, whereToMerge); XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
XLink::AddParamToHeadInt(&t, leadingDim); XLink::AddParamToHeadInt(&t, whereToMerge);
XLink::AddParamToHeadInt(&t, leadingDim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -261,7 +263,7 @@ void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim) ...@@ -261,7 +263,7 @@ void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim)
/* call _Merge function */ /* call _Merge function */
_Merge(&s, &t, whereToMerge, leadingDim); _Merge(&s, &t, whereToMerge, leadingDim);
if (t.enableGrad) { if (s.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE); XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
XLink::AddParamToHeadInt(&t, whereToMerge); XLink::AddParamToHeadInt(&t, whereToMerge);
...@@ -412,8 +414,10 @@ XTensor Merge(const TensorList &smalls, int whereToMerge) ...@@ -412,8 +414,10 @@ XTensor Merge(const TensorList &smalls, int whereToMerge)
_Merge(&smalls, &big, whereToMerge); _Merge(&smalls, &big, whereToMerge);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST); if (tensor->enableGrad) {
XLink::AddParamToHeadInt(&big, whereToMerge); XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
XLink::AddParamToHeadInt(&big, whereToMerge);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -453,8 +457,10 @@ XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge) ...@@ -453,8 +457,10 @@ XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge)
_Merge(&smalls, &big, whereToMerge); _Merge(&smalls, &big, whereToMerge);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST); if (smallA.enableGrad) {
XLink::AddParamToHeadInt(&big, whereToMerge); XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
XLink::AddParamToHeadInt(&big, whereToMerge);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -43,9 +43,11 @@ XTensor Reshape(XTensor &s, int order, int * dimSize) ...@@ -43,9 +43,11 @@ XTensor Reshape(XTensor &s, int order, int * dimSize)
t.Reshape(order, dimSize); t.Reshape(order, dimSize);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE); if (s.enableGrad) {
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE);
}
return t; return t;
} }
void Reshape(XTensor &s, XTensor &t, int order, int * dimSize) void Reshape(XTensor &s, XTensor &t, int order, int * dimSize)
...@@ -57,7 +59,7 @@ void Reshape(XTensor &s, XTensor &t, int order, int * dimSize) ...@@ -57,7 +59,7 @@ void Reshape(XTensor &s, XTensor &t, int order, int * dimSize)
/* call Reshape function */ /* call Reshape function */
t.Reshape(order, dimSize); t.Reshape(order, dimSize);
if (t.enableGrad) { if (s.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE); XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE);
} }
......
...@@ -217,9 +217,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum) ...@@ -217,9 +217,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
_Split(&s, &t, whereToSplit, splitNum); _Split(&s, &t, whereToSplit, splitNum);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT); if (s.enableGrad) {
XLink::AddParamToHeadInt(&t, whereToSplit); XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT);
XLink::AddParamToHeadInt(&t, splitNum); XLink::AddParamToHeadInt(&t, whereToSplit);
XLink::AddParamToHeadInt(&t, splitNum);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -251,7 +253,7 @@ void Split(const XTensor &s, XTensor &t, int whereToSplit, int splitNum) ...@@ -251,7 +253,7 @@ void Split(const XTensor &s, XTensor &t, int whereToSplit, int splitNum)
/* call _Split function */ /* call _Split function */
_Split(&s, &t, whereToSplit, splitNum); _Split(&s, &t, whereToSplit, splitNum);
if (t.enableGrad) { if (s.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT); XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT);
XLink::AddParamToHeadInt(&t, whereToSplit); XLink::AddParamToHeadInt(&t, whereToSplit);
...@@ -409,12 +411,15 @@ void Split(const XTensor &big, TensorList &smalls, int whereToSplit, int splitNu ...@@ -409,12 +411,15 @@ void Split(const XTensor &big, TensorList &smalls, int whereToSplit, int splitNu
/* tensor connections */ /* tensor connections */
for(int i = 0; i < smalls.count; i++){ for(int i = 0; i < smalls.count; i++){
XTensor * s = (XTensor*)smalls.Get(i); XTensor * s = (XTensor*)smalls.Get(i);
XLink::MakeLink(&big, NULL, s, SHAPE_SPLIT_LIST);
XLink::AddParamToHeadInt(s, whereToSplit);
/* it is tricky here that we keep the id of each if (s->enableGrad) {
block, rather than the total number of the splits */ XLink::MakeLink(&big, NULL, s, SHAPE_SPLIT_LIST);
XLink::AddParamToHeadInt(s, i); XLink::AddParamToHeadInt(s, whereToSplit);
/* it is tricky here that we keep the id of each
block, rather than the total number of the splits */
XLink::AddParamToHeadInt(s, i);
}
} }
} }
......
...@@ -121,7 +121,9 @@ XTensor Squeeze(XTensor & source, int leadingDim) ...@@ -121,7 +121,9 @@ XTensor Squeeze(XTensor & source, int leadingDim)
_Squeeze(&source, &target, leadingDim); _Squeeze(&source, &target, leadingDim);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE); if (source.enableGrad) {
XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE);
}
return target; return target;
} }
...@@ -135,7 +137,7 @@ void Squeeze(XTensor & source, XTensor & target, int leadingDim) ...@@ -135,7 +137,7 @@ void Squeeze(XTensor & source, XTensor & target, int leadingDim)
/* call _Squeeze function */ /* call _Squeeze function */
_Squeeze(&source, &target, leadingDim); _Squeeze(&source, &target, leadingDim);
if (target.enableGrad) { if (source.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE); XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE);
} }
......
...@@ -144,9 +144,11 @@ XTensor Transpose(const XTensor &a, const int i, const int j) ...@@ -144,9 +144,11 @@ XTensor Transpose(const XTensor &a, const int i, const int j)
_Transpose(&a, &b, i, j); _Transpose(&a, &b, i, j);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&a, NULL, &b, SHAPE_TRANSPOSE); if (a.enableGrad) {
XLink::AddParamToHeadInt(&b, i); XLink::MakeLink(&a, NULL, &b, SHAPE_TRANSPOSE);
XLink::AddParamToHeadInt(&b, j); XLink::AddParamToHeadInt(&b, i);
XLink::AddParamToHeadInt(&b, j);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -156,9 +156,11 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize) ...@@ -156,9 +156,11 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize)
_Unsqueeze(&a, &b, dim, dSize); _Unsqueeze(&a, &b, dim, dSize);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE); if (a.enableGrad) {
XLink::AddParamToHeadInt(&b, dim); XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE);
XLink::AddParamToHeadInt(&b, dSize); XLink::AddParamToHeadInt(&b, dim);
XLink::AddParamToHeadInt(&b, dSize);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -191,7 +193,7 @@ void Unsqueeze(const XTensor &a, XTensor &b, int dim, int dSize) ...@@ -191,7 +193,7 @@ void Unsqueeze(const XTensor &a, XTensor &b, int dim, int dSize)
/* call _Unsqueeze function */ /* call _Unsqueeze function */
_Unsqueeze(&a, &b, dim, dSize); _Unsqueeze(&a, &b, dim, dSize);
if (b.enableGrad) { if (a.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE); XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE);
XLink::AddParamToHeadInt(&b, dim); XLink::AddParamToHeadInt(&b, dim);
......
...@@ -377,8 +377,8 @@ get the top-k items ...@@ -377,8 +377,8 @@ get the top-k items
template<class T> __global__ template<class T> __global__
void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index) void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index)
{ {
__shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE - 1024 * sizeof(T)) / sizeof(CudaHeapNode<T>)]; __shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE - 512 * sizeof(T)) / sizeof(CudaHeapNode<T>)];
__shared__ T eachHeapMaxValue[1024]; __shared__ T eachHeapMaxValue[512];
/*optimization k size the parameter must more than half of k*/ /*optimization k size the parameter must more than half of k*/
int parameter = 0; int parameter = 0;
...@@ -429,7 +429,7 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -429,7 +429,7 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
} }
__syncthreads(); __syncthreads();
/*to merge the heap use another way*/ /* to merge the heap use another way */
T minData = minValue; T minData = minValue;
int heapLimit = heap.count / 2; int heapLimit = heap.count / 2;
if (heapLimit % 2 == 0 && heapLimit != 0) heapLimit -= 1; if (heapLimit % 2 == 0 && heapLimit != 0) heapLimit -= 1;
...@@ -438,12 +438,13 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -438,12 +438,13 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
minData = heap.items[counter].value; minData = heap.items[counter].value;
} }
eachHeapMaxValue[threadIdx.y * blockDim.x + threadIdx.x] = minData; eachHeapMaxValue[threadIdx.y * blockDim.x + threadIdx.x] = minData;
//need more optimation //need more optimation
if (i == 0) { if (i == 0) {
int threadLimit = (threadIdx.y + 1) * blockDim.x; int threadLimit = threadIdx.y * blockDim.x + min(blockDim.x,strideNum);
CudaXHeap<MIN_HEAP, T> chooseHeap(k, heapData + k * ((blockDim.x * blockDim.y) + threadIdx.y)); CudaXHeap<MIN_HEAP, T> chooseHeap(k, heapData + k * ((blockDim.x * blockDim.y) + threadIdx.y));
int counter = threadIdx.y * blockDim.x; int counter = threadIdx.y * blockDim.x;
for (; counter < threadIdx.y * blockDim.x + k; ++counter) { for (; counter < threadIdx.y * blockDim.x + min(k, blockDim.x); ++counter) {
chooseHeap.Push(counter, eachHeapMaxValue[counter]); chooseHeap.Push(counter, eachHeapMaxValue[counter]);
} }
for (; counter < threadLimit; ++counter) { for (; counter < threadLimit; ++counter) {
...@@ -451,15 +452,16 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -451,15 +452,16 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
chooseHeap.ReplaceTop(counter, eachHeapMaxValue[counter]); chooseHeap.ReplaceTop(counter, eachHeapMaxValue[counter]);
} }
} }
int heapNum = chooseHeap.count;
CudaXHeap<MIN_HEAP, T> ansHeapData(k, k - parameter, heapData + k * chooseHeap.items[0].index); CudaXHeap<MIN_HEAP, T> ansHeapData(k, k - parameter, heapData + k * chooseHeap.items[0].index);
int miss = parameter; int miss = parameter;
for (counter = 1; counter < k; ++counter) { for (counter = 1; counter < heapNum; ++counter) {
chooseHeap.items[0] = chooseHeap.items[chooseHeap.count - 1]; chooseHeap.items[0] = chooseHeap.items[chooseHeap.count - 1];
chooseHeap.count--; chooseHeap.count--;
chooseHeap.Down(0); chooseHeap.Down(0);
CudaHeapNode<T> * cmpHeapData = heapData + k * (chooseHeap.items[0].index); CudaHeapNode<T> * cmpHeapData = heapData + k * (chooseHeap.items[0].index);
int cmpHeapLimit = 0; int cmpHeapLimit = 0;
if (counter + heapLimit <= k - parameter){ if (counter + heapLimit <= k - parameter && heapNum == k){
cmpHeapLimit = heapLimit; cmpHeapLimit = heapLimit;
} }
/* take the max data from the minHeap,so start search from the leaf node */ /* take the max data from the minHeap,so start search from the leaf node */
...@@ -770,22 +772,22 @@ void KernelTopKRadixSelect(unsigned int * input, int stride, int strideNum, ...@@ -770,22 +772,22 @@ void KernelTopKRadixSelect(unsigned int * input, int stride, int strideNum,
/* /*
if (idx == 0) if (idx == 0)
{ {
unsigned int* uintOutput = new unsigned int; unsigned int* uintOutput = new unsigned int;
int* tmpIndex = new int; int* tmpIndex = new int;
//*******************something worng*************************** //*******************something worng***************************
cudaMalloc((void **)&uintOutput, sizeof(unsigned int)* k); cudaMalloc((void **)&uintOutput, sizeof(unsigned int)* k);
cudaMalloc((void **)&tmpIndex, sizeof(unsigned int)*k); cudaMalloc((void **)&tmpIndex, sizeof(unsigned int)*k);
//************************************************************* //*************************************************************
collectNumberOld(input, limit, k, desire, uintOutput, tmpIndex, stride, strideNum); collectNumberOld(input, limit, k, desire, uintOutput, tmpIndex, stride, strideNum);
int blockIndex = idy / stride; int blockIndex = idy / stride;
int offsetInBlock = idy% stride; int offsetInBlock = idy% stride;
for (int i = stride * k * blockIndex + offsetInBlock, j = 0; j < k; j++, i += stride) for (int i = stride * k * blockIndex + offsetInBlock, j = 0; j < k; j++, i += stride)
{ {
//for(int i = ) //for(int i = )
output[i] = deconvert(uintOutput[j]); output[i] = deconvert(uintOutput[j]);
index[i] = tmpIndex[j]; index[i] = tmpIndex[j];
} }
} }
__syncthreads(); __syncthreads();
*/ */
...@@ -840,7 +842,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -840,7 +842,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
/* we run the kernel if the heaps can fit into the shared memory */ /* we run the kernel if the heaps can fit into the shared memory */
cudaGrids[1] *= cudaBlocks[1]; cudaGrids[1] *= cudaBlocks[1];
cudaBlocks[1] = 1; cudaBlocks[1] = 1;
if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) < SHARED_MEMORY_SIZE) { if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) + (512 * sizeof(int))< SHARED_MEMORY_SIZE) {
if (a->dataType == DEFAULT_DTYPE) { if (a->dataType == DEFAULT_DTYPE) {
KernelTopK3<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> KernelTopK3<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN, ((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN,
...@@ -869,7 +871,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -869,7 +871,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
//delete indexA; //delete indexA;
int workerNum = WORKERSNUM; int workerNum = WORKERSNUM;
GDevs.GetCudaThread2D(a->mem->devID, GDevs.GetCudaThread2D(a->devID,
workerNum, stride * blockNum, MAX_INT, workerNum, stride * blockNum, MAX_INT,
cudaGrids, cudaBlocks); cudaGrids, cudaBlocks);
if (a->dataType == DEFAULT_DTYPE) { if (a->dataType == DEFAULT_DTYPE) {
......
...@@ -81,8 +81,10 @@ XTensor DropoutWithIndex(const XTensor &x, XTensor &maskIndex, DTYPE scale) ...@@ -81,8 +81,10 @@ XTensor DropoutWithIndex(const XTensor &x, XTensor &maskIndex, DTYPE scale)
_ScaleAndShiftMe(&c, scale); _ScaleAndShiftMe(&c, scale);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&x, &maskIndex, &c, MOVEMENT_DROPOUTWITHINDEX); if (x.enableGrad) {
XLink::AddParamToHead(&c, scale); XLink::MakeLink(&x, &maskIndex, &c, MOVEMENT_DROPOUTWITHINDEX);
XLink::AddParamToHead(&c, scale);
}
return c; return c;
} }
......
...@@ -78,7 +78,9 @@ XTensor HardTanH(const XTensor &x) ...@@ -78,7 +78,9 @@ XTensor HardTanH(const XTensor &x)
_HardTanH(&x, &y); _HardTanH(&x, &y);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_HARDTANH); if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_HARDTANH);
}
return y; return y;
} }
...@@ -92,7 +94,7 @@ void HardTanH(const XTensor &x, XTensor &y) ...@@ -92,7 +94,7 @@ void HardTanH(const XTensor &x, XTensor &y)
/* call _HardTanH function */ /* call _HardTanH function */
_HardTanH(&x, &y); _HardTanH(&x, &y);
if (y.enableGrad) { if (x.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_HARDTANH); XLink::MakeLink(&x, NULL, &y, FUNC_HARDTANH);
} }
......
...@@ -54,7 +54,9 @@ XTensor Identity(const XTensor &x) ...@@ -54,7 +54,9 @@ XTensor Identity(const XTensor &x)
_Identity(&x, &y); _Identity(&x, &y);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY); if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY);
}
return y; return y;
} }
...@@ -68,7 +70,7 @@ void Identity(const XTensor &x, XTensor &y) ...@@ -68,7 +70,7 @@ void Identity(const XTensor &x, XTensor &y)
/* call _Identity function */ /* call _Identity function */
_Identity(&x, &y); _Identity(&x, &y);
if (y.enableGrad) { if (x.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY); XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY);
} }
......
...@@ -188,8 +188,10 @@ XTensor LogSoftmax(const XTensor &x, int leadDim) ...@@ -188,8 +188,10 @@ XTensor LogSoftmax(const XTensor &x, int leadDim)
_LogSoftmax(&x, &y, ld); _LogSoftmax(&x, &y, ld);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX); if (x.enableGrad) {
XLink::AddParamToHeadInt(&y, ld); XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
}
return y; return y;
} }
...@@ -215,7 +217,7 @@ void LogSoftmax(const XTensor &x, XTensor &y, int leadDim) ...@@ -215,7 +217,7 @@ void LogSoftmax(const XTensor &x, XTensor &y, int leadDim)
/* call _LogSoftmax function */ /* call _LogSoftmax function */
_LogSoftmax(&x, &y, ld); _LogSoftmax(&x, &y, ld);
if (y.enableGrad) { if (x.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX); XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX);
XLink::AddParamToHeadInt(&y, ld); XLink::AddParamToHeadInt(&y, ld);
......
...@@ -70,7 +70,9 @@ XTensor Rectify(const XTensor &x) ...@@ -70,7 +70,9 @@ XTensor Rectify(const XTensor &x)
_Rectify(&x, &y); _Rectify(&x, &y);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_RECTIFY); if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_RECTIFY);
}
return y; return y;
} }
...@@ -84,7 +86,7 @@ void Rectify(const XTensor &x, XTensor &y) ...@@ -84,7 +86,7 @@ void Rectify(const XTensor &x, XTensor &y)
/* call _Rectify function */ /* call _Rectify function */
_Rectify(&x, &y); _Rectify(&x, &y);
if (y.enableGrad) { if (x.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_RECTIFY); XLink::MakeLink(&x, NULL, &y, FUNC_RECTIFY);
} }
......
...@@ -73,7 +73,9 @@ XTensor Sigmoid(const XTensor &x) ...@@ -73,7 +73,9 @@ XTensor Sigmoid(const XTensor &x)
_Sigmoid(&x, &y); _Sigmoid(&x, &y);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SIGMOID); if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_SIGMOID);
}
return y; return y;
} }
...@@ -87,7 +89,7 @@ void Sigmoid(const XTensor &x, XTensor &y) ...@@ -87,7 +89,7 @@ void Sigmoid(const XTensor &x, XTensor &y)
/* call _Sigmoid function */ /* call _Sigmoid function */
_Sigmoid(&x, &y); _Sigmoid(&x, &y);
if (y.enableGrad) { if (x.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SIGMOID); XLink::MakeLink(&x, NULL, &y, FUNC_SIGMOID);
} }
......
...@@ -142,8 +142,10 @@ XTensor Softmax(const XTensor &x, int leadDim) ...@@ -142,8 +142,10 @@ XTensor Softmax(const XTensor &x, int leadDim)
_Softmax(&x, &y, ld); _Softmax(&x, &y, ld);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX); if (x.enableGrad) {
XLink::AddParamToHeadInt(&y, ld); XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
}
return y; return y;
} }
...@@ -161,7 +163,7 @@ void Softmax(const XTensor &x, XTensor &y, int leadDim) ...@@ -161,7 +163,7 @@ void Softmax(const XTensor &x, XTensor &y, int leadDim)
/* call _Softmax function */ /* call _Softmax function */
_Softmax(&x, &y, ld); _Softmax(&x, &y, ld);
if (y.enableGrad) { if (x.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX); XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX);
XLink::AddParamToHeadInt(&y, ld); XLink::AddParamToHeadInt(&y, ld);
......
...@@ -277,8 +277,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold, ...@@ -277,8 +277,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
tails.Add((XTensor*)&gold); tails.Add((XTensor*)&gold);
tails.Add(weight); tails.Add(weight);
tails.Add(padding); tails.Add(padding);
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim); if (output.enableGrad) {
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim);
}
return loss; return loss;
} }
...@@ -302,8 +305,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold, ...@@ -302,8 +305,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
tails.Add((XTensor*)&gold); tails.Add((XTensor*)&gold);
tails.Add(weight); tails.Add(weight);
tails.Add((XTensor*)&padding); tails.Add((XTensor*)&padding);
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim); if (output.enableGrad) {
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim);
}
return loss; return loss;
} }
...@@ -677,4 +683,4 @@ void _CrossEntropyBackward(XTensor * dedy, const XTensor * output, ...@@ -677,4 +683,4 @@ void _CrossEntropyBackward(XTensor * dedy, const XTensor * output,
} }
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -406,6 +406,68 @@ bool TestSetData5() ...@@ -406,6 +406,68 @@ bool TestSetData5()
#endif // USE_CUDA #endif // USE_CUDA
} }
/*
case 6: test SetDataRange function.
generate data items with a range by start, end and the step
*/
bool TestSetData6()
{
/* a input tensor of size (5) */
int order = 1;
int * dimSize = new int[order];
dimSize[0] = 5;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE answer[5] = {5.2F, 3.2F, 1.2F, -0.8F, -2.8F};
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(order, dimSize);
/* initialize variables */
s->SetZeroAll();
/* call _SetDataRange function */
_SetDataRange(s, 5.2, -3.2, -2);
/* check results */
cpuTest = s->CheckData(answer, unitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
/* initialize variables */
sGPU->SetZeroAll();
/* call _SetDataRange function */
_SetDataRange(sGPU, 5.2, -3.2, -2);
gpuTest = sGPU->CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete s;
delete sGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */ /* other cases */
/* /*
TODO!! TODO!!
...@@ -462,6 +524,15 @@ bool TestSetData() ...@@ -462,6 +524,15 @@ bool TestSetData()
else else
XPRINT(0, stdout, ">> case 5 passed!\n"); XPRINT(0, stdout, ">> case 5 passed!\n");
/* case 6 test */
caseFlag = TestSetData6();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 6 failed!\n");
}
else
XPRINT(0, stdout, ">> case 6 passed!\n");
/* other cases test */ /* other cases test */
/* /*
TODO!! TODO!!
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论