Commit f98396a9 by 张裕浩

1.update OpenBLAS and MKL usage , test in Linux and Windows environment

2.update  SumMe operation
parent d0c6a99c
...@@ -26,183 +26,9 @@ ...@@ -26,183 +26,9 @@
* *
*/ */
#ifdef WIN32
#include <wtypes.h>
#endif
#include <stdlib.h>
#include <stdio.h>
#include "XBLAS.h"
#include "XGlobal.h"
/* the nts (NiuTrans.Tensor) namespace */ /* the nts (NiuTrans.Tensor) namespace */
namespace nts{ namespace nts{
#ifdef WIN32
HINSTANCE hBLASDll;
#endif
/* single-precision floating matrix-matrix multiplication */
void (*XBLAS_SGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
float *, OPENBLAS_CONST BLASINT);
/* double-precision floating matrix-matrix multiplication */
void (*XBLAS_DGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
double *, OPENBLAS_CONST BLASINT);
/* single-precision floating vector-vector multiplication (rank-1) */
void (*XBLAS_SGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha,
OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
float *, OPENBLAS_CONST BLASINT);
/* double-precision floating vector-vector multiplication (rank-1) */
void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
double *, OPENBLAS_CONST BLASINT);
/* set the number of threads */
void (*XBLAS_SET_THREAD_NUM)(int);
/* get the number of threads */
//int (*XBLAS_GET_THREAD_NUM)();
/* get the number of physical processors (cores).*/
int (*XBLAS_GET_CORE_NUM)();
/* get the CPU corename */
//char * (*XBLAS_GET_CORE_NAME)();
/* get the parallelization type used by OpenBLAS */
//int (*XBLAS_GET_PARALLEL_TYPE)(void);
#if defined(USE_BLAS)
/* load some stuff for BLAS */
void LoadBLAS(const char * dllFileName)
{
#ifndef CUDA_BLAS
#ifdef _WIN32
#if defined(OPENBLAS)
/* non-ascii characters are not supported yet */
wchar_t * fn = new wchar_t[strlen(dllFileName) + 1];
memset(fn, 0, sizeof(wchar_t) * (strlen(dllFileName) + 1));
for(int i = 0; i < strlen(dllFileName); i++)
fn[i] = dllFileName[i];
hBLASDll = LoadLibrary((LPCWSTR)fn);
if(!hBLASDll){
XPRINT1(0, stderr, "[LoadBLAS] Error! Cannot load dll %s!\n", dllFileName);
exit(1);
}
/* matrix-matrix multiplicatoin */
(FARPROC&)XBLAS_SGEMM = GetProcAddress(hBLASDll, "cblas_sgemm");
(FARPROC&)XBLAS_DGEMM = GetProcAddress(hBLASDll, "cblas_dgemm");
/* vector-vector multiplication */
(FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger");
(FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger");
/* multi-threading */
(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "openblas_set_num_threads");
//(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "goto_set_num_threads");
//(FARPROC&)XBLAS_GET_THREAD_NUM = GetProcAddress(hBLASDll, "openblas_get_num_threads");
(FARPROC&)XBLAS_GET_CORE_NUM = GetProcAddress(hBLASDll, "openblas_get_num_procs");
//(FARPROC&)XBLAS_GET_CORE_NAME = GetProcAddress(hBLASDll, "openblas_get_corename");
//(FARPROC&)XBLAS_GET_PARALLEL_TYPE = GetProcAddress(hBLASDll, "openblas_get_parallel");
delete[] fn;
#endif // defined(OPENBLAS)
#if defined(MKL)
/* non-ascii characters are not supported yet */
wchar_t * fn = new wchar_t[strlen(dllFileName) + 1];
memset(fn, 0, sizeof(wchar_t) * (strlen(dllFileName) + 1));
for(int i = 0; i < strlen(dllFileName); i++)
fn[i] = dllFileName[i];
hBLASDll = LoadLibrary((LPCWSTR)fn);
if(!hBLASDll){
XPRINT1(0, stderr, "[LoadBLAS] Error! Cannot load dll %s!\n", dllFileName);
exit(1);
}
/* matrix-matrix multiplicatoin */
(FARPROC&)XBLAS_SGEMM = GetProcAddress(hBLASDll, "cblas_sgemm");
(FARPROC&)XBLAS_DGEMM = GetProcAddress(hBLASDll, "cblas_dgemm");
/* vector-vector multiplication */
(FARPROC&)XBLAS_SGER = GetProcAddress(hBLASDll, "cblas_sger");
(FARPROC&)XBLAS_DGER = GetProcAddress(hBLASDll, "cblas_dger");
/* multi-threading */
(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "MKL_Set_Num_Threads");
(FARPROC&)XBLAS_GET_CORE_NUM = GetProcAddress(hBLASDll, "MKL_Get_Max_Threads");
#endif // defined(MKL)
#else // _WIN32
XBLAS_SGEMM = &cblas_sgemm;
XBLAS_DGEMM = &cblas_dgemm;
XBLAS_SGER = &cblas_sger;
XBLAS_DGER = &cblas_dger;
#if defined(OPENBLAS)
XBLAS_SET_THREAD_NUM = &openblas_set_num_threads;
XBLAS_GET_CORE_NUM = &openblas_get_num_procs;
#endif // defined(OPENBLAS)
#if defined(MKL)
XBLAS_SET_THREAD_NUM = &mkl_set_num_threads;
XBLAS_GET_CORE_NUM = &mkl_get_max_num_threads;
#endif // defined(MKL)
#endif // _WIN32
XBLAS_SET_THREAD_NUM(1);
#endif // ndef(CUDA_BLAS)
}
/* unload the libs */
void UnloadBLAS()
{
#ifdef _WIN32
if(!FreeLibrary(hBLASDll)){
XPRINT(0, stderr, "[UnloadBLAS] Error! Cannot free the BLAS dll!\n");
exit(1);
}
#else
#endif
}
#else // undefined(USE_BLAS) || undefined(OPENBLAS)
void LoadBLAS(const char * dllFileName)
{
XPRINT(0, stderr, "[LoadBLAS] Error! No Blas lib is available. Please use OPENBLAS or MKL!\n");
exit(1);
}
void UnloadBLAS()
{
XPRINT(0, stderr, "[UnloadBLAS] Error! No Blas lib is available. Please use OPENBLAS or MKL!\n");
exit(1);
}
#endif // defined(USE_BLAS) && defined(OPENBLAS)
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
...@@ -46,7 +46,26 @@ typedef enum CBLAS_SIDE {CblasLeft=141, CblasRight=142} CBLAS_SIDE; ...@@ -46,7 +46,26 @@ typedef enum CBLAS_SIDE {CblasLeft=141, CblasRight=142} CBLAS_SIDE;
#if defined(USE_BLAS) #if defined(USE_BLAS)
#ifdef OPENBLAS
#define XBLAS_SGEMM cblas_sgemm
#define XBLAS_DGEMM cblas_dgemm
#define XBLAS_SGER cblas_sger
#define XBLAS_DGER cblas_dger
#define XBLAS_SAXPY cblas_saxpy
#define XBLAS_DAXPY cblas_daxpy
#define XBLAS_SET_THREAD_NUM openblas_set_num_threads
#define XBLAS_GET_CORE_NUM openblas_get_num_procs
#endif
#ifdef MKL
#define XBLAS_SGEMM cblas_sgemm
#define XBLAS_DGEMM cblas_dgemm
#define XBLAS_SGER cblas_sger
#define XBLAS_DGER cblas_dger
#define XBLAS_SAXPY cblas_saxpy
#define XBLAS_DAXPY cblas_daxpy
#define XBLAS_SET_THREAD_NUM MKL_Set_Num_Threads
#define XBLAS_GET_CORE_NUM MKL_Get_Max_Threads
#endif
/* /*
single/double-precision floating matrix-matrix multiplication (rank-3) single/double-precision floating matrix-matrix multiplication (rank-3)
- SGEMM (ORDER, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) - SGEMM (ORDER, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
...@@ -62,14 +81,14 @@ where A, B and C are matrices, ...@@ -62,14 +81,14 @@ where A, B and C are matrices,
LDB(=N) specifies the size of the first dimension of B as declared in the calling (sub) program, LDB(=N) specifies the size of the first dimension of B as declared in the calling (sub) program,
and LDC(=N) specifies the size of the first dimension of C as declared in the calling (sub) program. and LDC(=N) specifies the size of the first dimension of C as declared in the calling (sub) program.
*/ */
extern "C" void (*XBLAS_SGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE, extern "C" void XBLAS_SGEMM(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float,
float *, OPENBLAS_CONST BLASINT); float *, OPENBLAS_CONST BLASINT);
/* double-precision floating matrix-matrix multiplication */ /* double-precision floating matrix-matrix multiplication */
extern "C" void (*XBLAS_DGEMM)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE, extern "C" void XBLAS_DGEMM(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST enum CBLAS_TRANSPOSE, OPENBLAS_CONST enum CBLAS_TRANSPOSE,
OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double,
...@@ -88,24 +107,33 @@ where X and Y are vectors with m and n elements respectively, ...@@ -88,24 +107,33 @@ where X and Y are vectors with m and n elements respectively,
E.g., if we are using CblasRowMajor, the leading dimension is the number of columns of A. E.g., if we are using CblasRowMajor, the leading dimension is the number of columns of A.
*/ */
extern "C" void (*XBLAS_SGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha, extern "C" void XBLAS_SGER(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST float alpha,
OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST float *, OPENBLAS_CONST BLASINT,
float *, OPENBLAS_CONST BLASINT); float *, OPENBLAS_CONST BLASINT);
/* double-precision floating vector-vector multiplication (rank-1) */ /* double-precision floating vector-vector multiplication (rank-1) */
extern "C" void (*XBLAS_DGER)(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha, extern "C" void XBLAS_DGER(OPENBLAS_CONST enum CBLAS_ORDER, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT, OPENBLAS_CONST double *, OPENBLAS_CONST BLASINT,
double *, OPENBLAS_CONST BLASINT); double *, OPENBLAS_CONST BLASINT);
/*
some description
*/
extern "C" void XBLAS_SAXPY(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
/* double-precision floating sumMe function */
extern "C" void XBLAS_DAXPY(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST double a, OPENBLAS_CONST double *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST double *y, OPENBLAS_CONST BLASINT incy);
/* set the number of threads */ /* set the number of threads */
extern "C" void (*XBLAS_SET_THREAD_NUM)(int); extern "C" void XBLAS_SET_THREAD_NUM(int);
/* get the number of threads */ /* get the number of threads */
//extern "C" int (*XBLAS_GET_THREAD_NUM)(); //extern "C" int (*XBLAS_GET_THREAD_NUM)();
/* get the number of physical processors (cores).*/ /* get the number of physical processors (cores).*/
extern "C" int (*XBLAS_GET_CORE_NUM)(); extern "C" int XBLAS_GET_CORE_NUM();
/* get the CPU corename */ /* get the CPU corename */
//extern "C" char * (*XBLAS_GET_CORE_NAME)(); //extern "C" char * (*XBLAS_GET_CORE_NAME)();
...@@ -133,19 +161,25 @@ extern "C" void cblas_sger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONS ...@@ -133,19 +161,25 @@ extern "C" void cblas_sger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONS
extern "C" void cblas_dger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha, extern "C" void cblas_dger (OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST BLASINT M, OPENBLAS_CONST BLASINT N, OPENBLAS_CONST double alpha,
OPENBLAS_CONST double *X, OPENBLAS_CONST BLASINT incX, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT incY, OPENBLAS_CONST double *X, OPENBLAS_CONST BLASINT incX, OPENBLAS_CONST double *Y, OPENBLAS_CONST BLASINT incY,
double *A, OPENBLAS_CONST BLASINT lda); double *A, OPENBLAS_CONST BLASINT lda);
extern "C" void cblas_saxpy (OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
extern "C" void cblas_daxpy (OPENBLAS_CONST BLASINT n, OPENBLAS_CONST double a, OPENBLAS_CONST double *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST double *y, OPENBLAS_CONST BLASINT incy);
#endif
#if defined(OPENBLAS) #if defined(OPENBLAS)
/* better control of multi-threading */ /* better control of multi-threading */
extern "C" void openblas_set_num_threads(int num_threads); extern "C" void openblas_set_num_threads(int num_threads);
extern "C" void goto_set_num_threads(int num_threads); extern "C" void goto_set_num_threads(int num_threads);
//extern "C" int openblas_get_num_threads(void); //extern "C" int openblas_get_num_threads(void);
extern "C" int openblas_get_num_procs(void); extern "C" int openblas_get_num_procs(void);
extern "C" void cblas_saxpy(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
extern "C" void cblas_daxpy(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST double a, OPENBLAS_CONST double *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST double *y, OPENBLAS_CONST BLASINT incy);
//extern "C" char* openblas_get_config(void); //extern "C" char* openblas_get_config(void);
//extern "C" char* openblas_get_corename(void); //extern "C" char* openblas_get_corename(void);
//extern "C" int openblas_get_parallel(void); //extern "C" int openblas_get_parallel(void);
#endif #endif
#endif
#if defined(MKL) #if defined(MKL)
...@@ -153,12 +187,12 @@ extern "C" int openblas_get_num_procs(void); ...@@ -153,12 +187,12 @@ extern "C" int openblas_get_num_procs(void);
/* better control of multi-threading */ /* better control of multi-threading */
//_Mkl_Api(void,MKL_Set_Num_Threads,(int nth)) //_Mkl_Api(void,MKL_Set_Num_Threads,(int nth))
//_Mkl_Api(int,MKL_Get_Max_Threads,(void)) //_Mkl_Api(int,MKL_Get_Max_Threads,(void))
extern "C" void MKL_Set_Num_Threads(int num_threads); //extern "C" void MKL_Set_Num_Threads(int num_threads);
extern "C" int MKL_Get_Max_Threads(); //extern "C" int MKL_Get_Max_Threads();
#define mkl_set_num_threads MKL_Set_Num_Threads //extern "C" void cblas_saxpy(OPENBLAS_CONST BLASINT n, OPENBLAS_CONST float a, OPENBLAS_CONST float *x, OPENBLAS_CONST BLASINT incx, OPENBLAS_CONST float *y, OPENBLAS_CONST BLASINT incy);
#define mkl_get_max_num_threads MKL_Get_Max_Threads //#define mkl_set_num_threads MKL_Set_Num_Threads
//#define mkl_get_max_num_threads MKL_Get_Max_Threads
//extern "C" void mkl_set_num_threads(int num_threads); //extern "C" void mkl_set_num_threads(int num_threads);
//extern "C" void omp_set_num_threads(int num_threads); //extern "C" void omp_set_num_threads(int num_threads);
...@@ -186,24 +220,8 @@ extern void BLASMatrixMULD(int deviceID, double * a, double * b, double * c, int ...@@ -186,24 +220,8 @@ extern void BLASMatrixMULD(int deviceID, double * a, double * b, double * c, int
#endif #endif
#endif
#ifdef _WIN32
#include "windows.h"
extern HINSTANCE hBLASDll;
#else
#endif #endif
/* load some stuff for BLAS */
extern void LoadBLAS(const char * dllFileName);
/* unload the libs */
extern void UnloadBLAS();
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
#endif #endif
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XUtility.h" #include "../../XUtility.h"
#include "../../XBLAS.h"
#include "../movement/CopyValues.h" #include "../movement/CopyValues.h"
#include "Sum.h" #include "Sum.h"
#include "Sum.cuh" #include "Sum.cuh"
...@@ -82,7 +83,34 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta) ...@@ -82,7 +83,34 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
DTYPE * ap = (DTYPE*)a->data; DTYPE * ap = (DTYPE*)a->data;
DTYPE * bp = (DTYPE*)b->data; DTYPE * bp = (DTYPE*)b->data;
DTYPE * cp = (DTYPE*)c->data; DTYPE * cp = (DTYPE*)c->data;
/* when c != a, OpenBLAS needs to copy a to c first. This operation
slow down the speed, so just use OpenBLAS when c == a */
#if defined(USE_BLAS)
if( c == a){
AXPY(a->unitNum,beta,bp,1,cp,1);
} else{
int num = a->unitNum;
if (num % 4 == 0) {
for (int i = 0; i < num; i += 4) {
cp[i] = ap[i] + bp[i] * beta;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
cp[i + 2] = ap[i + 2] + bp[i + 2] * beta;
cp[i + 3] = ap[i + 3] + bp[i + 3] * beta;
}
}
else if (num % 2 == 0) {
for (int i = 0; i < num; i += 2) {
cp[i] = ap[i] + bp[i] * beta;
cp[i + 1] = ap[i + 1] + bp[i + 1] * beta;
}
}
else {
for (int i = 0; i < num; i++) {
cp[i] = ap[i] + bp[i] * beta;
}
}
}
#else
/* unrolling */ /* unrolling */
int num = a->unitNum; int num = a->unitNum;
if (num % 4 == 0) { if (num % 4 == 0) {
...@@ -104,6 +132,7 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta) ...@@ -104,6 +132,7 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
cp[i] = ap[i] + bp[i] * beta; cp[i] = ap[i] + bp[i] * beta;
} }
} }
#endif
} }
else { else {
// TODO!! // TODO!!
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论