Commit 4641030e by Tianzhi

add headers in cpp and fix the bug of MAX_LEN in XLink.h

parent ba8bc234
......@@ -65,7 +65,7 @@ ifeq ($(USE_MKL), 1)
$(MKL_LIB_DIR)/libmkl_core.a \
$(MKL_LIB_DIR)/libmkl_intel_thread.a \
$(INTEL_ROOT)/lib/intel64/libiomp5.a
DYNAMIC_DEPLIB += -liomp5 -lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core
#DYNAMIC_DEPLIB += -liomp5 -lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core
endif
ifeq ($(USE_OPENBLAS), 1)
STATIC_DEPLIB += $(OPENBLAS_LIB_DIR)/libopenblas.a
......
......@@ -67,6 +67,13 @@ void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OP
OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
double *, OPENBLAS_CONST BLASINT);
float (*XBLAS_SASUM)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
float (*XBLAS_ISAMAX)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
float (*XBLAS_SNRM2)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
void (*XBLAS_SSCAL)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float a,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
void (*XBLAS_SCOPY)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx,OPENBLAS_CONST float *y,OPENBLAS_CONST BLASINT incy);
void (*XBLAS_SAXPY)(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
/* set the number of threads */
void (*XBLAS_SET_THREAD_NUM)(int);
......@@ -115,6 +122,13 @@ void LoadBLAS(const char * dllFileName)
(FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger");
(FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger");
(FARPROC&)XBLAS_SASUM = GetProcAddress(hBLASDll, "cblas_sasum");
(FARPROC&)XBLAS_ISAMAX = GetProcAddress(hBLASDll, "cblas_isamax");
(FARPROC&)XBLAS_SNRM2 = GetProcAddress(hBLASDll, "cblas_snrm2");
(FARPROC&)XBLAS_SSCAL = GetProcAddress(hBLASDll, "cblas_sscal");
(FARPROC&)XBLAS_SCOPY = GetProcAddress(hBLASDll, "cblas_scopy");
(FARPROC&)XBLAS_SAXPY = GetProcAddress(hBLASDll, "cblas_saxpy");
/* multi-threading */
(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "openblas_set_num_threads");
//(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "goto_set_num_threads");
......@@ -148,17 +162,31 @@ void LoadBLAS(const char * dllFileName)
(FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger");
(FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger");
(FARPROC&)XBLAS_SASUM = GetProcAddress(hBLASDll, "cblas_sasum");
(FARPROC&)XBLAS_ISAMAX = GetProcAddress(hBLASDll, "cblas_isamax");
(FARPROC&)XBLAS_SNRM2 = GetProcAddress(hBLASDll, "cblas_snrm2");
(FARPROC&)XBLAS_SSCAL = GetProcAddress(hBLASDll, "cblas_sscal");
(FARPROC&)XBLAS_SCOPY = GetProcAddress(hBLASDll, "cblas_scopy");
(FARPROC&)XBLAS_SAXPY = GetProcAddress(hBLASDll, "cblas_saxpy");
/* multi-threading */
(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "MKL_Set_Num_Threads");
(FARPROC&)XBLAS_GET_CORE_NUM = GetProcAddress(hBLASDll, "MKL_Get_Max_Threads");
#endif // defined(MKL)
#else // _WIN32
XBLAS_SGEMM = &cblas_sgemm;
XBLAS_DGEMM = &cblas_dgemm;
XBLAS_SGER = &cblas_sger;
XBLAS_DGER = &cblas_dger;
XBLAS_SASUM = &cblas_sasum;
XBLAS_ISAMAX = &cblas_isamax;
XBLAS_SNRM2 = &cblas_snrm2;
XBLAS_SSCAL = &cblas_sscal;
XBLAS_SCOPY = &cblas_scopy;
XBLAS_SAXPY = &cblas_saxpy;
#if defined(OPENBLAS)
XBLAS_SET_THREAD_NUM = &openblas_set_num_threads;
XBLAS_GET_CORE_NUM = &openblas_get_num_procs;
......@@ -205,4 +233,4 @@ void UnloadBLAS()
#endif // defined(USE_BLAS) && defined(OPENBLAS)
} /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -97,6 +97,15 @@ extern "C" void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BL
OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
double *, OPENBLAS_CONST BLASINT);
extern "C" float (*XBLAS_SASUM)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" float (*XBLAS_ISAMAX)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" float (*XBLAS_ISAMIN)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" float (*XBLAS_SNRM2)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" void (*XBLAS_SSCAL)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float a,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" void (*XBLAS_SCOPY)(OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx,OPENBLAS_CONST float *y,OPENBLAS_CONST BLASINT incy);
extern "C" void (*XBLAS_SAXPY)(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
/* set the number of threads */
extern "C" void (*XBLAS_SET_THREAD_NUM)(int);
......@@ -134,6 +143,14 @@ extern "C" void cblas_dger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONS
OPENBLAS_CONST double *X, OPENBLAS_CONST BLASINT incX, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT incY,
double *A, OPENBLAS_CONST BLASINT lda);
extern "C" float cblas_sasum (OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" float cblas_isamax (OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" float cblas_isamin (OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" float cblas_snrm2 (OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" void cblas_sscal (OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float a,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx);
extern "C" void cblas_scopy (OPENBLAS_CONST BLASINT n,OPENBLAS_CONST float *x,OPENBLAS_CONST BLASINT incx,OPENBLAS_CONST float *y,OPENBLAS_CONST BLASINT incy);
extern "C" void cblas_saxpy (OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
#if defined(OPENBLAS)
/* better control of multi-threading */
extern "C" void openblas_set_num_threads(int num_threads);
......
......@@ -33,7 +33,7 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* cross reference */
struct XTensor;
#define MAX_OP_NAME_LENGTH 16
#define MAX_OP_NAME_LENGTH 32
#define PARAM_UNTI_SIZE 64
/*
......
......@@ -22,6 +22,8 @@
#include "../../XTensor.h"
#include "../../XDevice.h"
#include "../../XName.h"
#include "../../XBLAS.h"
#include "../arithmetic/XTensorBLAS.h"
#include "MatrixMulBatched.h"
#include "XTensorBLAS.h"
#include "MatrixMul2D.h"
......
......@@ -22,6 +22,8 @@
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "../../XBLAS.h"
#include "../arithmetic/XTensorBLAS.h"
#include "../movement/CopyValues.h"
#include "Sum.h"
#include "Sum.cuh"
......
......@@ -49,12 +49,12 @@ void _MatrixMULCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
#if defined(USE_BLAS)
int an = a->dimSize[0];
int am = a->dimSize[1];
int am = a->dimSize[1];
int bn = b->dimSize[0];
int bm = b->dimSize[1];
int cn = c->dimSize[0];
int cm = c->dimSize[1];
printf("4\n");
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, cn, cm, am, alpha, (DTYPE*)a->data, am, (DTYPE*)b->data, bm, beta, (DTYPE*)c->data, cm);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
......
......@@ -22,6 +22,9 @@
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "../../XBLAS.h"
#include "../arithmetic/XTensorBLAS.h"
#include "../movement/CopyValues.h"
#include "ScaleAndShift.h"
#include "ScaleAndShift.cuh"
......
......@@ -21,6 +21,8 @@
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XBLAS.h"
#include "../arithmetic/XTensorBLAS.h"
#include "ReduceMax.h"
#include "ReduceMax.cuh"
......@@ -77,12 +79,12 @@ void _ReduceMax(const XTensor * input, XTensor * output, int dim)
blockSize = stride * strideNum;
for(int k = 0; k < blockNum; k++){
if(useBLAS){
*(op + i) = *(ip + i + cblas_isamax(strideNum, ip + i, stride));
} else{
DTYPE * ip = (DTYPE*)input->data + blockSize * k;
DTYPE * op = (DTYPE*)output->data + stride * k;
for(int i = 0; i < stride; i++){
DTYPE * ip = (DTYPE*)input->data + blockSize * k;
DTYPE * op = (DTYPE*)output->data + stride * k;
for(int i = 0; i < stride; i++){
if(useBLAS){
*(op + i) = cblas_isamax(strideNum, ip + i, stride);
} else{
DTYPE max = FLOAT_MIN;
DTYPE * ipe = ip + blockSize;
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
......
......@@ -23,6 +23,8 @@
#include "ReduceSum.h"
#include "ReduceSum.cuh"
#include "../../XName.h"
#include "../../XBLAS.h"
#include "../arithmetic/XTensorBLAS.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论