Commit 4641030e by Tianzhi

add headers in cpp and fix the bug of MAX_LEN in XLink.h

parent ba8bc234
...@@ -65,7 +65,7 @@ ifeq ($(USE_MKL), 1) ...@@ -65,7 +65,7 @@ ifeq ($(USE_MKL), 1)
$(MKL_LIB_DIR)/libmkl_core.a \ $(MKL_LIB_DIR)/libmkl_core.a \
$(MKL_LIB_DIR)/libmkl_intel_thread.a \ $(MKL_LIB_DIR)/libmkl_intel_thread.a \
$(INTEL_ROOT)/lib/intel64/libiomp5.a $(INTEL_ROOT)/lib/intel64/libiomp5.a
DYNAMIC_DEPLIB += -liomp5 -lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core #DYNAMIC_DEPLIB += -liomp5 -lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core
endif endif
ifeq ($(USE_OPENBLAS), 1) ifeq ($(USE_OPENBLAS), 1)
STATIC_DEPLIB += $(OPENBLAS_LIB_DIR)/libopenblas.a STATIC_DEPLIB += $(OPENBLAS_LIB_DIR)/libopenblas.a
......
...@@ -67,6 +67,13 @@ void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OP ...@@ -67,6 +67,13 @@ void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OP
OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
double *, OPENBLAS_CONST BLASINT); double *, OPENBLAS_CONST BLASINT);
float (*XBLAS_SASUM)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
float (*XBLAS_ISAMAX)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
float (*XBLAS_SNRM2)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
void (*XBLAS_SSCAL)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float a,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
void (*XBLAS_SCOPY)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx,OPENBLAS_CONST float *y,OPENBLAS_CONST BLASINT incy);
void (*XBLAS_SAXPY)(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
/* set the number of threads */ /* set the number of threads */
void (*XBLAS_SET_THREAD_NUM)(int); void (*XBLAS_SET_THREAD_NUM)(int);
...@@ -115,6 +122,13 @@ void LoadBLAS(const char * dllFileName) ...@@ -115,6 +122,13 @@ void LoadBLAS(const char * dllFileName)
(FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger"); (FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger");
(FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger"); (FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger");
(FARPROC&)XBLAS_SASUM = GetProcAddress(hBLASDll, "cblas_sasum");
(FARPROC&)XBLAS_ISAMAX = GetProcAddress(hBLASDll, "cblas_isamax");
(FARPROC&)XBLAS_SNRM2 = GetProcAddress(hBLASDll, "cblas_snrm2");
(FARPROC&)XBLAS_SSCAL = GetProcAddress(hBLASDll, "cblas_sscal");
(FARPROC&)XBLAS_SCOPY = GetProcAddress(hBLASDll, "cblas_scopy");
(FARPROC&)XBLAS_SAXPY = GetProcAddress(hBLASDll, "cblas_saxpy");
/* multi-threading */ /* multi-threading */
(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "openblas_set_num_threads"); (FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "openblas_set_num_threads");
//(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "goto_set_num_threads"); //(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "goto_set_num_threads");
...@@ -148,17 +162,31 @@ void LoadBLAS(const char * dllFileName) ...@@ -148,17 +162,31 @@ void LoadBLAS(const char * dllFileName)
(FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger"); (FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger");
(FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger"); (FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger");
(FARPROC&)XBLAS_SASUM = GetProcAddress(hBLASDll, "cblas_sasum");
(FARPROC&)XBLAS_ISAMAX = GetProcAddress(hBLASDll, "cblas_isamax");
(FARPROC&)XBLAS_SNRM2 = GetProcAddress(hBLASDll, "cblas_snrm2");
(FARPROC&)XBLAS_SSCAL = GetProcAddress(hBLASDll, "cblas_sscal");
(FARPROC&)XBLAS_SCOPY = GetProcAddress(hBLASDll, "cblas_scopy");
(FARPROC&)XBLAS_SAXPY = GetProcAddress(hBLASDll, "cblas_saxpy");
/* multi-threading */ /* multi-threading */
(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "MKL_Set_Num_Threads"); (FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "MKL_Set_Num_Threads");
(FARPROC&)XBLAS_GET_CORE_NUM = GetProcAddress(hBLASDll, "MKL_Get_Max_Threads"); (FARPROC&)XBLAS_GET_CORE_NUM = GetProcAddress(hBLASDll, "MKL_Get_Max_Threads");
#endif // defined(MKL) #endif // defined(MKL)
#else // _WIN32 #else // _WIN32
XBLAS_SGEMM = &cblas_sgemm; XBLAS_SGEMM = &cblas_sgemm;
XBLAS_DGEMM = &cblas_dgemm; XBLAS_DGEMM = &cblas_dgemm;
XBLAS_SGER = &cblas_sger; XBLAS_SGER = &cblas_sger;
XBLAS_DGER = &cblas_dger; XBLAS_DGER = &cblas_dger;
XBLAS_SASUM = &cblas_sasum;
XBLAS_ISAMAX = &cblas_isamax;
XBLAS_SNRM2 = &cblas_snrm2;
XBLAS_SSCAL = &cblas_sscal;
XBLAS_SCOPY = &cblas_scopy;
XBLAS_SAXPY = &cblas_saxpy;
#if defined(OPENBLAS) #if defined(OPENBLAS)
XBLAS_SET_THREAD_NUM = &openblas_set_num_threads; XBLAS_SET_THREAD_NUM = &openblas_set_num_threads;
XBLAS_GET_CORE_NUM = &openblas_get_num_procs; XBLAS_GET_CORE_NUM = &openblas_get_num_procs;
...@@ -205,4 +233,4 @@ void UnloadBLAS() ...@@ -205,4 +233,4 @@ void UnloadBLAS()
#endif // defined(USE_BLAS) && defined(OPENBLAS) #endif // defined(USE_BLAS) && defined(OPENBLAS)
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
...@@ -97,6 +97,15 @@ extern "C" void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BL ...@@ -97,6 +97,15 @@ extern "C" void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BL
OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
double *, OPENBLAS_CONST BLASINT); double *, OPENBLAS_CONST BLASINT);
extern "C" float (*XBLAS_SASUM)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" float (*XBLAS_ISAMAX)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" float (*XBLAS_ISAMIN)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" float (*XBLAS_SNRM2)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" void (*XBLAS_SSCAL)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float a,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" void (*XBLAS_SCOPY)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx,OPENBLAS_CONST float *y,OPENBLAS_CONST BLASINT incy);
extern "C" void (*XBLAS_SAXPY)(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
/* set the number of threads */ /* set the number of threads */
extern "C" void (*XBLAS_SET_THREAD_NUM)(int); extern "C" void (*XBLAS_SET_THREAD_NUM)(int);
...@@ -134,6 +143,14 @@ extern "C" void cblas_dger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONS ...@@ -134,6 +143,14 @@ extern "C" void cblas_dger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONS
OPENBLAS_CONST double *X, OPENBLAS_CONST BLASINT incX, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT incY, OPENBLAS_CONST double *X, OPENBLAS_CONST BLASINT incX, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT incY,
double *A, OPENBLAS_CONST BLASINT lda); double *A, OPENBLAS_CONST BLASINT lda);
extern "C" float cblas_sasum (OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" float cblas_isamax (OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" float cblas_isamin (OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" float cblas_snrm2 (OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" void cblas_sscal (OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float a,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" void cblas_scopy (OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx,OPENBLAS_CONST float *y,OPENBLAS_CONST BLASINT incy);
extern "C" void cblas_saxpy (OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
#if defined(OPENBLAS) #if defined(OPENBLAS)
/* better control of multi-threading */ /* better control of multi-threading */
extern "C" void openblas_set_num_threads(int num_threads); extern "C" void openblas_set_num_threads(int num_threads);
......
...@@ -33,7 +33,7 @@ namespace nts{ // namespace nts(NiuTrans.Tensor) ...@@ -33,7 +33,7 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* cross reference */ /* cross reference */
struct XTensor; struct XTensor;
#define MAX_OP_NAME_LENGTH 16 #define MAX_OP_NAME_LENGTH 32
#define PARAM_UNTI_SIZE 64 #define PARAM_UNTI_SIZE 64
/* /*
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XDevice.h" #include "../../XDevice.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XBLAS.h"
#include "../arithmetic/XTensorBLAS.h"
#include "MatrixMulBatched.h" #include "MatrixMulBatched.h"
#include "XTensorBLAS.h" #include "XTensorBLAS.h"
#include "MatrixMul2D.h" #include "MatrixMul2D.h"
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XUtility.h" #include "../../XUtility.h"
#include "../../XBLAS.h"
#include "../arithmetic/XTensorBLAS.h"
#include "../movement/CopyValues.h" #include "../movement/CopyValues.h"
#include "Sum.h" #include "Sum.h"
#include "Sum.cuh" #include "Sum.cuh"
......
...@@ -49,12 +49,12 @@ void _MatrixMULCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -49,12 +49,12 @@ void _MatrixMULCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
#if defined(USE_BLAS) #if defined(USE_BLAS)
int an = a->dimSize[0]; int an = a->dimSize[0];
int am = a->dimSize[1]; int am = a->dimSize[1];
int bn = b->dimSize[0]; int bn = b->dimSize[0];
int bm = b->dimSize[1]; int bm = b->dimSize[1];
int cn = c->dimSize[0]; int cn = c->dimSize[0];
int cm = c->dimSize[1]; int cm = c->dimSize[1];
printf("4\n");
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS) if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, cn, cm, am, alpha, (DTYPE*)a->data, am, (DTYPE*)b->data, bm, beta, (DTYPE*)c->data, cm); GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, cn, cm, am, alpha, (DTYPE*)a->data, am, (DTYPE*)b->data, bm, beta, (DTYPE*)c->data, cm);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS) else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
......
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XUtility.h" #include "../../XUtility.h"
#include "../../XBLAS.h"
#include "../arithmetic/XTensorBLAS.h"
#include "../movement/CopyValues.h"
#include "ScaleAndShift.h" #include "ScaleAndShift.h"
#include "ScaleAndShift.cuh" #include "ScaleAndShift.cuh"
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XBLAS.h"
#include "../arithmetic/XTensorBLAS.h"
#include "ReduceMax.h" #include "ReduceMax.h"
#include "ReduceMax.cuh" #include "ReduceMax.cuh"
...@@ -77,12 +79,12 @@ void _ReduceMax(const XTensor * input, XTensor * output, int dim) ...@@ -77,12 +79,12 @@ void _ReduceMax(const XTensor * input, XTensor * output, int dim)
blockSize = stride * strideNum; blockSize = stride * strideNum;
for(int k = 0; k < blockNum; k++){ for(int k = 0; k < blockNum; k++){
if(useBLAS){ DTYPE * ip = (DTYPE*)input->data + blockSize * k;
*(op + i) = *(ip + i + cblas_isamax(strideNum, ip + i, stride)); DTYPE * op = (DTYPE*)output->data + stride * k;
} else{ for(int i = 0; i < stride; i++){
DTYPE * ip = (DTYPE*)input->data + blockSize * k; if(useBLAS){
DTYPE * op = (DTYPE*)output->data + stride * k; *(op + i) = cblas_isamax(strideNum, ip + i, stride);
for(int i = 0; i < stride; i++){ } else{
DTYPE max = FLOAT_MIN; DTYPE max = FLOAT_MIN;
DTYPE * ipe = ip + blockSize; DTYPE * ipe = ip + blockSize;
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "ReduceSum.h" #include "ReduceSum.h"
#include "ReduceSum.cuh" #include "ReduceSum.cuh"
#include "../../XName.h" #include "../../XName.h"
#include "../../XBLAS.h"
#include "../arithmetic/XTensorBLAS.h"
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论