Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
666d51e9
Commit
666d51e9
authored
Sep 22, 2019
by
liyinqiao
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Merge with Yuhao branch.
parent
ddbb77b6
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
613 行增加
和
333 行删除
+613
-333
source/tensor/XBLAS.cpp
+0
-175
source/tensor/XBLAS.h
+36
-77
source/tensor/XGlobal.h
+2
-0
source/tensor/core/arithmetic/MatrixMul2D.cpp
+3
-2
source/tensor/core/arithmetic/MatrixMulBatched.cpp
+1
-7
source/tensor/core/arithmetic/Sum.cpp
+48
-19
source/tensor/core/reduce/ReduceMax.cpp
+70
-11
source/tensor/core/reduce/ReduceMax.cu
+10
-10
source/tensor/core/reduce/ReduceSum.cpp
+0
-0
source/tensor/core/reduce/ReduceSum.cu
+10
-10
source/tensor/core/reduce/VectorBuffer.cpp
+172
-0
source/tensor/core/reduce/VectorBuffer.h
+54
-0
source/tensor/core/sort/TopK.cu
+11
-9
source/tensor/function/Softmax.cu
+1
-1
source/tensor/test/TTopK.cpp
+195
-12
没有找到文件。
source/tensor/XBLAS.cpp
查看文件 @
666d51e9
...
...
@@ -26,183 +26,9 @@
*
*/
#ifdef WIN32
#include <wtypes.h>
#endif
#include <stdlib.h>
#include <stdio.h>
#include "XBLAS.h"
#include "XGlobal.h"
/* the nts (NiuTrans.Tensor) namespace */
namespace
nts
{
#ifdef WIN32
HINSTANCE
hBLASDll
;
#endif
/* single-precision floating matrix-matrix multiplication */
void
(
*
XBLAS_SGEMM
)(
OPENBLAS_CONST
enum
CBLAS_ORDER
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
float
,
OPENBLAS_CONST
float
*
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
float
*
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
float
,
float
*
,
OPENBLAS_CONST
BLASINT
);
/* double-precision floating matrix-matrix multiplication */
void
(
*
XBLAS_DGEMM
)(
OPENBLAS_CONST
enum
CBLAS_ORDER
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
double
,
OPENBLAS_CONST
double
*
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
double
*
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
double
,
double
*
,
OPENBLAS_CONST
BLASINT
);
/* single-precision floating vector-vector multiplication (rank-1) */
void
(
*
XBLAS_SGER
)(
OPENBLAS_CONST
enum
CBLAS_ORDER
,
OPENBLAS_CONST
BLASINT
M
,
OPENBLAS_CONST
BLASINT
N
,
OPENBLAS_CONST
float
alpha
,
OPENBLAS_CONST
float
*
Y
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
float
*
,
OPENBLAS_CONST
BLASINT
,
float
*
,
OPENBLAS_CONST
BLASINT
);
/* double-precision floating vector-vector multiplication (rank-1) */
void
(
*
XBLAS_DGER
)(
OPENBLAS_CONST
enum
CBLAS_ORDER
,
OPENBLAS_CONST
BLASINT
M
,
OPENBLAS_CONST
BLASINT
N
,
OPENBLAS_CONST
double
alpha
,
OPENBLAS_CONST
double
*
Y
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
double
*
,
OPENBLAS_CONST
BLASINT
,
double
*
,
OPENBLAS_CONST
BLASINT
);
/* set the number of threads */
void
(
*
XBLAS_SET_THREAD_NUM
)(
int
);
/* get the number of threads */
//int (*XBLAS_GET_THREAD_NUM)();
/* get the number of physical processors (cores).*/
int
(
*
XBLAS_GET_CORE_NUM
)();
/* get the CPU corename */
//char * (*XBLAS_GET_CORE_NAME)();
/* get the parallelization type used by OpenBLAS */
//int (*XBLAS_GET_PARALLEL_TYPE)(void);
#if defined(USE_BLAS)
/* load some stuff for BLAS */
void
LoadBLAS
(
const
char
*
dllFileName
)
{
#ifndef CUDA_BLAS
#ifdef _WIN32
#if defined(OPENBLAS)
/* non-ascii characters are not supported yet */
wchar_t
*
fn
=
new
wchar_t
[
strlen
(
dllFileName
)
+
1
];
memset
(
fn
,
0
,
sizeof
(
wchar_t
)
*
(
strlen
(
dllFileName
)
+
1
));
for
(
int
i
=
0
;
i
<
strlen
(
dllFileName
);
i
++
)
fn
[
i
]
=
dllFileName
[
i
];
hBLASDll
=
LoadLibrary
((
LPCWSTR
)
fn
);
if
(
!
hBLASDll
){
XPRINT1
(
0
,
stderr
,
"[LoadBLAS] Error! Cannot load dll %s!
\n
"
,
dllFileName
);
exit
(
1
);
}
/* matrix-matrix multiplicatoin */
(
FARPROC
&
)
XBLAS_SGEMM
=
GetProcAddress
(
hBLASDll
,
"cblas_sgemm"
);
(
FARPROC
&
)
XBLAS_DGEMM
=
GetProcAddress
(
hBLASDll
,
"cblas_dgemm"
);
/* vector-vector multiplication */
(
FARPROC
&
)
XBLAS_SGER
=
GetProcAddress
(
hBLASDll
,
"cblas_sger"
);
(
FARPROC
&
)
XBLAS_DGER
=
GetProcAddress
(
hBLASDll
,
"cblas_dger"
);
/* multi-threading */
(
FARPROC
&
)
XBLAS_SET_THREAD_NUM
=
GetProcAddress
(
hBLASDll
,
"openblas_set_num_threads"
);
//(FARPROC&)XBLAS_SET_THREAD_NUM = GetProcAddress(hBLASDll, "goto_set_num_threads");
//(FARPROC&)XBLAS_GET_THREAD_NUM = GetProcAddress(hBLASDll, "openblas_get_num_threads");
(
FARPROC
&
)
XBLAS_GET_CORE_NUM
=
GetProcAddress
(
hBLASDll
,
"openblas_get_num_procs"
);
//(FARPROC&)XBLAS_GET_CORE_NAME = GetProcAddress(hBLASDll, "openblas_get_corename");
//(FARPROC&)XBLAS_GET_PARALLEL_TYPE = GetProcAddress(hBLASDll, "openblas_get_parallel");
delete
[]
fn
;
#endif // defined(OPENBLAS)
#if defined(MKL)
/* non-ascii characters are not supported yet */
wchar_t
*
fn
=
new
wchar_t
[
strlen
(
dllFileName
)
+
1
];
memset
(
fn
,
0
,
sizeof
(
wchar_t
)
*
(
strlen
(
dllFileName
)
+
1
));
for
(
int
i
=
0
;
i
<
strlen
(
dllFileName
);
i
++
)
fn
[
i
]
=
dllFileName
[
i
];
hBLASDll
=
LoadLibrary
((
LPCWSTR
)
fn
);
if
(
!
hBLASDll
){
XPRINT1
(
0
,
stderr
,
"[LoadBLAS] Error! Cannot load dll %s!
\n
"
,
dllFileName
);
exit
(
1
);
}
/* matrix-matrix multiplicatoin */
(
FARPROC
&
)
XBLAS_SGEMM
=
GetProcAddress
(
hBLASDll
,
"cblas_sgemm"
);
(
FARPROC
&
)
XBLAS_DGEMM
=
GetProcAddress
(
hBLASDll
,
"cblas_dgemm"
);
/* vector-vector multiplication */
(
FARPROC
&
)
XBLAS_SGER
=
GetProcAddress
(
hBLASDll
,
"cblas_sger"
);
(
FARPROC
&
)
XBLAS_DGER
=
GetProcAddress
(
hBLASDll
,
"cblas_dger"
);
/* multi-threading */
(
FARPROC
&
)
XBLAS_SET_THREAD_NUM
=
GetProcAddress
(
hBLASDll
,
"MKL_Set_Num_Threads"
);
(
FARPROC
&
)
XBLAS_GET_CORE_NUM
=
GetProcAddress
(
hBLASDll
,
"MKL_Get_Max_Threads"
);
#endif // defined(MKL)
#else // _WIN32
XBLAS_SGEMM
=
&
cblas_sgemm
;
XBLAS_DGEMM
=
&
cblas_dgemm
;
XBLAS_SGER
=
&
cblas_sger
;
XBLAS_DGER
=
&
cblas_dger
;
#if defined(OPENBLAS)
XBLAS_SET_THREAD_NUM
=
&
openblas_set_num_threads
;
XBLAS_GET_CORE_NUM
=
&
openblas_get_num_procs
;
#endif // defined(OPENBLAS)
#if defined(MKL)
XBLAS_SET_THREAD_NUM
=
&
mkl_set_num_threads
;
XBLAS_GET_CORE_NUM
=
&
mkl_get_max_num_threads
;
#endif // defined(MKL)
#endif // _WIN32
XBLAS_SET_THREAD_NUM
(
1
);
#endif // ndef(CUDA_BLAS)
}
/* unload the libs */
void
UnloadBLAS
()
{
#ifdef _WIN32
if
(
!
FreeLibrary
(
hBLASDll
)){
XPRINT
(
0
,
stderr
,
"[UnloadBLAS] Error! Cannot free the BLAS dll!
\n
"
);
exit
(
1
);
}
#else
#endif
}
#else // undefined(USE_BLAS) || undefined(OPENBLAS)
void
LoadBLAS
(
const
char
*
dllFileName
)
{
XPRINT
(
0
,
stderr
,
"[LoadBLAS] Error! No Blas lib is available. Please use OPENBLAS or MKL!
\n
"
);
exit
(
1
);
}
void
UnloadBLAS
()
{
XPRINT
(
0
,
stderr
,
"[UnloadBLAS] Error! No Blas lib is available. Please use OPENBLAS or MKL!
\n
"
);
exit
(
1
);
}
#endif // defined(USE_BLAS) && defined(OPENBLAS)
}
/* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
source/tensor/XBLAS.h
查看文件 @
666d51e9
...
...
@@ -34,7 +34,6 @@ namespace nts{
/* some of the code below is from OpenBLAS (https://github.com/xianyi/OpenBLAS) */
//#define OPENBLAS
#define OPENBLAS_CONST const
typedef
int
BLASINT
;
...
...
@@ -46,7 +45,26 @@ typedef enum CBLAS_SIDE {CblasLeft=141, CblasRight=142} CBLAS_SIDE;
#if defined(USE_BLAS)
#ifdef OPENBLAS
#define XBLAS_SGEMM cblas_sgemm
#define XBLAS_DGEMM cblas_dgemm
#define XBLAS_SGER cblas_sger
#define XBLAS_DGER cblas_dger
#define XBLAS_SAXPY cblas_saxpy
#define XBLAS_DAXPY cblas_daxpy
#define XBLAS_SET_THREAD_NUM openblas_set_num_threads
#define XBLAS_GET_CORE_NUM openblas_get_num_procs
#endif
#ifdef MKL
#define XBLAS_SGEMM cblas_sgemm
#define XBLAS_DGEMM cblas_dgemm
#define XBLAS_SGER cblas_sger
#define XBLAS_DGER cblas_dger
#define XBLAS_SAXPY cblas_saxpy
#define XBLAS_DAXPY cblas_daxpy
#define XBLAS_SET_THREAD_NUM MKL_Set_Num_Threads
#define XBLAS_GET_CORE_NUM MKL_Get_Max_Threads
#endif
/*
single/double-precision floating matrix-matrix multiplication (rank-3)
- SGEMM (ORDER, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
...
...
@@ -62,14 +80,14 @@ where A, B and C are matrices,
LDB(=N) specifies the size of the first dimension of B as declared in the calling (sub) program,
and LDC(=N) specifies the size of the first dimension of C as declared in the calling (sub) program.
*/
extern
"C"
void
(
*
XBLAS_SGEMM
)
(
OPENBLAS_CONST
enum
CBLAS_ORDER
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
,
extern
"C"
void
XBLAS_SGEMM
(
OPENBLAS_CONST
enum
CBLAS_ORDER
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
float
,
OPENBLAS_CONST
float
*
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
float
*
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
float
,
float
*
,
OPENBLAS_CONST
BLASINT
);
/* double-precision floating matrix-matrix multiplication */
extern
"C"
void
(
*
XBLAS_DGEMM
)
(
OPENBLAS_CONST
enum
CBLAS_ORDER
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
,
extern
"C"
void
XBLAS_DGEMM
(
OPENBLAS_CONST
enum
CBLAS_ORDER
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
double
,
OPENBLAS_CONST
double
*
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
double
*
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
double
,
...
...
@@ -88,24 +106,33 @@ where X and Y are vectors with m and n elements respectively,
E.g., if we are using CblasRowMajor, the leading dimension is the number of columns of A.
*/
extern
"C"
void
(
*
XBLAS_SGER
)
(
OPENBLAS_CONST
enum
CBLAS_ORDER
,
OPENBLAS_CONST
BLASINT
M
,
OPENBLAS_CONST
BLASINT
N
,
OPENBLAS_CONST
float
alpha
,
extern
"C"
void
XBLAS_SGER
(
OPENBLAS_CONST
enum
CBLAS_ORDER
,
OPENBLAS_CONST
BLASINT
M
,
OPENBLAS_CONST
BLASINT
N
,
OPENBLAS_CONST
float
alpha
,
OPENBLAS_CONST
float
*
Y
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
float
*
,
OPENBLAS_CONST
BLASINT
,
float
*
,
OPENBLAS_CONST
BLASINT
);
/* double-precision floating vector-vector multiplication (rank-1) */
extern
"C"
void
(
*
XBLAS_DGER
)
(
OPENBLAS_CONST
enum
CBLAS_ORDER
,
OPENBLAS_CONST
BLASINT
M
,
OPENBLAS_CONST
BLASINT
N
,
OPENBLAS_CONST
double
alpha
,
extern
"C"
void
XBLAS_DGER
(
OPENBLAS_CONST
enum
CBLAS_ORDER
,
OPENBLAS_CONST
BLASINT
M
,
OPENBLAS_CONST
BLASINT
N
,
OPENBLAS_CONST
double
alpha
,
OPENBLAS_CONST
double
*
Y
,
OPENBLAS_CONST
BLASINT
,
OPENBLAS_CONST
double
*
,
OPENBLAS_CONST
BLASINT
,
double
*
,
OPENBLAS_CONST
BLASINT
);
/*
some description
*/
extern
"C"
void
XBLAS_SAXPY
(
OPENBLAS_CONST
BLASINT
n
,
OPENBLAS_CONST
float
a
,
OPENBLAS_CONST
float
*
x
,
OPENBLAS_CONST
BLASINT
incx
,
OPENBLAS_CONST
float
*
y
,
OPENBLAS_CONST
BLASINT
incy
);
/* double-precision floating sumMe function */
extern
"C"
void
XBLAS_DAXPY
(
OPENBLAS_CONST
BLASINT
n
,
OPENBLAS_CONST
double
a
,
OPENBLAS_CONST
double
*
x
,
OPENBLAS_CONST
BLASINT
incx
,
OPENBLAS_CONST
double
*
y
,
OPENBLAS_CONST
BLASINT
incy
);
/* set the number of threads */
extern
"C"
void
(
*
XBLAS_SET_THREAD_NUM
)
(
int
);
extern
"C"
void
XBLAS_SET_THREAD_NUM
(
int
);
/* get the number of threads */
//extern "C" int (*XBLAS_GET_THREAD_NUM)();
/* get the number of physical processors (cores).*/
extern
"C"
int
(
*
XBLAS_GET_CORE_NUM
)
();
extern
"C"
int
XBLAS_GET_CORE_NUM
();
/* get the CPU corename */
//extern "C" char * (*XBLAS_GET_CORE_NAME)();
...
...
@@ -113,58 +140,6 @@ extern "C" int (*XBLAS_GET_CORE_NUM)();
/* get the parallelization type used by OpenBLAS */
//extern "C" int (*XBLAS_GET_PARALLEL_TYPE)(void);
/* linux systems */
#ifndef _WIN32
/* cblas functions that are imported from the lib. See cblas.h in OpenBlas for more information */
extern
"C"
void
cblas_sgemm
(
OPENBLAS_CONST
enum
CBLAS_ORDER
Order
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
TransA
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
TransB
,
OPENBLAS_CONST
BLASINT
M
,
OPENBLAS_CONST
BLASINT
N
,
OPENBLAS_CONST
BLASINT
K
,
OPENBLAS_CONST
float
alpha
,
OPENBLAS_CONST
float
*
A
,
OPENBLAS_CONST
BLASINT
lda
,
OPENBLAS_CONST
float
*
B
,
OPENBLAS_CONST
BLASINT
ldb
,
OPENBLAS_CONST
float
beta
,
float
*
C
,
OPENBLAS_CONST
BLASINT
ldc
);
extern
"C"
void
cblas_dgemm
(
OPENBLAS_CONST
enum
CBLAS_ORDER
Order
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
TransA
,
OPENBLAS_CONST
enum
CBLAS_TRANSPOSE
TransB
,
OPENBLAS_CONST
BLASINT
M
,
OPENBLAS_CONST
BLASINT
N
,
OPENBLAS_CONST
BLASINT
K
,
OPENBLAS_CONST
double
alpha
,
OPENBLAS_CONST
double
*
A
,
OPENBLAS_CONST
BLASINT
lda
,
OPENBLAS_CONST
double
*
B
,
OPENBLAS_CONST
BLASINT
ldb
,
OPENBLAS_CONST
double
beta
,
double
*
C
,
OPENBLAS_CONST
BLASINT
ldc
);
extern
"C"
void
cblas_sger
(
OPENBLAS_CONST
enum
CBLAS_ORDER
order
,
OPENBLAS_CONST
BLASINT
M
,
OPENBLAS_CONST
BLASINT
N
,
OPENBLAS_CONST
float
alpha
,
OPENBLAS_CONST
float
*
X
,
OPENBLAS_CONST
BLASINT
incX
,
OPENBLAS_CONST
float
*
Y
,
OPENBLAS_CONST
BLASINT
incY
,
float
*
A
,
OPENBLAS_CONST
BLASINT
lda
);
extern
"C"
void
cblas_dger
(
OPENBLAS_CONST
enum
CBLAS_ORDER
order
,
OPENBLAS_CONST
BLASINT
M
,
OPENBLAS_CONST
BLASINT
N
,
OPENBLAS_CONST
double
alpha
,
OPENBLAS_CONST
double
*
X
,
OPENBLAS_CONST
BLASINT
incX
,
OPENBLAS_CONST
double
*
Y
,
OPENBLAS_CONST
BLASINT
incY
,
double
*
A
,
OPENBLAS_CONST
BLASINT
lda
);
#if defined(OPENBLAS)
/* better control of multi-threading */
extern
"C"
void
openblas_set_num_threads
(
int
num_threads
);
extern
"C"
void
goto_set_num_threads
(
int
num_threads
);
//extern "C" int openblas_get_num_threads(void);
extern
"C"
int
openblas_get_num_procs
(
void
);
//extern "C" char* openblas_get_config(void);
//extern "C" char* openblas_get_corename(void);
//extern "C" int openblas_get_parallel(void);
#endif
#endif
#if defined(MKL)
/* better control of multi-threading */
//_Mkl_Api(void,MKL_Set_Num_Threads,(int nth))
//_Mkl_Api(int,MKL_Get_Max_Threads,(void))
extern
"C"
void
MKL_Set_Num_Threads
(
int
num_threads
);
extern
"C"
int
MKL_Get_Max_Threads
();
#define mkl_set_num_threads MKL_Set_Num_Threads
#define mkl_get_max_num_threads MKL_Get_Max_Threads
//extern "C" void mkl_set_num_threads(int num_threads);
//extern "C" void omp_set_num_threads(int num_threads);
//extern "C" int mkl_get_max_num_threads();
#endif
#if defined(CUDA_BLAS)
...
...
@@ -186,24 +161,8 @@ extern void BLASMatrixMULD(int deviceID, double * a, double * b, double * c, int
#endif
#endif
#ifdef _WIN32
#include "windows.h"
extern
HINSTANCE
hBLASDll
;
#else
#endif
/* load some stuff for BLAS */
extern
void
LoadBLAS
(
const
char
*
dllFileName
);
/* unload the libs */
extern
void
UnloadBLAS
();
}
/* end of the nts (NiuTrans.Tensor) namespace */
#endif
source/tensor/XGlobal.h
查看文件 @
666d51e9
...
...
@@ -160,8 +160,10 @@ extern bool useCUDA;
/* BLAS interfaces */
#ifdef DOUBELPRICSION
#define GEMM XBLAS_DGEMM
#define AXPY XBLAS_DAXPY
#else
#define GEMM XBLAS_SGEMM
#define AXPY XBLAS_SAXPY
#endif
extern
void
InitGlobalAll
();
...
...
source/tensor/core/arithmetic/MatrixMul2D.cpp
查看文件 @
666d51e9
...
...
@@ -82,10 +82,11 @@ void _MatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
b
->
dataType
==
DEFAULT_DTYPE
&&
c
->
dataType
==
DEFAULT_DTYPE
)
{
if
(
use
BLAS
)
#if defined(USE_
BLAS)
_MatrixMULCPU
(
a
,
transposedA
,
b
,
transposedB
,
c
,
alpha
,
beta
);
else
#
else
_MatrixMul2DParallel
(
a
,
transposedA
,
b
,
transposedB
,
c
,
alpha
,
beta
,
parallelRunner
);
#endif
}
else
{
// TODO!!
...
...
source/tensor/core/arithmetic/MatrixMulBatched.cpp
查看文件 @
666d51e9
...
...
@@ -199,10 +199,7 @@ void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
bi
->
data
=
(
char
*
)
b
->
data
+
i
*
bRealBlockSize
;
ci
->
data
=
(
char
*
)
c
->
data
+
i
*
cRealBlockSize
;
#ifdef USE_BLAS
if
(
useBLAS
)
_MatrixMULCPU
(
ai
,
transposedA
,
bi
,
transposedB
,
ci
,
alpha
,
beta
);
else
_MatrixMul2D
(
ai
,
transposedA
,
bi
,
transposedB
,
ci
,
alpha
,
beta
);
_MatrixMULCPU
(
ai
,
transposedA
,
bi
,
transposedB
,
ci
,
alpha
,
beta
);
#else
_MatrixMul2D
(
ai
,
transposedA
,
bi
,
transposedB
,
ci
,
alpha
,
beta
);
#endif
...
...
@@ -262,10 +259,7 @@ void _MatrixMulBatchedCPU(const TensorList * a, MATRIX_TRANS_TYPE transposedA,
CheckNTErrors
((
bi
->
order
==
2
),
"2d tensor (i.e., matrix) is required!"
);
CheckNTErrors
((
ci
->
order
==
2
),
"2d tensor (i.e., matrix) is required!"
);
#ifdef USE_BLAS
if
(
useBLAS
)
_MatrixMULCPU
(
ai
,
transposedA
,
bi
,
transposedB
,
ci
,
alpha
,
beta
);
else
_MatrixMul2D
(
ai
,
transposedA
,
bi
,
transposedB
,
ci
,
alpha
,
beta
);
#else
_MatrixMul2D
(
ai
,
transposedA
,
bi
,
transposedB
,
ci
,
alpha
,
beta
);
#endif
...
...
source/tensor/core/arithmetic/Sum.cpp
查看文件 @
666d51e9
...
...
@@ -22,6 +22,7 @@
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "../../XBLAS.h"
#include "../movement/CopyValues.h"
#include "Sum.h"
#include "Sum.cuh"
...
...
@@ -84,29 +85,57 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
DTYPE
*
ap
=
(
DTYPE
*
)
a
->
data
;
DTYPE
*
bp
=
(
DTYPE
*
)
b
->
data
;
DTYPE
*
cp
=
(
DTYPE
*
)
c
->
data
;
/* unrolling */
int
num
=
a
->
unitNum
;
if
(
num
%
4
==
0
)
{
for
(
int
i
=
0
;
i
<
num
;
i
+=
4
)
{
cp
[
i
]
=
ap
[
i
]
+
bp
[
i
]
*
beta
;
cp
[
i
+
1
]
=
ap
[
i
+
1
]
+
bp
[
i
+
1
]
*
beta
;
cp
[
i
+
2
]
=
ap
[
i
+
2
]
+
bp
[
i
+
2
]
*
beta
;
cp
[
i
+
3
]
=
ap
[
i
+
3
]
+
bp
[
i
+
3
]
*
beta
;
}
/* when c != a, OpenBLAS needs to copy a to c first. This operation
slow down the speed, so just use OpenBLAS when c == a */
#if defined(USE_BLAS)
if
(
c
==
a
){
AXPY
(
a
->
unitNum
,
beta
,
bp
,
1
,
cp
,
1
);
}
else
{
int
num
=
a
->
unitNum
;
if
(
num
%
4
==
0
)
{
for
(
int
i
=
0
;
i
<
num
;
i
+=
4
)
{
cp
[
i
]
=
ap
[
i
]
+
bp
[
i
]
*
beta
;
cp
[
i
+
1
]
=
ap
[
i
+
1
]
+
bp
[
i
+
1
]
*
beta
;
cp
[
i
+
2
]
=
ap
[
i
+
2
]
+
bp
[
i
+
2
]
*
beta
;
cp
[
i
+
3
]
=
ap
[
i
+
3
]
+
bp
[
i
+
3
]
*
beta
;
}
}
else
if
(
num
%
2
==
0
)
{
for
(
int
i
=
0
;
i
<
num
;
i
+=
2
)
{
cp
[
i
]
=
ap
[
i
]
+
bp
[
i
]
*
beta
;
cp
[
i
+
1
]
=
ap
[
i
+
1
]
+
bp
[
i
+
1
]
*
beta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
num
;
i
++
)
{
cp
[
i
]
=
ap
[
i
]
+
bp
[
i
]
*
beta
;
}
}
}
else
if
(
num
%
2
==
0
)
{
for
(
int
i
=
0
;
i
<
num
;
i
+=
2
)
{
cp
[
i
]
=
ap
[
i
]
+
bp
[
i
]
*
beta
;
cp
[
i
+
1
]
=
ap
[
i
+
1
]
+
bp
[
i
+
1
]
*
beta
;
#else
/* unrolling */
int
num
=
a
->
unitNum
;
if
(
num
%
4
==
0
)
{
for
(
int
i
=
0
;
i
<
num
;
i
+=
4
)
{
cp
[
i
]
=
ap
[
i
]
+
bp
[
i
]
*
beta
;
cp
[
i
+
1
]
=
ap
[
i
+
1
]
+
bp
[
i
+
1
]
*
beta
;
cp
[
i
+
2
]
=
ap
[
i
+
2
]
+
bp
[
i
+
2
]
*
beta
;
cp
[
i
+
3
]
=
ap
[
i
+
3
]
+
bp
[
i
+
3
]
*
beta
;
}
}
}
else
{
for
(
int
i
=
0
;
i
<
num
;
i
++
)
{
cp
[
i
]
=
ap
[
i
]
+
bp
[
i
]
*
beta
;
else
if
(
num
%
2
==
0
)
{
for
(
int
i
=
0
;
i
<
num
;
i
+=
2
)
{
cp
[
i
]
=
ap
[
i
]
+
bp
[
i
]
*
beta
;
cp
[
i
+
1
]
=
ap
[
i
+
1
]
+
bp
[
i
+
1
]
*
beta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
num
;
i
++
)
{
cp
[
i
]
=
ap
[
i
]
+
bp
[
i
]
*
beta
;
}
}
#endif
}
}
else
{
// TODO!!
ShowNTErrors
(
"TODO!"
);
...
...
source/tensor/core/reduce/ReduceMax.cpp
查看文件 @
666d51e9
...
...
@@ -21,6 +21,8 @@
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XBLAS.h"
#include "VectorBuffer.h"
#include "ReduceMax.h"
#include "ReduceMax.cuh"
...
...
@@ -76,18 +78,75 @@ void _ReduceMax(const XTensor * input, XTensor * output, int dim)
}
blockSize
=
stride
*
strideNum
;
for
(
int
k
=
0
;
k
<
blockNum
;
k
++
){
DTYPE
*
ip
=
(
DTYPE
*
)
input
->
data
+
blockSize
*
k
;
DTYPE
*
op
=
(
DTYPE
*
)
output
->
data
+
stride
*
k
;
for
(
int
i
=
0
;
i
<
stride
;
i
++
){
DTYPE
max
=
FLOAT_MIN
;
DTYPE
*
ipe
=
ip
+
blockSize
;
for
(
DTYPE
*
ipb
=
ip
+
i
;
ipb
<
ipe
;
ipb
+=
stride
){
DTYPE
v
=
*
ipb
;
if
(
max
<
v
)
max
=
v
;
if
(
input
->
dimSizeRDI
[
0
]
%
(
4
*
32
/
sizeof
(
DTYPE
))
==
0
&&
input
->
dimSizeRDI
[
0
]
>=
32
){
int
vecBufLength
=
32
/
sizeof
(
DTYPE
);
if
(
dimRDI
==
0
){
//data is contiguous in dim 0
for
(
int
i
=
0
;
i
<
blockNum
;
i
++
){
DTYPE
*
ip
=
(
DTYPE
*
)
input
->
data
+
blockSize
*
i
;
DTYPE
*
op
=
(
DTYPE
*
)
output
->
data
+
i
;
VectorBuffer
vecBuf
[
4
];
for
(
int
j
=
0
;
j
<
4
;
j
++
){
vecBuf
[
j
]
=
VectorBuffer
::
loadu
((
DTYPE
*
)(
ip
)
+
j
*
vecBufLength
);
}
for
(
int
j
=
1
;
j
<
strideNum
/
32
;
j
++
){
const
DTYPE
*
ptr
=
(
DTYPE
*
)(
ip
+
j
*
vecBufLength
);
vecBuf
[
0
]
=
vecBuf
[
0
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
0
*
vecBufLength
));
vecBuf
[
1
]
=
vecBuf
[
1
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
1
*
vecBufLength
));
vecBuf
[
2
]
=
vecBuf
[
2
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
2
*
vecBufLength
));
vecBuf
[
3
]
=
vecBuf
[
3
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
3
*
vecBufLength
));
}
vecBuf
[
0
]
=
vecBuf
[
0
].
maxData
(
vecBuf
[
1
]);
vecBuf
[
0
]
=
vecBuf
[
0
].
maxData
(
vecBuf
[
2
]);
vecBuf
[
0
]
=
vecBuf
[
0
].
maxData
(
vecBuf
[
3
]);
DTYPE
maxN
=
DTYPE_MIN
;
for
(
int
k
=
0
;
k
<
vecBufLength
;
k
++
){
maxN
=
MAX
(
maxN
,
vecBuf
[
0
][
k
]);
}
*
op
=
maxN
;
}
}
else
{
//data is separated
for
(
int
i
=
0
;
i
<
blockNum
;
i
++
){
for
(
int
j
=
0
;
j
<
input
->
dimSizeRDI
[
0
]
/
32
;
j
++
){
DTYPE
*
ip
=
(
DTYPE
*
)
input
->
data
+
blockSize
*
i
;
DTYPE
*
op
=
(
DTYPE
*
)
output
->
data
+
stride
*
i
;
VectorBuffer
vecBuf
[
4
];
for
(
int
k
=
0
;
k
<
4
;
k
++
){
vecBuf
[
k
]
=
VectorBuffer
::
loadu
((
DTYPE
*
)(
ip
)
+
(
j
*
4
+
k
)
*
32
/
sizeof
(
DTYPE
));
}
for
(
int
k
=
1
;
k
<
strideNum
;
k
++
){
DTYPE
*
ptr
=
ip
+
k
*
stride
+
(
j
*
4
)
*
vecBufLength
;
vecBuf
[
0
]
=
vecBuf
[
0
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
0
*
vecBufLength
));
vecBuf
[
1
]
=
vecBuf
[
1
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
1
*
vecBufLength
));
vecBuf
[
2
]
=
vecBuf
[
2
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
2
*
vecBufLength
));
vecBuf
[
3
]
=
vecBuf
[
3
].
maxData
(
VectorBuffer
::
loadu
(
ptr
+
3
*
vecBufLength
));
}
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
l
=
0
;
l
<
vecBufLength
;
l
++
)
*
(
op
+
j
*
32
+
8
*
k
+
l
)
=
vecBuf
[
k
][
l
];
}
}
}
}
}
//run vector buffer
else
{
for
(
int
k
=
0
;
k
<
blockNum
;
k
++
){
DTYPE
*
ip
=
(
DTYPE
*
)
input
->
data
+
blockSize
*
k
;
DTYPE
*
op
=
(
DTYPE
*
)
output
->
data
+
stride
*
k
;
for
(
int
i
=
0
;
i
<
stride
;
i
++
){
DTYPE
max
=
DTYPE_MIN
;
DTYPE
*
ipe
=
ip
+
blockSize
;
for
(
DTYPE
*
ipb
=
ip
+
i
;
ipb
<
ipe
;
ipb
+=
stride
){
DTYPE
v
=
*
ipb
;
if
(
max
<
v
)
max
=
v
;
}
*
(
op
+
i
)
=
max
;
}
*
(
op
+
i
)
=
max
;
}
}
}
...
...
source/tensor/core/reduce/ReduceMax.cu
查看文件 @
666d51e9
...
...
@@ -41,19 +41,19 @@ float shflDownReduceMax(float input)
"{"
".reg .f32 r0;"
".reg .pred p;"
"shfl.
down.b32 r0, %1, 0x10, 0x1
f;"
"shfl.
sync.down.b32 r0, %1, 0x10, 0x1f,0xfffffff
f;"
"setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;"
"shfl.
down.b32 r0, %1, 0x8, 0x
f;"
"shfl.
sync.down.b32 r0, %1, 0x8, 0xf,0xfffffff
f;"
"setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;"
"shfl.
down.b32 r0, %1, 0x4, 0x7
;"
"shfl.
sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff
;"
"setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;"
"shfl.
down.b32 r0, %1, 0x2, 0x3
;"
"shfl.
sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff
;"
"setp.lt.f32 p,%1,r0;"
"@p mov.f32 %1,r0;"
"shfl.
down.b32 r0, %1, 0x1, 0x1
;"
"shfl.
sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff
;"
"setp.lt.f32 p, %1, r0; "
"@p mov.f32 %1,r0;"
"mov.f32 %0,%1;"
...
...
@@ -73,19 +73,19 @@ int shflDownReduceMax(int input)
"{"
".reg .s32 r0;"
".reg .pred p;"
"shfl.
down.b32 r0, %1, 0x10, 0x1
f;"
"shfl.
sync.down.b32 r0, %1, 0x10, 0x1f,0xfffffff
f;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.
down.b32 r0, %1, 0x8, 0x
f;"
"shfl.
sync.down.b32 r0, %1, 0x8, 0xf,0xfffffff
f;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.
down.b32 r0, %1, 0x4, 0x7
;"
"shfl.
sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff
;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.
down.b32 r0, %1, 0x2, 0x3
;"
"shfl.
sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff
;"
"setp.lt.s32 p,%1,r0;"
"@p mov.s32 %1,r0;"
"shfl.
down.b32 r0, %1, 0x1, 0x1
;"
"shfl.
sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff
;"
"setp.lt.s32 p, %1, r0; "
"@p mov.s32 %1,r0;"
"mov.s32 %0,%1;"
...
...
source/tensor/core/reduce/ReduceSum.cpp
查看文件 @
666d51e9
差异被折叠。
点击展开。
source/tensor/core/reduce/ReduceSum.cu
查看文件 @
666d51e9
...
...
@@ -37,15 +37,15 @@ float shflDownReduceSum(float input)
asm volatile(
"{"
".reg .f32 r0;"
"shfl.
down.b32 r0, %1, 0x10, 0x1
f;"
"shfl.
sync.down.b32 r0, %1, 0x10, 0x1f,0xfffffff
f;"
"add.f32 %1, r0, %1;"
"shfl.
down.b32 r0, %1, 0x8, 0x
f;"
"shfl.
sync.down.b32 r0, %1, 0x8, 0xf,0xfffffff
f;"
"add.f32 %1, r0, %1;"
"shfl.
down.b32 r0, %1, 0x4, 0x7
;"
"shfl.
sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff
;"
"add.f32 %1, r0, %1;"
"shfl.
down.b32 r0, %1, 0x2, 0x3
;"
"shfl.
sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff
;"
"add.f32 %1, r0, %1;"
"shfl.
down.b32 r0, %1, 0x1, 0x1
;"
"shfl.
sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff
;"
"add.f32 %0, r0, %1;"
"}"
: "=f"(output) : "f"(input));
...
...
@@ -62,15 +62,15 @@ int shflDownReduceSum(int input)
asm volatile(
"{"
".reg .s32 r0;"
"shfl.
down.b32 r0, %1, 0x10, 0x1
f;"
"shfl.
sync.down.b32 r0, %1, 0x10, 0x1f,0xfffffff
f;"
"add.s32 %1, r0, %1;"
"shfl.
down.b32 r0, %1, 0x8, 0x
f;"
"shfl.
sync.down.b32 r0, %1, 0x8, 0xf,0xfffffff
f;"
"add.s32 %1, r0, %1;"
"shfl.
down.b32 r0, %1, 0x4, 0x7
;"
"shfl.
sync.down.b32 r0, %1, 0x4, 0x7,0xffffffff
;"
"add.s32 %1, r0, %1;"
"shfl.
down.b32 r0, %1, 0x2, 0x3
;"
"shfl.
sync.down.b32 r0, %1, 0x2, 0x3,0xffffffff
;"
"add.s32 %1, r0, %1;"
"shfl.
down.b32 r0, %1, 0x1, 0x1
;"
"shfl.
sync.down.b32 r0, %1, 0x1, 0x1,0xffffffff
;"
"add.s32 %0, r0, %1;"
"}"
: "=r"(output) : "r"(input));
...
...
source/tensor/core/reduce/VectorBuffer.cpp
0 → 100644
查看文件 @
666d51e9
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: ZHANG Yuhao (email: zhangyuhao@stu.neu.edu.cn) 2019-07-23
*/
#include "VectorBuffer.h"
namespace
nts
{
/* data size for each buffer */
int
VectorBuffer
::
size
()
{
return
32
/
sizeof
(
DTYPE
);
}
/* constructor */
VectorBuffer
::
VectorBuffer
()
{
}
/*
constructor
initial values with val
*/
VectorBuffer
::
VectorBuffer
(
DTYPE
val
)
{
for
(
int
i
=
0
;
i
!=
size
();
i
++
)
{
values
[
i
]
=
val
;
}
}
/* load data */
VectorBuffer
VectorBuffer
::
loadu
(
const
DTYPE
*
ptr
,
bool
isExp
,
DTYPE
power
,
DTYPE
*
bias
)
{
int
count
=
32
/
sizeof
(
DTYPE
);
VectorBuffer
vec
;
if
(
isExp
)
{
if
(
bias
==
NULL
)
{
if
(
power
==
(
DTYPE
)
1.0
)
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
(
DTYPE
)
exp
(
*
(
ptr
+
i
));
}
}
else
if
(
power
==
(
DTYPE
)
2.0
)
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
(
DTYPE
)
exp
((
*
(
ptr
+
i
))
*
(
*
(
ptr
+
i
)));
}
}
else
if
(
power
==
(
DTYPE
)
0.5
)
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
(
DTYPE
)
exp
(
sqrt
(
*
(
ptr
+
i
)));
}
}
else
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
(
DTYPE
)
exp
(
pow
(
*
(
ptr
+
i
),
power
));
}
}
}
/*is bias == NULL*/
else
{
if
(
power
==
(
DTYPE
)
1.0
)
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
(
DTYPE
)
exp
(
*
(
ptr
+
i
)
-
bias
[
i
]);
}
}
else
if
(
power
==
(
DTYPE
)
2.0
)
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
DTYPE
value
=
*
(
ptr
+
i
)
-
bias
[
i
];
vec
.
values
[
i
]
=
(
DTYPE
)
exp
(
value
*
value
);
}
}
else
if
(
power
==
(
DTYPE
)
0.5
)
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
(
DTYPE
)
exp
(
sqrt
(
*
(
ptr
+
i
)
-
bias
[
i
]));
}
}
else
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
(
DTYPE
)
exp
(
pow
(
*
(
ptr
+
i
)
-
bias
[
i
],
power
));
}
}
}
}
//isExp
else
{
if
(
bias
==
NULL
)
{
if
(
power
==
(
DTYPE
)
1.0
)
{
memcpy
(
vec
.
values
,
ptr
,
count
*
sizeof
(
DTYPE
));
}
else
if
(
power
==
(
DTYPE
)
2.0
)
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
(
*
(
ptr
+
i
))
*
(
*
(
ptr
+
i
));
}
}
else
if
(
power
==
(
DTYPE
)
0.5
)
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
(
DTYPE
)
sqrt
(
*
(
ptr
+
i
));
}
}
else
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
(
DTYPE
)
pow
(
*
(
ptr
+
i
),
power
);
}
}
}
// if bias == NULL
else
{
if
(
power
==
(
DTYPE
)
1.0
)
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
*
(
ptr
+
i
)
-
bias
[
i
];
}
}
else
if
(
power
==
(
DTYPE
)
2.0
)
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
DTYPE
value
=
*
(
ptr
+
i
)
-
bias
[
i
];
vec
.
values
[
i
]
=
value
*
value
;
}
}
else
if
(
power
==
(
DTYPE
)
0.5
)
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
(
DTYPE
)
sqrt
(
*
(
ptr
+
i
)
-
bias
[
i
]);
}
}
else
{
for
(
int
i
=
0
;
i
!=
count
;
i
++
)
{
vec
.
values
[
i
]
=
(
DTYPE
)
pow
(
*
(
ptr
+
i
)
-
bias
[
i
],
power
);
}
}
}
}
return
vec
;
}
/* overloading [] */
const
DTYPE
&
VectorBuffer
::
operator
[](
int
idx
)
const
{
return
values
[
idx
];
}
/* overloading + */
VectorBuffer
VectorBuffer
::
operator
+
(
const
VectorBuffer
&
a
)
{
for
(
int
i
=
0
;
i
!=
a
.
size
();
i
++
)
{
this
->
values
[
i
]
=
a
[
i
]
+
this
->
values
[
i
];
}
return
*
this
;
}
/* conculte the max of two buffer */
VectorBuffer
VectorBuffer
::
maxData
(
const
VectorBuffer
&
a
)
{
for
(
int
i
=
0
;
i
!=
a
.
size
();
i
++
)
{
this
->
values
[
i
]
=
MAX
(
a
[
i
],
this
->
values
[
i
]);
}
return
*
this
;
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
source/tensor/core/reduce/VectorBuffer.h
0 → 100644
查看文件 @
666d51e9
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: ZHANG Yuhao (email: zhangyuhao@stu.neu.edu.cn) 2019-07-23
*/
//#include <cstring>
#include <math.h>
#include "../../XGlobal.h"
namespace
nts
{
class
VectorBuffer
{
private
:
/* buffer for concluter */
DTYPE
values
[
32
/
sizeof
(
DTYPE
)]
=
{
0
};
public
:
/* data size for each buffer */
static
int
size
();
/* constructor */
VectorBuffer
();
/* constructor */
VectorBuffer
(
DTYPE
val
);
/* load data */
static
VectorBuffer
loadu
(
const
DTYPE
*
ptr
,
bool
isExp
=
false
,
DTYPE
power
=
(
DTYPE
)
1
.
0
F
,
DTYPE
*
bias
=
NULL
);
/* overloading [] */
const
DTYPE
&
operator
[](
int
idx
)
const
;
/* overloading + */
VectorBuffer
operator
+
(
const
VectorBuffer
&
a
);
/* conculte the max of two buffer */
VectorBuffer
maxData
(
const
VectorBuffer
&
a
);
};
}
\ No newline at end of file
source/tensor/core/sort/TopK.cu
查看文件 @
666d51e9
...
...
@@ -377,8 +377,8 @@ get the top-k items
template<class T> __global__
void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index)
{
__shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE -
1024
* sizeof(T)) / sizeof(CudaHeapNode<T>)];
__shared__ T eachHeapMaxValue[
1024
];
__shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE -
512
* sizeof(T)) / sizeof(CudaHeapNode<T>)];
__shared__ T eachHeapMaxValue[
512
];
/*optimization k size the parameter must more than half of k*/
int parameter = 0;
...
...
@@ -429,7 +429,7 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
}
__syncthreads();
/*
to merge the heap use another way
*/
/*
to merge the heap use another way
*/
T minData = minValue;
int heapLimit = heap.count / 2;
if (heapLimit % 2 == 0 && heapLimit != 0) heapLimit -= 1;
...
...
@@ -438,12 +438,13 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
minData = heap.items[counter].value;
}
eachHeapMaxValue[threadIdx.y * blockDim.x + threadIdx.x] = minData;
//need more optimation
if (i == 0) {
int threadLimit =
(threadIdx.y + 1) * blockDim.x
;
int threadLimit =
threadIdx.y * blockDim.x + min(blockDim.x,strideNum)
;
CudaXHeap<MIN_HEAP, T> chooseHeap(k, heapData + k * ((blockDim.x * blockDim.y) + threadIdx.y));
int counter = threadIdx.y * blockDim.x;
for (; counter < threadIdx.y * blockDim.x +
k
; ++counter) {
for (; counter < threadIdx.y * blockDim.x +
min(k, blockDim.x)
; ++counter) {
chooseHeap.Push(counter, eachHeapMaxValue[counter]);
}
for (; counter < threadLimit; ++counter) {
...
...
@@ -451,15 +452,16 @@ void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T mi
chooseHeap.ReplaceTop(counter, eachHeapMaxValue[counter]);
}
}
int heapNum = chooseHeap.count;
CudaXHeap<MIN_HEAP, T> ansHeapData(k, k - parameter, heapData + k * chooseHeap.items[0].index);
int miss = parameter;
for (counter = 1; counter <
k
; ++counter) {
for (counter = 1; counter <
heapNum
; ++counter) {
chooseHeap.items[0] = chooseHeap.items[chooseHeap.count - 1];
chooseHeap.count--;
chooseHeap.Down(0);
CudaHeapNode<T> * cmpHeapData = heapData + k * (chooseHeap.items[0].index);
int cmpHeapLimit = 0;
if (counter + heapLimit <= k - parameter){
if (counter + heapLimit <= k - parameter
&& heapNum == k
){
cmpHeapLimit = heapLimit;
}
/* take the max data from the minHeap,so start search from the leaf node */
...
...
@@ -840,7 +842,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
/* we run the kernel if the heaps can fit into the shared memory */
cudaGrids[1] *= cudaBlocks[1];
cudaBlocks[1] = 1;
if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) < SHARED_MEMORY_SIZE) {
if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int))
+ (512 * sizeof(int))
< SHARED_MEMORY_SIZE) {
if (a->dataType == DEFAULT_DTYPE) {
KernelTopK3<DTYPE> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >>>
((DTYPE*)a->data, stride, strideNumA, blockNum, k, DTYPE_MIN,
...
...
@@ -869,7 +871,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
//delete indexA;
int workerNum = WORKERSNUM;
GDevs.GetCudaThread2D(a->
mem->
devID,
GDevs.GetCudaThread2D(a->devID,
workerNum, stride * blockNum, MAX_INT,
cudaGrids, cudaBlocks);
if (a->dataType == DEFAULT_DTYPE) {
...
...
source/tensor/function/Softmax.cu
查看文件 @
666d51e9
...
...
@@ -171,7 +171,7 @@ float broadcast(float input)
float output;
asm(
"{"
"shfl.
idx.b32 %0,%1,0x0,0x1
f;"
"shfl.
sync.idx.b32 %0,%1,0x0,0x1f,0xfffffff
f;"
"}"
:"=f"(output) : "f"(input)
);
...
...
source/tensor/test/TTopK.cpp
查看文件 @
666d51e9
差异被折叠。
点击展开。
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论