Commit 666d51e9 by liyinqiao

Merge with Yuhao branch.

parent ddbb77b6
...@@ -26,183 +26,9 @@ ...@@ -26,183 +26,9 @@
* *
*/ */
#ifdef WIN32
#include <wtypes.h>
#endif
#include <stdlib.h>
#include <stdio.h>
#include "XBLAS.h"
#include "XGlobal.h"
/* the nts (NiuTrans.Tensor) namespace */ /* the nts (NiuTrans.Tensor) namespace */
namespace nts{ namespace nts{
#ifdef WIN32
HINSTANCE hBLASDll;
#endif
/* single-precision floating matrix-matrix multiplication */
void (*XBLAS_SGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
float *, OPENBLAS_CONST BLASINT);
/* double-precision floating matrix-matrix multiplication */
void (*XBLAS_DGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
double *, OPENBLAS_CONST BLASINT);
/* single-precision floating vector-vector multiplication (rank-1) */
void (*XBLAS_SGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha,
OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
float *, OPENBLAS_CONST BLASINT);
/* double-precision floating vector-vector multiplication (rank-1) */
void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
double *, OPENBLAS_CONST BLASINT);
/* set the number of threads */
void (*XBLAS_SET_THREAD_NUM)(int);
/* get the number of threads */
//int (*XBLAS_GET_THREAD_NUM)();
/* get the number of physical processors (cores).*/
int (*XBLAS_GET_CORE_NUM)();
/* get the CPU corename */
//char * (*XBLAS_GET_CORE_NAME)();
/* get the parallelization type used by OpenBLAS */
//int (*XBLAS_GET_PARALLEL_TYPE)(void);
#if defined(USE_BLAS)
/* load some stuff for BLAS */
void LoadBLAS(const char * dllFileName)
{
#ifndef CUDA_BLAS
#ifdef _WIN32
#if defined(OPENBLAS)
/* non-ascii characters are not supported yet */
wchar_t * fn = new wchar_t[strlen(dllFileName) + 1];
memset(fn, 0, sizeof(wchar_t) * (strlen(dllFileName) + 1));
for(int i = 0; i < strlen(dllFileName); i++)
fn[i] = dllFileName[i];
hBLASDll = LoadLibrary((LPCWSTR)fn);
if(!hBLASDll){
XPRINT1(0, stderr, "[LoadBLAS] Error! Cannot load dll %s!\n", dllFileName);
exit(1);
}
/* matrix-matrix multiplicatoin */
(FARPROC&)XBLAS_SGEMM = GetProcAddress(hBLASDll, "cblas_sgemm");
(FARPROC&)XBLAS_DGEMM = GetProcAddress(hBLASDll, "cblas_dgemm");
/* vector-vector multiplication */
(FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger");
(FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger");
/* multi-threading */
(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "openblas_set_num_threads");
//(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "goto_set_num_threads");
//(FARPROC&)XBLAS_GET_THREAD_NUM = GetProcAddress(hBLASDll, "openblas_get_num_threads");
(FARPROC&)XBLAS_GET_CORE_NUM = GetProcAddress(hBLASDll, "openblas_get_num_procs");
//(FARPROC&)XBLAS_GET_CORE_NAME = GetProcAddress(hBLASDll, "openblas_get_corename");
//(FARPROC&)XBLAS_GET_PARALLEL_TYPE = GetProcAddress(hBLASDll, "openblas_get_parallel");
delete[] fn;
#endif // defined(OPENBLAS)
#if defined(MKL)
/* non-ascii characters are not supported yet */
wchar_t * fn = new wchar_t[strlen(dllFileName) + 1];
memset(fn, 0, sizeof(wchar_t) * (strlen(dllFileName) + 1));
for(int i = 0; i < strlen(dllFileName); i++)
fn[i] = dllFileName[i];
hBLASDll = LoadLibrary((LPCWSTR)fn);
if(!hBLASDll){
XPRINT1(0, stderr, "[LoadBLAS] Error! Cannot load dll %s!\n", dllFileName);
exit(1);
}
/* matrix-matrix multiplicatoin */
(FARPROC&)XBLAS_SGEMM = GetProcAddress(hBLASDll, "cblas_sgemm");
(FARPROC&)XBLAS_DGEMM = GetProcAddress(hBLASDll, "cblas_dgemm");
/* vector-vector multiplication */
(FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger");
(FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger");
/* multi-threading */
(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "MKL_Set_Num_Threads");
(FARPROC&)XBLAS_GET_CORE_NUM = GetProcAddress(hBLASDll, "MKL_Get_Max_Threads");
#endif // defined(MKL)
#else // _WIN32
XBLAS_SGEMM = &cblas_sgemm;
XBLAS_DGEMM = &cblas_dgemm;
XBLAS_SGER = &cblas_sger;
XBLAS_DGER = &cblas_dger;
#if defined(OPENBLAS)
XBLAS_SET_THREAD_NUM = &openblas_set_num_threads;
XBLAS_GET_CORE_NUM = &openblas_get_num_procs;
#endif // defined(OPENBLAS)
#if defined(MKL)
XBLAS_SET_THREAD_NUM = &mkl_set_num_threads;
XBLAS_GET_CORE_NUM = &mkl_get_max_num_threads;
#endif // defined(MKL)
#endif // _WIN32
XBLAS_SET_THREAD_NUM(1);
#endif // ndef(CUDA_BLAS)
}
/* unload the libs */
void UnloadBLAS()
{
#ifdef _WIN32
if(!FreeLibrary(hBLASDll)){
XPRINT(0, stderr, "[UnloadBLAS] Error! Cannot free the BLAS dll!\n");
exit(1);
}
#else
#endif
}
#else // undefined(USE_BLAS) || undefined(OPENBLAS)
void LoadBLAS(const char * dllFileName)
{
XPRINT(0, stderr, "[LoadBLAS] Error! No Blas lib is available. Please use OPENBLAS or MKL!\n");
exit(1);
}
void UnloadBLAS()
{
XPRINT(0, stderr, "[UnloadBLAS] Error! No Blas lib is available. Please use OPENBLAS or MKL!\n");
exit(1);
}
#endif // defined(USE_BLAS) && defined(OPENBLAS)
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
...@@ -34,7 +34,6 @@ namespace nts{ ...@@ -34,7 +34,6 @@ namespace nts{
/* some of the code below is from OpenBLAS (https://github.com/xianyi/OpenBLAS) */ /* some of the code below is from OpenBLAS (https://github.com/xianyi/OpenBLAS) */
//#define OPENBLAS
#define OPENBLAS_CONST const #define OPENBLAS_CONST const
typedef int BLASINT; typedef int BLASINT;
...@@ -46,7 +45,26 @@ typedef enum CBLAS_SIDE {CblasLeft=141, CblasRight=142} CBLAS_SIDE; ...@@ -46,7 +45,26 @@ typedef enum CBLAS_SIDE {CblasLeft=141, CblasRight=142} CBLAS_SIDE;
#if defined(USE_BLAS) #if defined(USE_BLAS)
#ifdef OPENBLAS
#define XBLAS_SGEMM cblas_sgemm
#define XBLAS_DGEMM cblas_dgemm
#define XBLAS_SGER cblas_sger
#define XBLAS_DGER cblas_dger
#define XBLAS_SAXPY cblas_saxpy
#define XBLAS_DAXPY cblas_daxpy
#define XBLAS_SET_THREAD_NUM openblas_set_num_threads
#define XBLAS_GET_CORE_NUM openblas_get_num_procs
#endif
#ifdef MKL
#define XBLAS_SGEMM cblas_sgemm
#define XBLAS_DGEMM cblas_dgemm
#define XBLAS_SGER cblas_sger
#define XBLAS_DGER cblas_dger
#define XBLAS_SAXPY cblas_saxpy
#define XBLAS_DAXPY cblas_daxpy
#define XBLAS_SET_THREAD_NUM MKL_Set_Num_Threads
#define XBLAS_GET_CORE_NUM MKL_Get_Max_Threads
#endif
/* /*
single/double-precision floating matrix-matrix multiplication (rank-3) single/double-precision floating matrix-matrix multiplication (rank-3)
- SGEMM (ORDER, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) - SGEMM (ORDER, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
...@@ -62,14 +80,14 @@ where A, B and C are matrices, ...@@ -62,14 +80,14 @@ where A, B and C are matrices,
LDB(=N) specifies the size of the first dimension of B as declared in the calling (sub) program, LDB(=N) specifies the size of the first dimension of B as declared in the calling (sub) program,
and LDC(=N) specifies the size of the first dimension of C as declared in the calling (sub) program. and LDC(=N) specifies the size of the first dimension of C as declared in the calling (sub) program.
*/ */
extern "C" void (*XBLAS_SGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE, extern "C" void XBLAS_SGEMM(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
float *, OPENBLAS_CONST BLASINT); float *, OPENBLAS_CONST BLASINT);
/* double-precision floating matrix-matrix multiplication */ /* double-precision floating matrix-matrix multiplication */
extern "C" void (*XBLAS_DGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE, extern "C" void XBLAS_DGEMM(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
...@@ -88,24 +106,33 @@ where X and Y are vectors with m and n elements respectively, ...@@ -88,24 +106,33 @@ where X and Y are vectors with m and n elements respectively,
E.g., if we are using CblasRowMajor, the leading dimension is the number of columns of A. E.g., if we are using CblasRowMajor, the leading dimension is the number of columns of A.
*/ */
extern "C" void (*XBLAS_SGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha, extern "C" void XBLAS_SGER(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha,
OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
float *, OPENBLAS_CONST BLASINT); float *, OPENBLAS_CONST BLASINT);
/* double-precision floating vector-vector multiplication (rank-1) */ /* double-precision floating vector-vector multiplication (rank-1) */
extern "C" void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha, extern "C" void XBLAS_DGER(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
double *, OPENBLAS_CONST BLASINT); double *, OPENBLAS_CONST BLASINT);
/*
some description
*/
extern "C" void XBLAS_SAXPY(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
/* double-precision floating sumMe function */
extern "C" void XBLAS_DAXPY(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST double a, OPENBLAS_CONST double *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST double *y, OPENBLAS_CONST BLASINT incy);
/* set the number of threads */ /* set the number of threads */
extern "C" void (*XBLAS_SET_THREAD_NUM)(int); extern "C" void XBLAS_SET_THREAD_NUM(int);
/* get the number of threads */ /* get the number of threads */
//extern "C" int (*XBLAS_GET_THREAD_NUM)(); //extern "C" int (*XBLAS_GET_THREAD_NUM)();
/* get the number of physical processors (cores).*/ /* get the number of physical processors (cores).*/
extern "C" int (*XBLAS_GET_CORE_NUM)(); extern "C" int XBLAS_GET_CORE_NUM();
/* get the CPU corename */ /* get the CPU corename */
//extern "C" char * (*XBLAS_GET_CORE_NAME)(); //extern "C" char * (*XBLAS_GET_CORE_NAME)();
...@@ -113,58 +140,6 @@ extern "C" int (*XBLAS_GET_CORE_NUM)(); ...@@ -113,58 +140,6 @@ extern "C" int (*XBLAS_GET_CORE_NUM)();
/* get the parallelization type used by OpenBLAS */ /* get the parallelization type used by OpenBLAS */
//extern "C" int (*XBLAS_GET_PARALLEL_TYPE)(void); //extern "C" int (*XBLAS_GET_PARALLEL_TYPE)(void);
/* linux systems */
#ifndef _WIN32
/* cblas functions that are imported from the lib. See cblas.h in OpenBlas for more information */
extern "C" void cblas_sgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB,
OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST BLASINT K, OPENBLAS_CONST float alpha,
OPENBLAS_CONST float *A, OPENBLAS_CONST BLASINT lda,
OPENBLAS_CONST float *B, OPENBLAS_CONST BLASINT ldb,
OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST BLASINT ldc);
extern "C" void cblas_dgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB,
OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST BLASINT K, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *A, OPENBLAS_CONST BLASINT lda,
OPENBLAS_CONST double *B, OPENBLAS_CONST BLASINT ldb,
OPENBLAS_CONST double beta, double *C, OPENBLAS_CONST BLASINT ldc);
extern "C" void cblas_sger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha,
OPENBLAS_CONST float *X, OPENBLAS_CONST BLASINT incX, OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT incY,
float *A, OPENBLAS_CONST BLASINT lda);
extern "C" void cblas_dger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *X, OPENBLAS_CONST BLASINT incX, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT incY,
double *A, OPENBLAS_CONST BLASINT lda);
#if defined(OPENBLAS)
/* better control of multi-threading */
extern "C" void openblas_set_num_threads(int num_threads);
extern "C" void goto_set_num_threads(int num_threads);
//extern "C" int openblas_get_num_threads(void);
extern "C" int openblas_get_num_procs(void);
//extern "C" char* openblas_get_config(void);
//extern "C" char* openblas_get_corename(void);
//extern "C" int openblas_get_parallel(void);
#endif
#endif
#if defined(MKL)
/* better control of multi-threading */
//_Mkl_Api(void,MKL_Set_Num_Threads,(int nth))
//_Mkl_Api(int,MKL_Get_Max_Threads,(void))
extern "C" void MKL_Set_Num_Threads(int num_threads);
extern "C" int MKL_Get_Max_Threads();
#define mkl_set_num_threads MKL_Set_Num_Threads
#define mkl_get_max_num_threads MKL_Get_Max_Threads
//extern "C" void mkl_set_num_threads(int num_threads);
//extern "C" void omp_set_num_threads(int num_threads);
//extern "C" int mkl_get_max_num_threads();
#endif
#if defined(CUDA_BLAS) #if defined(CUDA_BLAS)
...@@ -186,24 +161,8 @@ extern void BLASMatrixMULD(int deviceID, double * a, double * b, double * c, int ...@@ -186,24 +161,8 @@ extern void BLASMatrixMULD(int deviceID, double * a, double * b, double * c, int
#endif #endif
#endif
#ifdef _WIN32
#include "windows.h"
extern HINSTANCE hBLASDll;
#else
#endif #endif
/* load some stuff for BLAS */
extern void LoadBLAS(const char * dllFileName);
/* unload the libs */
extern void UnloadBLAS();
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
#endif #endif
...@@ -160,8 +160,10 @@ extern bool useCUDA; ...@@ -160,8 +160,10 @@ extern bool useCUDA;
/* BLAS interfaces */ /* BLAS interfaces */
#ifdef DOUBELPRICSION #ifdef DOUBELPRICSION
#define GEMM XBLAS_DGEMM #define GEMM XBLAS_DGEMM
#define AXPY XBLAS_DAXPY
#else #else
#define GEMM XBLAS_SGEMM #define GEMM XBLAS_SGEMM
#define AXPY XBLAS_SAXPY
#endif #endif
extern void InitGlobalAll(); extern void InitGlobalAll();
......
...@@ -82,10 +82,11 @@ void _MatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -82,10 +82,11 @@ void _MatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
b->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE) c->dataType == DEFAULT_DTYPE)
{ {
if (useBLAS) #if defined(USE_BLAS)
_MatrixMULCPU(a, transposedA, b, transposedB, c, alpha, beta); _MatrixMULCPU(a, transposedA, b, transposedB, c, alpha, beta);
else #else
_MatrixMul2DParallel(a, transposedA, b, transposedB, c, alpha, beta, parallelRunner); _MatrixMul2DParallel(a, transposedA, b, transposedB, c, alpha, beta, parallelRunner);
#endif
} }
else { else {
// TODO!! // TODO!!
......
...@@ -199,10 +199,7 @@ void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -199,10 +199,7 @@ void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
bi->data = (char*)b->data + i * bRealBlockSize; bi->data = (char*)b->data + i * bRealBlockSize;
ci->data = (char*)c->data + i * cRealBlockSize; ci->data = (char*)c->data + i * cRealBlockSize;
#ifdef USE_BLAS #ifdef USE_BLAS
if (useBLAS) _MatrixMULCPU(ai, transposedA, bi, transposedB, ci, alpha, beta);
_MatrixMULCPU(ai, transposedA, bi, transposedB, ci, alpha, beta);
else
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#else #else
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta); _MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#endif #endif
...@@ -262,10 +259,7 @@ void _MatrixMulBatchedCPU(const TensorList * a, MATRIX_TRANS_TYPE transposedA, ...@@ -262,10 +259,7 @@ void _MatrixMulBatchedCPU(const TensorList * a, MATRIX_TRANS_TYPE transposedA,
CheckNTErrors((bi->order == 2), "2d tensor (i.e., matrix) is required!"); CheckNTErrors((bi->order == 2), "2d tensor (i.e., matrix) is required!");
CheckNTErrors((ci->order == 2), "2d tensor (i.e., matrix) is required!"); CheckNTErrors((ci->order == 2), "2d tensor (i.e., matrix) is required!");
#ifdef USE_BLAS #ifdef USE_BLAS
if (useBLAS)
_MatrixMULCPU(ai, transposedA, bi, transposedB, ci, alpha, beta); _MatrixMULCPU(ai, transposedA, bi, transposedB, ci, alpha, beta);
else
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#else #else
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta); _MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#endif #endif
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XUtility.h" #include "../../XUtility.h"
#include "../../XBLAS.h"
#include "../movement/CopyValues.h" #include "../movement/CopyValues.h"
#include "Sum.h" #include "Sum.h"
#include "Sum.cuh" #include "Sum.cuh"
...@@ -84,29 +85,57 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta) ...@@ -84,29 +85,57 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
DTYPE * ap = (DTYPE*)a->data; DTYPE * ap = (DTYPE*)a->data;
DTYPE * bp = (DTYPE*)b->data; DTYPE * bp = (DTYPE*)b->data;
DTYPE * cp = (DTYPE*)c->data; DTYPE * cp = (DTYPE*)c->data;
/* when c != a, OpenBLAS needs to copy a to c first. This operation
/* unrolling */ slow down the speed, so just use OpenBLAS when c == a */
int num = a->unitNum; #if defined(USE_BLAS)
if (num % 4 == 0) { if( c == a){
for (int i = 0; i < num; i += 4) { AXPY(a->unitNum,beta,bp,1,cp,1);
cp[i] = ap[i] + bp[i] * beta; } else{
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta; int num = a->unitNum;
cp[i + 2] = ap[i + 2] + bp[i + 2] * beta; if (num % 4 == 0) {
cp[i + 3] = ap[i + 3] + bp[i + 3] * beta; for (int i = 0; i < num; i += 4) {
} cp[i] = ap[i] + bp[i] * beta;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
cp[i + 2] = ap[i + 2] + bp[i + 2] * beta;
cp[i + 3] = ap[i + 3] + bp[i + 3] * beta;
}
}
else if (num % 2 == 0) {
for (int i = 0; i < num; i += 2) {
cp[i] = ap[i] + bp[i] * beta;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
}
}
else {
for (int i = 0; i < num; i++) {
cp[i] = ap[i] + bp[i] * beta;
}
}
} }
else if (num % 2 == 0) { #else
for (int i = 0; i < num; i += 2) { /* unrolling */
cp[i] = ap[i] + bp[i] * beta; int num = a->unitNum;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta; if (num % 4 == 0) {
for (int i = 0; i < num; i += 4) {
cp[i] = ap[i] + bp[i] * beta;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
cp[i + 2] = ap[i + 2] + bp[i + 2] * beta;
cp[i + 3] = ap[i + 3] + bp[i + 3] * beta;
}
} }
} else if (num % 2 == 0) {
else { for (int i = 0; i < num; i += 2) {
for (int i = 0; i < num; i++) { cp[i] = ap[i] + bp[i] * beta;
cp[i] = ap[i] + bp[i] * beta; cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
}
} }
else {
for (int i = 0; i < num; i++) {
cp[i] = ap[i] + bp[i] * beta;
}
}
#endif
} }
}
else { else {
// TODO!! // TODO!!
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XBLAS.h"
#include "VectorBuffer.h"
#include "ReduceMax.h" #include "ReduceMax.h"
#include "ReduceMax.cuh" #include "ReduceMax.cuh"
...@@ -76,18 +78,75 @@ void _ReduceMax(const XTensor * input, XTensor * output, int dim) ...@@ -76,18 +78,75 @@ void _ReduceMax(const XTensor * input, XTensor * output, int dim)
} }
blockSize = stride * strideNum; blockSize = stride * strideNum;
for(int k = 0; k < blockNum; k++){ if(input->dimSizeRDI[0] % (4 * 32 / sizeof(DTYPE)) == 0 && input->dimSizeRDI[0] >= 32){
DTYPE * ip = (DTYPE*)input->data + blockSize * k; int vecBufLength = 32 / sizeof(DTYPE);
DTYPE * op = (DTYPE*)output->data + stride * k;
for(int i = 0; i < stride; i++){ if(dimRDI == 0){
DTYPE max = FLOAT_MIN; //data is contiguous in dim 0
DTYPE * ipe = ip + blockSize; for(int i = 0; i < blockNum; i++){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE v = *ipb; DTYPE * op = (DTYPE*)output->data + i;
if(max < v) VectorBuffer vecBuf[4];
max = v; for(int j = 0; j < 4; j++){
vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip) + j * vecBufLength);
}
for(int j = 1; j < strideNum / 32; j++){
const DTYPE* ptr = (DTYPE*)(ip + j * vecBufLength);
vecBuf[0] = vecBuf[0].maxData(VectorBuffer::loadu(ptr + 0 * vecBufLength));
vecBuf[1] = vecBuf[1].maxData(VectorBuffer::loadu(ptr + 1 * vecBufLength));
vecBuf[2] = vecBuf[2].maxData(VectorBuffer::loadu(ptr + 2 * vecBufLength));
vecBuf[3] = vecBuf[3].maxData(VectorBuffer::loadu(ptr + 3 * vecBufLength));
}
vecBuf[0] = vecBuf[0].maxData(vecBuf[1]);
vecBuf[0] = vecBuf[0].maxData(vecBuf[2]);
vecBuf[0] = vecBuf[0].maxData(vecBuf[3]);
DTYPE maxN = DTYPE_MIN;
for(int k = 0; k < vecBufLength; k++){
maxN = MAX(maxN,vecBuf[0][k]);
}
*op = maxN;
}
} else{
//data is separated
for(int i = 0; i < blockNum; i++){
for(int j = 0; j < input->dimSizeRDI[0] / 32; j++){
DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE * op = (DTYPE*)output->data + stride * i;
VectorBuffer vecBuf[4];
for(int k = 0; k < 4; k++){
vecBuf[k] = VectorBuffer::loadu((DTYPE*)(ip) + (j * 4 + k) * 32 / sizeof(DTYPE));
}
for(int k = 1; k < strideNum; k++){
DTYPE * ptr = ip + k * stride + (j * 4) * vecBufLength;
vecBuf[0] = vecBuf[0].maxData(VectorBuffer::loadu(ptr + 0 * vecBufLength));
vecBuf[1] = vecBuf[1].maxData(VectorBuffer::loadu(ptr + 1 * vecBufLength));
vecBuf[2] = vecBuf[2].maxData(VectorBuffer::loadu(ptr + 2 * vecBufLength));
vecBuf[3] = vecBuf[3].maxData(VectorBuffer::loadu(ptr + 3 * vecBufLength));
}
for(int k = 0; k < 4; k++){
for(int l = 0; l < vecBufLength; l++)
*(op + j * 32 + 8 * k + l) = vecBuf[k][l];
}
}
}
}
}//run vector buffer
else{
for(int k = 0; k < blockNum; k++){
DTYPE * ip = (DTYPE*)input->data + blockSize * k;
DTYPE * op = (DTYPE*)output->data + stride * k;
for(int i = 0; i < stride; i++){
DTYPE max = DTYPE_MIN;
DTYPE * ipe = ip + blockSize;
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE v = *ipb;
if(max < v)
max = v;
}
*(op + i) = max;
} }
*(op + i) = max;
} }
} }
} }
......
...@@ -41,19 +41,19 @@ float shflDownReduceMax(float input) ...@@ -41,19 +41,19 @@ float shflDownReduceMax(float input)
"{" "{"
".reg .f32 r0;" ".reg .f32 r0;"
".reg .pred p;" ".reg .pred p;"
"shfl.down.b32 r0, %1, 0x10, 0x1f;" "shfl.sync.down.b32 r0, %1, 0x10, 0x1f,0xffffffff;"
"setp.lt.f32 p,%1,r0;" "setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;" "@p mov.f32 %1,r0;"
"shfl.down.b32 r0, %1, 0x8, 0xf;" "shfl.sync.down.b32 r0, %1, 0x8, 0xf,0xffffffff;"
"setp.lt.f32 p,%1,r0;" "setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;" "@p mov.f32 %1,r0;"
"shfl.down.b32 r0, %1, 0x4, 0x7;" "shfl.sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff;"
"setp.lt.f32 p,%1,r0;" "setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;" "@p mov.f32 %1,r0;"
"shfl.down.b32 r0, %1, 0x2, 0x3;" "shfl.sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff;"
"setp.lt.f32 p,%1,r0;" "setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;" "@p mov.f32 %1,r0;"
"shfl.down.b32 r0, %1, 0x1, 0x1;" "shfl.sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff;"
"setp.lt.f32 p, %1, r0; " "setp.lt.f32 p, %1, r0; "
"@p mov.f32 %1,r0;" "@p mov.f32 %1,r0;"
"mov.f32 %0,%1;" "mov.f32 %0,%1;"
...@@ -73,19 +73,19 @@ int shflDownReduceMax(int input) ...@@ -73,19 +73,19 @@ int shflDownReduceMax(int input)
"{" "{"
".reg .s32 r0;" ".reg .s32 r0;"
".reg .pred p;" ".reg .pred p;"
"shfl.down.b32 r0, %1, 0x10, 0x1f;" "shfl.sync.down.b32 r0, %1, 0x10, 0x1f,0xffffffff;"
"setp.lt.s32 p,%1,r0;" "setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;" "@p mov.s32 %1,r0;"
"shfl.down.b32 r0, %1, 0x8, 0xf;" "shfl.sync.down.b32 r0, %1, 0x8, 0xf,0xffffffff;"
"setp.lt.s32 p,%1,r0;" "setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;" "@p mov.s32 %1,r0;"
"shfl.down.b32 r0, %1, 0x4, 0x7;" "shfl.sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff;"
"setp.lt.s32 p,%1,r0;" "setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;" "@p mov.s32 %1,r0;"
"shfl.down.b32 r0, %1, 0x2, 0x3;" "shfl.sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff;"
"setp.lt.s32 p,%1,r0;" "setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;" "@p mov.s32 %1,r0;"
"shfl.down.b32 r0, %1, 0x1, 0x1;" "shfl.sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff;"
"setp.lt.s32 p, %1, r0; " "setp.lt.s32 p, %1, r0; "
"@p mov.s32 %1,r0;" "@p mov.s32 %1,r0;"
"mov.s32 %0,%1;" "mov.s32 %0,%1;"
......
...@@ -37,15 +37,15 @@ float shflDownReduceSum(float input) ...@@ -37,15 +37,15 @@ float shflDownReduceSum(float input)
asm volatile( asm volatile(
"{" "{"
".reg .f32 r0;" ".reg .f32 r0;"
"shfl.down.b32 r0, %1, 0x10, 0x1f;" "shfl.sync.down.b32 r0, %1, 0x10, 0x1f,0xffffffff;"
"add.f32 %1, r0, %1;" "add.f32 %1, r0, %1;"
"shfl.down.b32 r0, %1, 0x8, 0xf;" "shfl.sync.down.b32 r0, %1, 0x8, 0xf,0xffffffff;"
"add.f32 %1, r0, %1;" "add.f32 %1, r0, %1;"
"shfl.down.b32 r0, %1, 0x4, 0x7;" "shfl.sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff;"
"add.f32 %1, r0, %1;" "add.f32 %1, r0, %1;"
"shfl.down.b32 r0, %1, 0x2, 0x3;" "shfl.sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff;"
"add.f32 %1, r0, %1;" "add.f32 %1, r0, %1;"
"shfl.down.b32 r0, %1, 0x1, 0x1;" "shfl.sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff;"
"add.f32 %0, r0, %1;" "add.f32 %0, r0, %1;"
"}" "}"
: "=f"(output) : "f"(input)); : "=f"(output) : "f"(input));
...@@ -62,15 +62,15 @@ int shflDownReduceSum(int input) ...@@ -62,15 +62,15 @@ int shflDownReduceSum(int input)
asm volatile( asm volatile(
"{" "{"
".reg .s32 r0;" ".reg .s32 r0;"
"shfl.down.b32 r0, %1, 0x10, 0x1f;" "shfl.sync.down.b32 r0, %1, 0x10, 0x1f,0xffffffff;"
"add.s32 %1, r0, %1;" "add.s32 %1, r0, %1;"
"shfl.down.b32 r0, %1, 0x8, 0xf;" "shfl.sync.down.b32 r0, %1, 0x8, 0xf,0xffffffff;"
"add.s32 %1, r0, %1;" "add.s32 %1, r0, %1;"
"shfl.down.b32 r0, %1, 0x4, 0x7;" "shfl.sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff;"
"add.s32 %1, r0, %1;" "add.s32 %1, r0, %1;"
"shfl.down.b32 r0, %1, 0x2, 0x3;" "shfl.sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff;"
"add.s32 %1, r0, %1;" "add.s32 %1, r0, %1;"
"shfl.down.b32 r0, %1, 0x1, 0x1;" "shfl.sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff;"
"add.s32 %0, r0, %1;" "add.s32 %0, r0, %1;"
"}" "}"
: "=r"(output) : "r"(input)); : "=r"(output) : "r"(input));
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: ZHANG Yuhao (email: zhangyuhao@stu.neu.edu.cn) 2019-07-23
*/
#include "VectorBuffer.h"
namespace nts {
/* data size for each buffer */
int VectorBuffer::size()
{
return 32 / sizeof(DTYPE);
}
/* constructor */
VectorBuffer::VectorBuffer()
{
}
/*
constructor
initial values with val
*/
VectorBuffer::VectorBuffer(DTYPE val)
{
for (int i = 0; i != size(); i++) {
values[i] = val;
}
}
/* load data */
VectorBuffer VectorBuffer::loadu(const DTYPE* ptr, bool isExp , DTYPE power , DTYPE* bias )
{
int count = 32 / sizeof(DTYPE);
VectorBuffer vec;
if (isExp) {
if (bias == NULL) {
if (power == (DTYPE)1.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(*(ptr + i));
}
}
else if (power == (DTYPE)2.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp((*(ptr + i)) * (*(ptr + i)));
}
}
else if (power == (DTYPE)0.5) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(sqrt(*(ptr + i)));
}
}
else {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(pow(*(ptr + i), power));
}
}
}/*is bias == NULL*/
else {
if (power == (DTYPE)1.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(*(ptr + i) - bias[i]);
}
}
else if (power == (DTYPE)2.0) {
for (int i = 0; i != count; i++) {
DTYPE value = *(ptr + i) - bias[i];
vec.values[i] = (DTYPE)exp(value * value);
}
}
else if (power == (DTYPE)0.5) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(sqrt(*(ptr + i) - bias[i]));
}
}
else {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(pow(*(ptr + i) - bias[i], power));
}
}
}
}//isExp
else {
if (bias == NULL) {
if (power == (DTYPE)1.0) {
memcpy(vec.values, ptr, count * sizeof(DTYPE));
}
else if (power == (DTYPE)2.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = (*(ptr + i)) * (*(ptr + i));
}
}
else if (power == (DTYPE)0.5) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)sqrt(*(ptr + i));
}
}
else {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)pow(*(ptr + i), power);
}
}
}// if bias == NULL
else {
if (power == (DTYPE)1.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = *(ptr + i) - bias[i];
}
}
else if (power == (DTYPE)2.0) {
for (int i = 0; i != count; i++) {
DTYPE value = *(ptr + i) - bias[i];
vec.values[i] = value * value;
}
}
else if (power == (DTYPE)0.5) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)sqrt(*(ptr + i) - bias[i]);
}
}
else {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)pow(*(ptr + i) - bias[i], power);
}
}
}
}
return vec;
}
/* overloading [] */
const DTYPE& VectorBuffer::operator[](int idx)const
{
return values[idx];
}
/* overloading + */
VectorBuffer VectorBuffer::operator+(const VectorBuffer &a)
{
for (int i = 0; i != a.size(); i++) {
this->values[i] = a[i] + this->values[i];
}
return *this;
}
/* conculte the max of two buffer */
VectorBuffer VectorBuffer::maxData(const VectorBuffer &a) {
for (int i = 0; i != a.size(); i++) {
this->values[i] = MAX(a[i], this->values[i]);
}
return *this;
}
}/* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: ZHANG Yuhao (email: zhangyuhao@stu.neu.edu.cn) 2019-07-23
*/
//#include <cstring>
#include <math.h>
#include "../../XGlobal.h"
namespace nts {
class VectorBuffer {
private:
/* buffer for concluter */
DTYPE values[32 / sizeof(DTYPE)] = { 0 };
public:
/* data size for each buffer */
static int size();
/* constructor */
VectorBuffer();
/* constructor */
VectorBuffer(DTYPE val);
/* load data */
static VectorBuffer loadu(const DTYPE* ptr, bool isExp = false, DTYPE power = (DTYPE)1.0F, DTYPE* bias = NULL);
/* overloading [] */
const DTYPE& operator[](int idx)const;
/* overloading + */
VectorBuffer operator+(const VectorBuffer &a);
/* conculte the max of two buffer */
VectorBuffer maxData(const VectorBuffer &a);
};
}
\ No newline at end of file
...@@ -377,8 +377,8 @@ get the top-k items ...@@ -377,8 +377,8 @@ get the top-k items
template<class T> __global__ template<class T> __global__
void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index) void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index)
{ {
__shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE - 1024 * sizeof(T)) / sizeof(CudaHeapNode<T>)]; __shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE - 512 * sizeof(T)) / sizeof(CudaHeapNode<T>)];
__shared__ T eachHeapMaxValue[1024]; __shared__ T eachHeapMaxValue[512];
/*optimization k size the parameter must more than half of k*/ /*optimization k size the parameter must more than half of k*/
int parameter = 0; int parameter = 0;
...@@ -429,7 +429,7 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -429,7 +429,7 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
} }
__syncthreads(); __syncthreads();
/*to merge the heap use another way*/ /* to merge the heap use another way */
T minData = minValue; T minData = minValue;
int heapLimit = heap.count / 2; int heapLimit = heap.count / 2;
if (heapLimit % 2 == 0 && heapLimit != 0) heapLimit -= 1; if (heapLimit % 2 == 0 && heapLimit != 0) heapLimit -= 1;
...@@ -438,12 +438,13 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -438,12 +438,13 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
minData = heap.items[counter].value; minData = heap.items[counter].value;
} }
eachHeapMaxValue[threadIdx.y * blockDim.x + threadIdx.x] = minData; eachHeapMaxValue[threadIdx.y * blockDim.x + threadIdx.x] = minData;
//need more optimation //need more optimation
if (i == 0) { if (i == 0) {
int threadLimit = (threadIdx.y + 1) * blockDim.x; int threadLimit = threadIdx.y * blockDim.x + min(blockDim.x,strideNum);
CudaXHeap<MIN_HEAP, T> chooseHeap(k, heapData + k * ((blockDim.x * blockDim.y) + threadIdx.y)); CudaXHeap<MIN_HEAP, T> chooseHeap(k, heapData + k * ((blockDim.x * blockDim.y) + threadIdx.y));
int counter = threadIdx.y * blockDim.x; int counter = threadIdx.y * blockDim.x;
for (; counter < threadIdx.y * blockDim.x + k; ++counter) { for (; counter < threadIdx.y * blockDim.x + min(k, blockDim.x); ++counter) {
chooseHeap.Push(counter, eachHeapMaxValue[counter]); chooseHeap.Push(counter, eachHeapMaxValue[counter]);
} }
for (; counter < threadLimit; ++counter) { for (; counter < threadLimit; ++counter) {
...@@ -451,15 +452,16 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -451,15 +452,16 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
chooseHeap.ReplaceTop(counter, eachHeapMaxValue[counter]); chooseHeap.ReplaceTop(counter, eachHeapMaxValue[counter]);
} }
} }
int heapNum = chooseHeap.count;
CudaXHeap<MIN_HEAP, T> ansHeapData(k, k - parameter, heapData + k * chooseHeap.items[0].index); CudaXHeap<MIN_HEAP, T> ansHeapData(k, k - parameter, heapData + k * chooseHeap.items[0].index);
int miss = parameter; int miss = parameter;
for (counter = 1; counter < k; ++counter) { for (counter = 1; counter < heapNum; ++counter) {
chooseHeap.items[0] = chooseHeap.items[chooseHeap.count - 1]; chooseHeap.items[0] = chooseHeap.items[chooseHeap.count - 1];
chooseHeap.count--; chooseHeap.count--;
chooseHeap.Down(0); chooseHeap.Down(0);
CudaHeapNode<T> * cmpHeapData = heapData + k * (chooseHeap.items[0].index); CudaHeapNode<T> * cmpHeapData = heapData + k * (chooseHeap.items[0].index);
int cmpHeapLimit = 0; int cmpHeapLimit = 0;
if (counter + heapLimit <= k - parameter){ if (counter + heapLimit <= k - parameter && heapNum == k){
cmpHeapLimit = heapLimit; cmpHeapLimit = heapLimit;
} }
/* take the max data from the minHeap,so start search from the leaf node */ /* take the max data from the minHeap,so start search from the leaf node */
...@@ -840,7 +842,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -840,7 +842,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
/* we run the kernel if the heaps can fit into the shared memory */ /* we run the kernel if the heaps can fit into the shared memory */
cudaGrids[1] *= cudaBlocks[1]; cudaGrids[1] *= cudaBlocks[1];
cudaBlocks[1] = 1; cudaBlocks[1] = 1;
if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) < SHARED_MEMORY_SIZE) { if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) + (512 * sizeof(int))< SHARED_MEMORY_SIZE) {
if (a->dataType == DEFAULT_DTYPE) { if (a->dataType == DEFAULT_DTYPE) {
KernelTopK3<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>> KernelTopK3<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN, ((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN,
...@@ -869,7 +871,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -869,7 +871,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
//delete indexA; //delete indexA;
int workerNum = WORKERSNUM; int workerNum = WORKERSNUM;
GDevs.GetCudaThread2D(a->mem->devID, GDevs.GetCudaThread2D(a->devID,
workerNum, stride * blockNum, MAX_INT, workerNum, stride * blockNum, MAX_INT,
cudaGrids, cudaBlocks); cudaGrids, cudaBlocks);
if (a->dataType == DEFAULT_DTYPE) { if (a->dataType == DEFAULT_DTYPE) {
......
...@@ -171,7 +171,7 @@ float broadcast(float input) ...@@ -171,7 +171,7 @@ float broadcast(float input)
float output; float output;
asm( asm(
"{" "{"
"shfl.idx.b32 %0,%1,0x0,0x1f;" "shfl.sync.idx.b32 %0,%1,0x0,0x1f,0xffffffff;"
"}" "}"
:"=f"(output) : "f"(input) :"=f"(output) : "f"(input)
); );
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论