Commit f98396a9 by 张裕浩

1.update OpenBLAS and MKL usage , test in Linux and Windows environment

2.update  SumMe operation
parent d0c6a99c
......@@ -26,183 +26,9 @@
*
*/
#ifdef WIN32
#include <wtypes.h>
#endif
#include <stdlib.h>
#include <stdio.h>
#include "XBLAS.h"
#include "XGlobal.h"
/* the nts (NiuTrans.Tensor) namespace */
namespace nts{
#ifdef WIN32
HINSTANCE hBLASDll;
#endif
/* single-precision floating matrix-matrix multiplication */
void (*XBLAS_SGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
float *, OPENBLAS_CONST BLASINT);
/* double-precision floating matrix-matrix multiplication */
void (*XBLAS_DGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
double *, OPENBLAS_CONST BLASINT);
/* single-precision floating vector-vector multiplication (rank-1) */
void (*XBLAS_SGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha,
OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
float *, OPENBLAS_CONST BLASINT);
/* double-precision floating vector-vector multiplication (rank-1) */
void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
double *, OPENBLAS_CONST BLASINT);
/* set the number of threads */
void (*XBLAS_SET_THREAD_NUM)(int);
/* get the number of threads */
//int (*XBLAS_GET_THREAD_NUM)();
/* get the number of physical processors (cores).*/
int (*XBLAS_GET_CORE_NUM)();
/* get the CPU corename */
//char * (*XBLAS_GET_CORE_NAME)();
/* get the parallelization type used by OpenBLAS */
//int (*XBLAS_GET_PARALLEL_TYPE)(void);
#if defined(USE_BLAS)
/* load some stuff for BLAS */
void LoadBLAS(const char * dllFileName)
{
#ifndef CUDA_BLAS
#ifdef _WIN32
#if defined(OPENBLAS)
/* non-ascii characters are not supported yet */
wchar_t * fn = new wchar_t[strlen(dllFileName) + 1];
memset(fn, 0, sizeof(wchar_t) * (strlen(dllFileName) + 1));
for(int i = 0; i < strlen(dllFileName); i++)
fn[i] = dllFileName[i];
hBLASDll = LoadLibrary((LPCWSTR)fn);
if(!hBLASDll){
XPRINT1(0, stderr, "[LoadBLAS] Error! Cannot load dll %s!\n", dllFileName);
exit(1);
}
/* matrix-matrix multiplicatoin */
(FARPROC&)XBLAS_SGEMM = GetProcAddress(hBLASDll, "cblas_sgemm");
(FARPROC&)XBLAS_DGEMM = GetProcAddress(hBLASDll, "cblas_dgemm");
/* vector-vector multiplication */
(FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger");
(FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger");
/* multi-threading */
(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "openblas_set_num_threads");
//(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "goto_set_num_threads");
//(FARPROC&)XBLAS_GET_THREAD_NUM = GetProcAddress(hBLASDll, "openblas_get_num_threads");
(FARPROC&)XBLAS_GET_CORE_NUM = GetProcAddress(hBLASDll, "openblas_get_num_procs");
//(FARPROC&)XBLAS_GET_CORE_NAME = GetProcAddress(hBLASDll, "openblas_get_corename");
//(FARPROC&)XBLAS_GET_PARALLEL_TYPE = GetProcAddress(hBLASDll, "openblas_get_parallel");
delete[] fn;
#endif // defined(OPENBLAS)
#if defined(MKL)
/* non-ascii characters are not supported yet */
wchar_t * fn = new wchar_t[strlen(dllFileName) + 1];
memset(fn, 0, sizeof(wchar_t) * (strlen(dllFileName) + 1));
for(int i = 0; i < strlen(dllFileName); i++)
fn[i] = dllFileName[i];
hBLASDll = LoadLibrary((LPCWSTR)fn);
if(!hBLASDll){
XPRINT1(0, stderr, "[LoadBLAS] Error! Cannot load dll %s!\n", dllFileName);
exit(1);
}
/* matrix-matrix multiplicatoin */
(FARPROC&)XBLAS_SGEMM = GetProcAddress(hBLASDll, "cblas_sgemm");
(FARPROC&)XBLAS_DGEMM = GetProcAddress(hBLASDll, "cblas_dgemm");
/* vector-vector multiplication */
(FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger");
(FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger");
/* multi-threading */
(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "MKL_Set_Num_Threads");
(FARPROC&)XBLAS_GET_CORE_NUM = GetProcAddress(hBLASDll, "MKL_Get_Max_Threads");
#endif // defined(MKL)
#else // _WIN32
XBLAS_SGEMM = &cblas_sgemm;
XBLAS_DGEMM = &cblas_dgemm;
XBLAS_SGER = &cblas_sger;
XBLAS_DGER = &cblas_dger;
#if defined(OPENBLAS)
XBLAS_SET_THREAD_NUM = &openblas_set_num_threads;
XBLAS_GET_CORE_NUM = &openblas_get_num_procs;
#endif // defined(OPENBLAS)
#if defined(MKL)
XBLAS_SET_THREAD_NUM = &mkl_set_num_threads;
XBLAS_GET_CORE_NUM = &mkl_get_max_num_threads;
#endif // defined(MKL)
#endif // _WIN32
XBLAS_SET_THREAD_NUM(1);
#endif // ndef(CUDA_BLAS)
}
/* unload the libs */
void UnloadBLAS()
{
#ifdef _WIN32
if(!FreeLibrary(hBLASDll)){
XPRINT(0, stderr, "[UnloadBLAS] Error! Cannot free the BLAS dll!\n");
exit(1);
}
#else
#endif
}
#else // undefined(USE_BLAS) || undefined(OPENBLAS)
void LoadBLAS(const char * dllFileName)
{
XPRINT(0, stderr, "[LoadBLAS] Error! No Blas lib is available. Please use OPENBLAS or MKL!\n");
exit(1);
}
void UnloadBLAS()
{
XPRINT(0, stderr, "[UnloadBLAS] Error! No Blas lib is available. Please use OPENBLAS or MKL!\n");
exit(1);
}
#endif // defined(USE_BLAS) && defined(OPENBLAS)
} /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
......@@ -46,7 +46,26 @@ typedef enum CBLAS_SIDE {CblasLeft=141, CblasRight=142} CBLAS_SIDE;
#if defined(USE_BLAS)
#ifdef OPENBLAS
#define XBLAS_SGEMM cblas_sgemm
#define XBLAS_DGEMM cblas_dgemm
#define XBLAS_SGER cblas_sger
#define XBLAS_DGER cblas_dger
#define XBLAS_SAXPY cblas_saxpy
#define XBLAS_DAXPY cblas_daxpy
#define XBLAS_SET_THREAD_NUM openblas_set_num_threads
#define XBLAS_GET_CORE_NUM openblas_get_num_procs
#endif
#ifdef MKL
#define XBLAS_SGEMM cblas_sgemm
#define XBLAS_DGEMM cblas_dgemm
#define XBLAS_SGER cblas_sger
#define XBLAS_DGER cblas_dger
#define XBLAS_SAXPY cblas_saxpy
#define XBLAS_DAXPY cblas_daxpy
#define XBLAS_SET_THREAD_NUM MKL_Set_Num_Threads
#define XBLAS_GET_CORE_NUM MKL_Get_Max_Threads
#endif
/*
single/double-precision floating matrix-matrix multiplication (rank-3)
- SGEMM (ORDER, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
......@@ -62,14 +81,14 @@ where A, B and C are matrices,
LDB(=N) specifies the size of the first dimension of B as declared in the calling (sub) program,
and LDC(=N) specifies the size of the first dimension of C as declared in the calling (sub) program.
*/
extern "C" void (*XBLAS_SGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
extern "C" void XBLAS_SGEMM(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
float *, OPENBLAS_CONST BLASINT);
/* double-precision floating matrix-matrix multiplication */
extern "C" void (*XBLAS_DGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
extern "C" void XBLAS_DGEMM(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
......@@ -88,24 +107,33 @@ where X and Y are vectors with m and n elements respectively,
E.g., if we are using CblasRowMajor, the leading dimension is the number of columns of A.
*/
extern "C" void (*XBLAS_SGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha,
extern "C" void XBLAS_SGER(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha,
OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
float *, OPENBLAS_CONST BLASINT);
/* double-precision floating vector-vector multiplication (rank-1) */
extern "C" void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
extern "C" void XBLAS_DGER(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
double *, OPENBLAS_CONST BLASINT);
/*
some description
*/
extern "C" void XBLAS_SAXPY(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
/* double-precision floating sumMe function */
extern "C" void XBLAS_DAXPY(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST double a, OPENBLAS_CONST double *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST double *y, OPENBLAS_CONST BLASINT incy);
/* set the number of threads */
extern "C" void (*XBLAS_SET_THREAD_NUM)(int);
extern "C" void XBLAS_SET_THREAD_NUM(int);
/* get the number of threads */
//extern "C" int (*XBLAS_GET_THREAD_NUM)();
/* get the number of physical processors (cores).*/
extern "C" int (*XBLAS_GET_CORE_NUM)();
extern "C" int XBLAS_GET_CORE_NUM();
/* get the CPU corename */
//extern "C" char * (*XBLAS_GET_CORE_NAME)();
......@@ -133,19 +161,25 @@ extern "C" void cblas_sger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONS
extern "C" void cblas_dger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *X, OPENBLAS_CONST BLASINT incX, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT incY,
double *A, OPENBLAS_CONST BLASINT lda);
extern "C" void cblas_saxpy (OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
extern "C" void cblas_daxpy (OPENBLAS_CONST BLASINT n, OPENBLAS_CONST double a, OPENBLAS_CONST double *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST double *y, OPENBLAS_CONST BLASINT incy);
#endif
#if defined(OPENBLAS)
/* better control of multi-threading */
extern "C" void openblas_set_num_threads(int num_threads);
extern "C" void goto_set_num_threads(int num_threads);
//extern "C" int openblas_get_num_threads(void);
extern "C" int openblas_get_num_procs(void);
extern "C" void cblas_saxpy(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
extern "C" void cblas_daxpy(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST double a, OPENBLAS_CONST double *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST double *y, OPENBLAS_CONST BLASINT incy);
//extern "C" char* openblas_get_config(void);
//extern "C" char* openblas_get_corename(void);
//extern "C" int openblas_get_parallel(void);
#endif
#endif
#if defined(MKL)
......@@ -153,12 +187,12 @@ extern "C" int openblas_get_num_procs(void);
/* better control of multi-threading */
//_Mkl_Api(void,MKL_Set_Num_Threads,(int nth))
//_Mkl_Api(int,MKL_Get_Max_Threads,(void))
extern "C" void MKL_Set_Num_Threads(int num_threads);
extern "C" int MKL_Get_Max_Threads();
//extern "C" void MKL_Set_Num_Threads(int num_threads);
//extern "C" int MKL_Get_Max_Threads();
#define mkl_set_num_threads MKL_Set_Num_Threads
#define mkl_get_max_num_threads MKL_Get_Max_Threads
//extern "C" void cblas_saxpy(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
//#define mkl_set_num_threads MKL_Set_Num_Threads
//#define mkl_get_max_num_threads MKL_Get_Max_Threads
//extern "C" void mkl_set_num_threads(int num_threads);
//extern "C" void omp_set_num_threads(int num_threads);
......@@ -186,24 +220,8 @@ extern void BLASMatrixMULD(int deviceID, double * a, double * b, double * c, int
#endif
#endif
#ifdef _WIN32
#include "windows.h"
extern HINSTANCE hBLASDll;
#else
#endif
/* load some stuff for BLAS */
extern void LoadBLAS(const char * dllFileName);
/* unload the libs */
extern void UnloadBLAS();
} /* end of the nts (NiuTrans.Tensor) namespace */
#endif
......@@ -22,6 +22,7 @@
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "../../XBLAS.h"
#include "../movement/CopyValues.h"
#include "Sum.h"
#include "Sum.cuh"
......@@ -82,7 +83,34 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
DTYPE * ap = (DTYPE*)a->data;
DTYPE * bp = (DTYPE*)b->data;
DTYPE * cp = (DTYPE*)c->data;
/* when c != a, OpenBLAS needs to copy a to c first. This operation
slow down the speed, so just use OpenBLAS when c == a */
#if defined(USE_BLAS)
if( c == a){
AXPY(a->unitNum,beta,bp,1,cp,1);
} else{
int num = a->unitNum;
if (num % 4 == 0) {
for (int i = 0; i < num; i += 4) {
cp[i] = ap[i] + bp[i] * beta;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
cp[i + 2] = ap[i + 2] + bp[i + 2] * beta;
cp[i + 3] = ap[i + 3] + bp[i + 3] * beta;
}
}
else if (num % 2 == 0) {
for (int i = 0; i < num; i += 2) {
cp[i] = ap[i] + bp[i] * beta;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
}
}
else {
for (int i = 0; i < num; i++) {
cp[i] = ap[i] + bp[i] * beta;
}
}
}
#else
/* unrolling */
int num = a->unitNum;
if (num % 4 == 0) {
......@@ -104,6 +132,7 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
cp[i] = ap[i] + bp[i] * beta;
}
}
#endif
}
else {
// TODO!!
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论