mirror of
https://github.com/OpenMathLib/OpenBLAS
synced 2026-05-31 00:45:48 +08:00
Add FP16 support for RISCV
This commit is contained in:
@@ -1889,6 +1889,7 @@ export TARGET_CORE
|
||||
export NO_AVX512
|
||||
export NO_AVX2
|
||||
export BUILD_BFLOAT16
|
||||
export BUILD_HFLOAT16
|
||||
export NO_LSX
|
||||
export NO_LASX
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
@@ -11,8 +12,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
|
||||
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
|
||||
BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
|
||||
BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
|
||||
BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
|
||||
BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
|
||||
|
||||
ifdef EXPRECISION
|
||||
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
|
||||
@@ -24,6 +25,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
|
||||
BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
|
||||
endif
|
||||
|
||||
$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX
|
||||
$(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
|
||||
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX
|
||||
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX
|
||||
@@ -33,6 +35,7 @@ $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
|
||||
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
|
||||
$(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
|
||||
|
||||
$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
|
||||
$(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
|
||||
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
|
||||
$(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
|
||||
|
||||
4
cblas.h
4
cblas.h
@@ -446,6 +446,10 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C
|
||||
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
|
||||
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);
|
||||
|
||||
/*** FLOAT16 extensions */
|
||||
void cblas_shgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
|
||||
OPENBLAS_CONST float alpha, OPENBLAS_CONST hfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST hfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif /* __cplusplus */
|
||||
|
||||
@@ -640,6 +640,9 @@ endif()
|
||||
if (BUILD_BFLOAT16)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_BFLOAT16")
|
||||
endif()
|
||||
if (BUILD_HFLOAT16)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_HFLOAT16")
|
||||
endif()
|
||||
if(NOT MSVC)
|
||||
set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}")
|
||||
endif()
|
||||
@@ -647,14 +650,14 @@ endif()
|
||||
set(PFLAGS "${PFLAGS} ${CCOMMON_OPT} -I${TOPDIR} -DPROFILE ${COMMON_PROF}")
|
||||
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
|
||||
|
||||
if ("${F_COMPILER}" STREQUAL "FLANG")
|
||||
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
|
||||
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
|
||||
endif ()
|
||||
endif ()
|
||||
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
|
||||
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
|
||||
endif ()
|
||||
if ("${F_COMPILER}" STREQUAL "FLANG")
|
||||
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
|
||||
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
|
||||
endif ()
|
||||
endif ()
|
||||
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
|
||||
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
|
||||
|
||||
9
common.h
9
common.h
@@ -266,6 +266,11 @@ typedef uint16_t bfloat16;
|
||||
#define BFLOAT16CONVERSION 1
|
||||
#endif
|
||||
|
||||
#ifndef hfloat16
|
||||
#include <stdint.h>
|
||||
typedef uint16_t hfloat16;
|
||||
#endif
|
||||
|
||||
#ifdef USE64BITINT
|
||||
typedef BLASLONG blasint;
|
||||
#if defined(OS_WINDOWS) && defined(__64BIT__)
|
||||
@@ -313,8 +318,8 @@ typedef int blasint;
|
||||
#define SIZE 2
|
||||
#define BASE_SHIFT 1
|
||||
#define ZBASE_SHIFT 2
|
||||
#elif defined(FLOAT16)
|
||||
#define IFLOAT float16
|
||||
#elif defined(HFLOAT16)
|
||||
#define IFLOAT hfloat16
|
||||
#define XFLOAT IFLOAT
|
||||
#define FLOAT float
|
||||
#define SIZE 2
|
||||
|
||||
@@ -481,6 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint
|
||||
|
||||
/* Level 3 routines */
|
||||
|
||||
void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
|
||||
hfloat16 *, blasint *, hfloat16 *, blasint *, float *, float *, blasint *);
|
||||
void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
|
||||
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
|
||||
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
|
||||
|
||||
@@ -54,7 +54,8 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,
|
||||
|
||||
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
|
||||
|
||||
|
||||
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
|
||||
hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG);
|
||||
int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
|
||||
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
|
||||
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
|
||||
@@ -78,6 +79,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *,
|
||||
xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
|
||||
#endif
|
||||
|
||||
int shgemm_incopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
|
||||
int shgemm_itcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
|
||||
int shgemm_oncopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
|
||||
int shgemm_otcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
|
||||
int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
|
||||
int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
|
||||
int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
|
||||
@@ -505,6 +510,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
|
||||
int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
|
||||
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
|
||||
|
||||
int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG);
|
||||
int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
|
||||
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
|
||||
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);
|
||||
@@ -657,6 +663,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float
|
||||
int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG);
|
||||
int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG);
|
||||
|
||||
int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
|
||||
int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
|
||||
int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
|
||||
int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
|
||||
|
||||
int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
|
||||
int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
|
||||
int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
|
||||
@@ -754,6 +765,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON
|
||||
int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG);
|
||||
#endif
|
||||
|
||||
int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
|
||||
int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
|
||||
int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
|
||||
int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
|
||||
|
||||
int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
|
||||
int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
|
||||
int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
|
||||
@@ -1944,6 +1960,7 @@ int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
|
||||
int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
|
||||
int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
|
||||
int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
|
||||
// int shgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
|
||||
|
||||
#ifdef __CUDACC__
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@
|
||||
#ifndef COMMON_MACRO
|
||||
#define COMMON_MACRO
|
||||
|
||||
#include "common_sh.h"
|
||||
#include "common_sb.h"
|
||||
#include "common_s.h"
|
||||
#include "common_d.h"
|
||||
@@ -656,6 +657,50 @@
|
||||
#define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT
|
||||
#define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN
|
||||
#define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT
|
||||
#elif defined(HFLOAT16)
|
||||
#define GEMM_BETA SHGEMM_BETA
|
||||
#define GEMM_KERNEL_N SHGEMM_KERNEL
|
||||
#define GEMM_KERNEL_L SHGEMM_KERNEL
|
||||
#define GEMM_KERNEL_R SHGEMM_KERNEL
|
||||
#define GEMM_KERNEL_B SHGEMM_KERNEL
|
||||
#define GEMM_NN SHGEMM_NN
|
||||
#define GEMM_CN SHGEMM_TN
|
||||
#define GEMM_TN SHGEMM_TN
|
||||
#define GEMM_NC SHGEMM_NT
|
||||
#define GEMM_NT SHGEMM_NT
|
||||
#define GEMM_CC SHGEMM_TT
|
||||
#define GEMM_CT SHGEMM_TT
|
||||
#define GEMM_TC SHGEMM_TT
|
||||
#define GEMM_TT SHGEMM_TT
|
||||
#define GEMM_NR SHGEMM_NN
|
||||
#define GEMM_TR SHGEMM_TN
|
||||
#define GEMM_CR SHGEMM_TN
|
||||
#define GEMM_RN SHGEMM_NN
|
||||
#define GEMM_RT SHGEMM_NT
|
||||
#define GEMM_RC SHGEMM_NT
|
||||
#define GEMM_RR SHGEMM_NN
|
||||
#define GEMM_ONCOPY SHGEMM_ONCOPY
|
||||
#define GEMM_OTCOPY SHGEMM_OTCOPY
|
||||
#define GEMM_INCOPY SHGEMM_INCOPY
|
||||
#define GEMM_ITCOPY SHGEMM_ITCOPY
|
||||
|
||||
#define GEMM_THREAD_NN SHGEMM_THREAD_NN
|
||||
#define GEMM_THREAD_CN SHGEMM_THREAD_TN
|
||||
#define GEMM_THREAD_TN SHGEMM_THREAD_TN
|
||||
#define GEMM_THREAD_NC SHGEMM_THREAD_NT
|
||||
#define GEMM_THREAD_NT SHGEMM_THREAD_NT
|
||||
#define GEMM_THREAD_CC SHGEMM_THREAD_TT
|
||||
#define GEMM_THREAD_CT SHGEMM_THREAD_TT
|
||||
#define GEMM_THREAD_TC SHGEMM_THREAD_TT
|
||||
#define GEMM_THREAD_TT SHGEMM_THREAD_TT
|
||||
#define GEMM_THREAD_NR SHGEMM_THREAD_NN
|
||||
#define GEMM_THREAD_TR SHGEMM_THREAD_TN
|
||||
#define GEMM_THREAD_CR SHGEMM_THREAD_TN
|
||||
#define GEMM_THREAD_RN SHGEMM_THREAD_NN
|
||||
#define GEMM_THREAD_RT SHGEMM_THREAD_NT
|
||||
#define GEMM_THREAD_RC SHGEMM_THREAD_NT
|
||||
#define GEMM_THREAD_RR SHGEMM_THREAD_NN
|
||||
|
||||
|
||||
#elif defined(BFLOAT16)
|
||||
|
||||
|
||||
@@ -48,6 +48,21 @@ typedef struct {
|
||||
int dtb_entries;
|
||||
int switch_ratio;
|
||||
int offsetA, offsetB, align;
|
||||
#if BUILD_HFLOAT16 == 1
|
||||
int shgemm_p, shgemm_q, shgemm_r;
|
||||
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;
|
||||
|
||||
int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG);
|
||||
int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG);
|
||||
|
||||
int (*shgemm_incopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
|
||||
int (*shgemm_itcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
|
||||
int (*shgemm_oncopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
|
||||
int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
#if BUILD_BFLOAT16 == 1
|
||||
int sbgemm_p, sbgemm_q, sbgemm_r;
|
||||
@@ -64,10 +79,10 @@ typedef struct {
|
||||
float (*sbamin_k) (BLASLONG, float *, BLASLONG);
|
||||
float (*sbmax_k) (BLASLONG, float *, BLASLONG);
|
||||
float (*sbmin_k) (BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);
|
||||
|
||||
float (*sbnrm2_k) (BLASLONG, float *, BLASLONG);
|
||||
float (*sbasum_k) (BLASLONG, float *, BLASLONG);
|
||||
@@ -180,12 +195,12 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);
|
||||
#endif
|
||||
|
||||
#if (BUILD_SINGLE==1) || (BUILD_DOUBLE ==1) || (BUILD_COMPLEX==1)
|
||||
BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG);
|
||||
#endif
|
||||
#if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1)
|
||||
BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
|
||||
float (*snrm2_k) (BLASLONG, float *, BLASLONG);
|
||||
float (*sasum_k) (BLASLONG, float *, BLASLONG);
|
||||
#endif
|
||||
@@ -316,10 +331,10 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
|
||||
double (*damin_k) (BLASLONG, double *, BLASLONG);
|
||||
double (*dmax_k) (BLASLONG, double *, BLASLONG);
|
||||
double (*dmin_k) (BLASLONG, double *, BLASLONG);
|
||||
BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG);
|
||||
BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG);
|
||||
BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG);
|
||||
BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG);
|
||||
BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG);
|
||||
BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG);
|
||||
BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG);
|
||||
BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG);
|
||||
|
||||
double (*dnrm2_k) (BLASLONG, double *, BLASLONG);
|
||||
double (*dasum_k) (BLASLONG, double *, BLASLONG);
|
||||
@@ -435,10 +450,10 @@ BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG);
|
||||
xdouble (*qamin_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
xdouble (*qmax_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
xdouble (*qmin_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG);
|
||||
BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG);
|
||||
BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG);
|
||||
BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG);
|
||||
BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
|
||||
xdouble (*qnrm2_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
xdouble (*qasum_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
@@ -528,8 +543,8 @@ BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
float (*camax_k) (BLASLONG, float *, BLASLONG);
|
||||
float (*camin_k) (BLASLONG, float *, BLASLONG);
|
||||
|
||||
BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG);
|
||||
BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG);
|
||||
|
||||
float (*cnrm2_k) (BLASLONG, float *, BLASLONG);
|
||||
float (*casum_k) (BLASLONG, float *, BLASLONG);
|
||||
@@ -739,8 +754,8 @@ BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG);
|
||||
|
||||
double (*zamax_k) (BLASLONG, double *, BLASLONG);
|
||||
double (*zamin_k) (BLASLONG, double *, BLASLONG);
|
||||
BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG);
|
||||
BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG);
|
||||
BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG);
|
||||
BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG);
|
||||
|
||||
double (*znrm2_k) (BLASLONG, double *, BLASLONG);
|
||||
double (*zasum_k) (BLASLONG, double *, BLASLONG);
|
||||
@@ -950,8 +965,8 @@ BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG);
|
||||
|
||||
xdouble (*xamax_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
xdouble (*xamin_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG);
|
||||
BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG);
|
||||
BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG);
|
||||
BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG);
|
||||
|
||||
xdouble (*xnrm2_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
xdouble (*xasum_k) (BLASLONG, xdouble *, BLASLONG);
|
||||
@@ -1229,6 +1244,15 @@ extern gotoblas_t *gotoblas;
|
||||
|
||||
#define HAVE_EX_L2 gotoblas -> exclusive_cache
|
||||
|
||||
#if (BUILD_HFLOAT16==1)
|
||||
#define SHGEMM_P gotoblas -> shgemm_p
|
||||
#define SHGEMM_Q gotoblas -> shgemm_q
|
||||
#define SHGEMM_R gotoblas -> shgemm_r
|
||||
#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m
|
||||
#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n
|
||||
#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn
|
||||
#endif
|
||||
|
||||
#if (BUILD_BFLOAT16==1)
|
||||
#define SBGEMM_P gotoblas -> sbgemm_p
|
||||
#define SBGEMM_Q gotoblas -> sbgemm_q
|
||||
@@ -1357,6 +1381,19 @@ extern gotoblas_t *gotoblas;
|
||||
#define HAVE_EX_L2 0
|
||||
#endif
|
||||
|
||||
#if (BUILD_HFLOAT16 == 1)
|
||||
#define SHGEMM_P SHGEMM_DEFAULT_P
|
||||
#define SHGEMM_Q SHGEMM_DEFAULT_Q
|
||||
#define SHGEMM_R SHGEMM_DEFAULT_R
|
||||
#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
|
||||
#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
|
||||
#ifdef SHGEMM_DEFAULT_UNROLL_MN
|
||||
#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN
|
||||
#else
|
||||
#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N))
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if (BUILD_BFLOAT16 == 1)
|
||||
#define SBGEMM_P SBGEMM_DEFAULT_P
|
||||
#define SBGEMM_Q SBGEMM_DEFAULT_Q
|
||||
@@ -1478,6 +1515,7 @@ extern gotoblas_t *gotoblas;
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef COMPLEX
|
||||
@@ -1505,6 +1543,18 @@ extern gotoblas_t *gotoblas;
|
||||
#define GEMM_DEFAULT_R DGEMM_DEFAULT_R
|
||||
#define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M
|
||||
#define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N
|
||||
#elif defined(HFLOAT16)
|
||||
#define GEMM_P SHGEMM_P
|
||||
#define GEMM_Q SHGEMM_Q
|
||||
#define GEMM_R SHGEMM_R
|
||||
#define GEMM_UNROLL_M SHGEMM_UNROLL_M
|
||||
#define GEMM_UNROLL_N SHGEMM_UNROLL_N
|
||||
#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN
|
||||
#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P
|
||||
#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q
|
||||
#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R
|
||||
#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
|
||||
#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
|
||||
#elif defined(BFLOAT16)
|
||||
#define GEMM_P SBGEMM_P
|
||||
#define GEMM_Q SBGEMM_Q
|
||||
|
||||
72
common_sh.h
Normal file
72
common_sh.h
Normal file
@@ -0,0 +1,72 @@
|
||||
#ifndef COMMON_SH_H
|
||||
#define COMMON_SH_H
|
||||
|
||||
#ifndef DYNAMIC_ARCH
|
||||
|
||||
#define SHGEMM_ONCOPY shgemm_oncopy
|
||||
#define SHGEMM_OTCOPY shgemm_otcopy
|
||||
|
||||
#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N
|
||||
#define SHGEMM_INCOPY shgemm_oncopy
|
||||
#define SHGEMM_ITCOPY shgemm_otcopy
|
||||
#else
|
||||
#define SHGEMM_INCOPY shgemm_incopy
|
||||
#define SHGEMM_ITCOPY shgemm_itcopy
|
||||
#endif
|
||||
|
||||
#define SHGEMM_BETA shgemm_beta
|
||||
#define SHGEMM_KERNEL shgemm_kernel
|
||||
|
||||
|
||||
#else // #DYNAMIC_ARCH
|
||||
|
||||
#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy
|
||||
#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy
|
||||
#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N
|
||||
#define SHGEMM_INCOPY gotoblas -> shgemm_oncopy
|
||||
#define SHGEMM_ITCOPY gotoblas -> shgemm_otcopy
|
||||
#else
|
||||
#define SHGEMM_INCOPY gotoblas -> shgemm_incopy
|
||||
#define SHGEMM_ITCOPY gotoblas -> shgemm_itcopy
|
||||
#endif
|
||||
|
||||
#define SHGEMM_BETA gotoblas -> shgemm_beta
|
||||
#define SHGEMM_KERNEL gotoblas -> shgemm_kernel
|
||||
#endif // #DYNAMIC_ARCH
|
||||
|
||||
#define SHGEMM_NN shgemm_nn
|
||||
#define SHGEMM_CN shgemm_tn
|
||||
#define SHGEMM_TN shgemm_tn
|
||||
#define SHGEMM_NC shgemm_nt
|
||||
#define SHGEMM_NT shgemm_nt
|
||||
#define SHGEMM_CC shgemm_tt
|
||||
#define SHGEMM_CT shgemm_tt
|
||||
#define SHGEMM_TC shgemm_tt
|
||||
#define SHGEMM_TT shgemm_tt
|
||||
#define SHGEMM_NR shgemm_nn
|
||||
#define SHGEMM_TR shgemm_tn
|
||||
#define SHGEMM_CR shgemm_tn
|
||||
#define SHGEMM_RN shgemm_nn
|
||||
#define SHGEMM_RT shgemm_nt
|
||||
#define SHGEMM_RC shgemm_nt
|
||||
#define SHGEMM_RR shgemm_nn
|
||||
|
||||
#define SHGEMM_THREAD_NN shgemm_thread_nn
|
||||
#define SHGEMM_THREAD_CN shgemm_thread_tn
|
||||
#define SHGEMM_THREAD_TN shgemm_thread_tn
|
||||
#define SHGEMM_THREAD_NC shgemm_thread_nt
|
||||
#define SHGEMM_THREAD_NT shgemm_thread_nt
|
||||
#define SHGEMM_THREAD_CC shgemm_thread_tt
|
||||
#define SHGEMM_THREAD_CT shgemm_thread_tt
|
||||
#define SHGEMM_THREAD_TC shgemm_thread_tt
|
||||
#define SHGEMM_THREAD_TT shgemm_thread_tt
|
||||
#define SHGEMM_THREAD_NR shgemm_thread_nn
|
||||
#define SHGEMM_THREAD_TR shgemm_thread_tn
|
||||
#define SHGEMM_THREAD_CR shgemm_thread_tn
|
||||
#define SHGEMM_THREAD_RN shgemm_thread_nn
|
||||
#define SHGEMM_THREAD_RT shgemm_thread_nt
|
||||
#define SHGEMM_THREAD_RC shgemm_thread_nt
|
||||
#define SHGEMM_THREAD_RR shgemm_thread_nn
|
||||
|
||||
|
||||
#endif // #COMMON_SH_H
|
||||
@@ -39,6 +39,11 @@ typedef unsigned long BLASULONG;
|
||||
typedef uint16_t bfloat16;
|
||||
#endif
|
||||
|
||||
#ifndef HFLOAT16
|
||||
#include <stdint.h>
|
||||
typedef uint16_t hfloat16;
|
||||
#endif
|
||||
|
||||
#ifdef OPENBLAS_USE64BITINT
|
||||
typedef BLASLONG blasint;
|
||||
#else
|
||||
|
||||
28
param.h
28
param.h
@@ -72,6 +72,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#ifndef PARAM_H
|
||||
#define PARAM_H
|
||||
|
||||
#define SHGEMM_DEFAULT_UNROLL_N 8
|
||||
#define SHGEMM_DEFAULT_UNROLL_M 8
|
||||
#define SHGEMM_DEFAULT_P 128
|
||||
#define SHGEMM_DEFAULT_R 240
|
||||
#define SHGEMM_DEFAULT_Q 12288
|
||||
|
||||
#define SBGEMM_DEFAULT_UNROLL_N 4
|
||||
#define SBGEMM_DEFAULT_UNROLL_M 8
|
||||
@@ -3138,10 +3143,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#endif
|
||||
|
||||
#ifdef RISCV64_ZVL128B
|
||||
|
||||
#define GEMM_DEFAULT_OFFSET_A 0
|
||||
#define GEMM_DEFAULT_OFFSET_B 0
|
||||
#define GEMM_DEFAULT_ALIGN (BLASLONG)0x03fffUL
|
||||
|
||||
#undef SHGEMM_DEFAULT_UNROLL_M
|
||||
#undef SHGEMM_DEFAULT_UNROLL_N
|
||||
#define SHGEMM_DEFAULT_UNROLL_M 8
|
||||
#define SHGEMM_DEFAULT_UNROLL_N 8
|
||||
|
||||
#define SGEMM_DEFAULT_UNROLL_M 8
|
||||
#define SGEMM_DEFAULT_UNROLL_N 8
|
||||
|
||||
@@ -3154,16 +3165,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#define ZGEMM_DEFAULT_UNROLL_M 4
|
||||
#define ZGEMM_DEFAULT_UNROLL_N 4
|
||||
|
||||
#undef SHGEMM_DEFAULT_P
|
||||
#define SHGEMM_DEFAULT_P 128
|
||||
#define SGEMM_DEFAULT_P 128
|
||||
#define DGEMM_DEFAULT_P 128
|
||||
#define CGEMM_DEFAULT_P 96
|
||||
#define ZGEMM_DEFAULT_P 64
|
||||
|
||||
#undef SHGEMM_DEFAULT_Q
|
||||
#define SHGEMM_DEFAULT_Q 240
|
||||
#define SGEMM_DEFAULT_Q 240
|
||||
#define DGEMM_DEFAULT_Q 120
|
||||
#define CGEMM_DEFAULT_Q 120
|
||||
#define ZGEMM_DEFAULT_Q 120
|
||||
|
||||
#undef SHGEMM_DEFAULT_R
|
||||
#define SHGEMM_DEFAULT_R 12288
|
||||
#define SGEMM_DEFAULT_R 12288
|
||||
#define DGEMM_DEFAULT_R 8192
|
||||
#define CGEMM_DEFAULT_R 4096
|
||||
@@ -3181,6 +3198,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#define GEMM_DEFAULT_OFFSET_B 0
|
||||
#define GEMM_DEFAULT_ALIGN 0x03fffUL
|
||||
|
||||
#undef SHGEMM_DEFAULT_UNROLL_M
|
||||
#undef SHGEMM_DEFAULT_UNROLL_N
|
||||
#define SHGEMM_DEFAULT_UNROLL_M 16
|
||||
#define SHGEMM_DEFAULT_UNROLL_N 8
|
||||
|
||||
#define SGEMM_DEFAULT_UNROLL_M 16
|
||||
#define SGEMM_DEFAULT_UNROLL_N 8
|
||||
|
||||
@@ -3193,16 +3215,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#define ZGEMM_DEFAULT_UNROLL_M 8
|
||||
#define ZGEMM_DEFAULT_UNROLL_N 4
|
||||
|
||||
#undef SHGEMM_DEFAULT_P
|
||||
#define SHGEMM_DEFAULT_P 128
|
||||
#define SGEMM_DEFAULT_P 128
|
||||
#define DGEMM_DEFAULT_P 64
|
||||
#define CGEMM_DEFAULT_P 64
|
||||
#define ZGEMM_DEFAULT_P 64
|
||||
|
||||
#undef SHGEMM_DEFAULT_Q
|
||||
#define SHGEMM_DEFAULT_Q 128
|
||||
#define SGEMM_DEFAULT_Q 128
|
||||
#define DGEMM_DEFAULT_Q 128
|
||||
#define CGEMM_DEFAULT_Q 128
|
||||
#define ZGEMM_DEFAULT_Q 64
|
||||
|
||||
#undef SHGEMM_DEFAULT_R
|
||||
#define SHGEMM_DEFAULT_R 16384
|
||||
#define SGEMM_DEFAULT_R 16384
|
||||
#define DGEMM_DEFAULT_R 8192
|
||||
#define CGEMM_DEFAULT_R 8192
|
||||
|
||||
Reference in New Issue
Block a user