From 0a967797a15617239523053633bf14be7895b25a Mon Sep 17 00:00:00 2001 From: Srangrang Date: Tue, 27 May 2025 14:34:57 +0800 Subject: [PATCH] Add FP16 support for RISCV --- Makefile.system | 1 + Makefile.tail | 7 ++- cblas.h | 4 ++ cmake/system.cmake | 19 ++++---- common.h | 9 +++- common_interface.h | 2 + common_level3.h | 19 +++++++- common_macro.h | 45 ++++++++++++++++++ common_param.h | 94 +++++++++++++++++++++++++++++--------- common_sh.h | 72 +++++++++++++++++++++++++++++ openblas_config_template.h | 5 ++ param.h | 28 ++++++++++++ 12 files changed, 270 insertions(+), 35 deletions(-) create mode 100644 common_sh.h diff --git a/Makefile.system b/Makefile.system index 38646c3c6..cdc39982d 100644 --- a/Makefile.system +++ b/Makefile.system @@ -1889,6 +1889,7 @@ export TARGET_CORE export NO_AVX512 export NO_AVX2 export BUILD_BFLOAT16 +export BUILD_HFLOAT16 export NO_LSX export NO_LASX diff --git a/Makefile.tail b/Makefile.tail index 54ba649db..ed2c0e507 100644 --- a/Makefile.tail +++ b/Makefile.tail @@ -1,4 +1,5 @@ SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) +SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) @@ -11,8 +12,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) -BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) -BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) +BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) +BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) ifdef EXPRECISION BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) @@ -24,6 +25,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) endif +$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX $(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX @@ -33,6 +35,7 @@ $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX $(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX +$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) diff --git a/cblas.h b/cblas.h index 83686f743..ebf066ce7 100644 --- a/cblas.h +++ b/cblas.h @@ -446,6 +446,10 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); +/*** FLOAT16 extensions */ +void cblas_shgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, + OPENBLAS_CONST float alpha, OPENBLAS_CONST hfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST hfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc); + #ifdef __cplusplus } #endif /* __cplusplus */ diff --git a/cmake/system.cmake b/cmake/system.cmake index 14b2c65b1..bac756901 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -640,6 +640,9 @@ endif() if (BUILD_BFLOAT16) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_BFLOAT16") endif() +if (BUILD_HFLOAT16) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_HFLOAT16") +endif() if(NOT MSVC) set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}") endif() @@ -647,14 +650,14 @@ endif() set(PFLAGS "${PFLAGS} ${CCOMMON_OPT} -I${TOPDIR} -DPROFILE ${COMMON_PROF}") if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release") -if ("${F_COMPILER}" STREQUAL "FLANG") -if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3) - set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops") -endif () -endif () -if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows") - set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2") -endif () + if ("${F_COMPILER}" STREQUAL "FLANG") + if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3) + set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops") + endif () + endif () + if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows") + set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2") + endif () endif () diff --git a/common.h b/common.h index 21f6dd0f9..deb61b35b 100644 --- a/common.h +++ b/common.h @@ -266,6 +266,11 @@ typedef uint16_t bfloat16; #define BFLOAT16CONVERSION 1 #endif +#ifndef hfloat16 +#include +typedef uint16_t hfloat16; +#endif + #ifdef USE64BITINT typedef BLASLONG blasint; #if defined(OS_WINDOWS) && defined(__64BIT__) @@ -313,8 +318,8 @@ typedef int blasint; #define SIZE 2 #define BASE_SHIFT 1 #define ZBASE_SHIFT 2 -#elif defined(FLOAT16) -#define IFLOAT float16 +#elif defined(HFLOAT16) +#define IFLOAT hfloat16 #define XFLOAT IFLOAT #define FLOAT float #define SIZE 2 diff --git a/common_interface.h b/common_interface.h index efd3c6649..23d86871f 100644 --- a/common_interface.h +++ b/common_interface.h @@ -481,6 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint /* Level 3 routines */ +void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *, + hfloat16 *, blasint *, hfloat16 *, blasint *, float *, float *, blasint *); void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *, bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, diff --git a/common_level3.h b/common_level3.h index d370c1f96..1838b4bf6 100644 --- a/common_level3.h +++ b/common_level3.h @@ -54,7 +54,8 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K, int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); - +int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, + hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG); int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, @@ -78,6 +79,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG); #endif +int shgemm_incopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); +int shgemm_itcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); +int shgemm_oncopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); +int shgemm_otcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); @@ -505,6 +510,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); +int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG); int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); @@ -657,6 +663,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG); int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG); +int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); + int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); @@ -754,6 +765,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG); #endif +int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); + int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); @@ -1944,6 +1960,7 @@ int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +// int shgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); #ifdef __CUDACC__ } diff --git a/common_macro.h b/common_macro.h index 820cb472a..b967c2e60 100644 --- a/common_macro.h +++ b/common_macro.h @@ -39,6 +39,7 @@ #ifndef COMMON_MACRO #define COMMON_MACRO +#include "common_sh.h" #include "common_sb.h" #include "common_s.h" #include "common_d.h" @@ -656,6 +657,50 @@ #define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT #define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN #define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT +#elif defined(HFLOAT16) +#define GEMM_BETA SHGEMM_BETA +#define GEMM_KERNEL_N SHGEMM_KERNEL +#define GEMM_KERNEL_L SHGEMM_KERNEL +#define GEMM_KERNEL_R SHGEMM_KERNEL +#define GEMM_KERNEL_B SHGEMM_KERNEL +#define GEMM_NN SHGEMM_NN +#define GEMM_CN SHGEMM_TN +#define GEMM_TN SHGEMM_TN +#define GEMM_NC SHGEMM_NT +#define GEMM_NT SHGEMM_NT +#define GEMM_CC SHGEMM_TT +#define GEMM_CT SHGEMM_TT +#define GEMM_TC SHGEMM_TT +#define GEMM_TT SHGEMM_TT +#define GEMM_NR SHGEMM_NN +#define GEMM_TR SHGEMM_TN +#define GEMM_CR SHGEMM_TN +#define GEMM_RN SHGEMM_NN +#define GEMM_RT SHGEMM_NT +#define GEMM_RC SHGEMM_NT +#define GEMM_RR SHGEMM_NN +#define GEMM_ONCOPY SHGEMM_ONCOPY +#define GEMM_OTCOPY SHGEMM_OTCOPY +#define GEMM_INCOPY SHGEMM_INCOPY +#define GEMM_ITCOPY SHGEMM_ITCOPY + +#define GEMM_THREAD_NN SHGEMM_THREAD_NN +#define GEMM_THREAD_CN SHGEMM_THREAD_TN +#define GEMM_THREAD_TN SHGEMM_THREAD_TN +#define GEMM_THREAD_NC SHGEMM_THREAD_NT +#define GEMM_THREAD_NT SHGEMM_THREAD_NT +#define GEMM_THREAD_CC SHGEMM_THREAD_TT +#define GEMM_THREAD_CT SHGEMM_THREAD_TT +#define GEMM_THREAD_TC SHGEMM_THREAD_TT +#define GEMM_THREAD_TT SHGEMM_THREAD_TT +#define GEMM_THREAD_NR SHGEMM_THREAD_NN +#define GEMM_THREAD_TR SHGEMM_THREAD_TN +#define GEMM_THREAD_CR SHGEMM_THREAD_TN +#define GEMM_THREAD_RN SHGEMM_THREAD_NN +#define GEMM_THREAD_RT SHGEMM_THREAD_NT +#define GEMM_THREAD_RC SHGEMM_THREAD_NT +#define GEMM_THREAD_RR SHGEMM_THREAD_NN + #elif defined(BFLOAT16) diff --git a/common_param.h b/common_param.h index 2d771a27d..f82b73a72 100644 --- a/common_param.h +++ b/common_param.h @@ -48,6 +48,21 @@ typedef struct { int dtb_entries; int switch_ratio; int offsetA, offsetB, align; +#if BUILD_HFLOAT16 == 1 +int shgemm_p, shgemm_q, shgemm_r; +int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; + +int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG); +int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG); + +int (*shgemm_incopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); +int (*shgemm_itcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); +int (*shgemm_oncopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); +int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); + + +#endif + #if BUILD_BFLOAT16 == 1 int sbgemm_p, sbgemm_q, sbgemm_r; @@ -64,10 +79,10 @@ typedef struct { float (*sbamin_k) (BLASLONG, float *, BLASLONG); float (*sbmax_k) (BLASLONG, float *, BLASLONG); float (*sbmin_k) (BLASLONG, float *, BLASLONG); -BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG); -BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG); -BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG); -BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG); + BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG); + BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG); + BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG); + BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG); float (*sbnrm2_k) (BLASLONG, float *, BLASLONG); float (*sbasum_k) (BLASLONG, float *, BLASLONG); @@ -180,12 +195,12 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG); #endif #if (BUILD_SINGLE==1) || (BUILD_DOUBLE ==1) || (BUILD_COMPLEX==1) -BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG); + BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG); #endif #if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1) -BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG); -BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG); -BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); + BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG); + BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG); + BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); float (*snrm2_k) (BLASLONG, float *, BLASLONG); float (*sasum_k) (BLASLONG, float *, BLASLONG); #endif @@ -316,10 +331,10 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); double (*damin_k) (BLASLONG, double *, BLASLONG); double (*dmax_k) (BLASLONG, double *, BLASLONG); double (*dmin_k) (BLASLONG, double *, BLASLONG); -BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG); -BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG); -BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG); -BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); + BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG); + BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG); + BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG); + BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); double (*dnrm2_k) (BLASLONG, double *, BLASLONG); double (*dasum_k) (BLASLONG, double *, BLASLONG); @@ -435,10 +450,10 @@ BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); xdouble (*qamin_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qmax_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qmin_k) (BLASLONG, xdouble *, BLASLONG); -BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG); -BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG); -BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG); -BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG); + BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG); + BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG); + BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG); + BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qnrm2_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qasum_k) (BLASLONG, xdouble *, BLASLONG); @@ -528,8 +543,8 @@ BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG); float (*camax_k) (BLASLONG, float *, BLASLONG); float (*camin_k) (BLASLONG, float *, BLASLONG); -BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG); -BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG); + BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG); + BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG); float (*cnrm2_k) (BLASLONG, float *, BLASLONG); float (*casum_k) (BLASLONG, float *, BLASLONG); @@ -739,8 +754,8 @@ BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG); double (*zamax_k) (BLASLONG, double *, BLASLONG); double (*zamin_k) (BLASLONG, double *, BLASLONG); -BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG); -BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG); + BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG); + BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG); double (*znrm2_k) (BLASLONG, double *, BLASLONG); double (*zasum_k) (BLASLONG, double *, BLASLONG); @@ -950,8 +965,8 @@ BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG); xdouble (*xamax_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*xamin_k) (BLASLONG, xdouble *, BLASLONG); -BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG); -BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG); + BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG); + BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG); xdouble (*xnrm2_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*xasum_k) (BLASLONG, xdouble *, BLASLONG); @@ -1229,6 +1244,15 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 gotoblas -> exclusive_cache +#if (BUILD_HFLOAT16==1) +#define SHGEMM_P gotoblas -> shgemm_p +#define SHGEMM_Q gotoblas -> shgemm_q +#define SHGEMM_R gotoblas -> shgemm_r +#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m +#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n +#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn +#endif + #if (BUILD_BFLOAT16==1) #define SBGEMM_P gotoblas -> sbgemm_p #define SBGEMM_Q gotoblas -> sbgemm_q @@ -1357,6 +1381,19 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 0 #endif +#if (BUILD_HFLOAT16 == 1) +#define SHGEMM_P SHGEMM_DEFAULT_P +#define SHGEMM_Q SHGEMM_DEFAULT_Q +#define SHGEMM_R SHGEMM_DEFAULT_R +#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M +#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N +#ifdef SHGEMM_DEFAULT_UNROLL_MN +#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN +#else +#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N)) +#endif +#endif + #if (BUILD_BFLOAT16 == 1) #define SBGEMM_P SBGEMM_DEFAULT_P #define SBGEMM_Q SBGEMM_DEFAULT_Q @@ -1478,6 +1515,7 @@ extern gotoblas_t *gotoblas; #endif + #endif #ifndef COMPLEX @@ -1505,6 +1543,18 @@ extern gotoblas_t *gotoblas; #define GEMM_DEFAULT_R DGEMM_DEFAULT_R #define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M #define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N +#elif defined(HFLOAT16) +#define GEMM_P SHGEMM_P +#define GEMM_Q SHGEMM_Q +#define GEMM_R SHGEMM_R +#define GEMM_UNROLL_M SHGEMM_UNROLL_M +#define GEMM_UNROLL_N SHGEMM_UNROLL_N +#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN +#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P +#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q +#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R +#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M +#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N #elif defined(BFLOAT16) #define GEMM_P SBGEMM_P #define GEMM_Q SBGEMM_Q diff --git a/common_sh.h b/common_sh.h new file mode 100644 index 000000000..69734d1dc --- /dev/null +++ b/common_sh.h @@ -0,0 +1,72 @@ +#ifndef COMMON_SH_H +#define COMMON_SH_H + +#ifndef DYNAMIC_ARCH + +#define SHGEMM_ONCOPY shgemm_oncopy +#define SHGEMM_OTCOPY shgemm_otcopy + +#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N +#define SHGEMM_INCOPY shgemm_oncopy +#define SHGEMM_ITCOPY shgemm_otcopy +#else +#define SHGEMM_INCOPY shgemm_incopy +#define SHGEMM_ITCOPY shgemm_itcopy +#endif + +#define SHGEMM_BETA shgemm_beta +#define SHGEMM_KERNEL shgemm_kernel + + +#else // #DYNAMIC_ARCH + +#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy +#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy +#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N +#define SHGEMM_INCOPY gotoblas -> shgemm_oncopy +#define SHGEMM_ITCOPY gotoblas -> shgemm_otcopy +#else +#define SHGEMM_INCOPY gotoblas -> shgemm_incopy +#define SHGEMM_ITCOPY gotoblas -> shgemm_itcopy +#endif + +#define SHGEMM_BETA gotoblas -> shgemm_beta +#define SHGEMM_KERNEL gotoblas -> shgemm_kernel +#endif // #DYNAMIC_ARCH + +#define SHGEMM_NN shgemm_nn +#define SHGEMM_CN shgemm_tn +#define SHGEMM_TN shgemm_tn +#define SHGEMM_NC shgemm_nt +#define SHGEMM_NT shgemm_nt +#define SHGEMM_CC shgemm_tt +#define SHGEMM_CT shgemm_tt +#define SHGEMM_TC shgemm_tt +#define SHGEMM_TT shgemm_tt +#define SHGEMM_NR shgemm_nn +#define SHGEMM_TR shgemm_tn +#define SHGEMM_CR shgemm_tn +#define SHGEMM_RN shgemm_nn +#define SHGEMM_RT shgemm_nt +#define SHGEMM_RC shgemm_nt +#define SHGEMM_RR shgemm_nn + +#define SHGEMM_THREAD_NN shgemm_thread_nn +#define SHGEMM_THREAD_CN shgemm_thread_tn +#define SHGEMM_THREAD_TN shgemm_thread_tn +#define SHGEMM_THREAD_NC shgemm_thread_nt +#define SHGEMM_THREAD_NT shgemm_thread_nt +#define SHGEMM_THREAD_CC shgemm_thread_tt +#define SHGEMM_THREAD_CT shgemm_thread_tt +#define SHGEMM_THREAD_TC shgemm_thread_tt +#define SHGEMM_THREAD_TT shgemm_thread_tt +#define SHGEMM_THREAD_NR shgemm_thread_nn +#define SHGEMM_THREAD_TR shgemm_thread_tn +#define SHGEMM_THREAD_CR shgemm_thread_tn +#define SHGEMM_THREAD_RN shgemm_thread_nn +#define SHGEMM_THREAD_RT shgemm_thread_nt +#define SHGEMM_THREAD_RC shgemm_thread_nt +#define SHGEMM_THREAD_RR shgemm_thread_nn + + +#endif // #COMMON_SH_H \ No newline at end of file diff --git a/openblas_config_template.h b/openblas_config_template.h index 6a7382108..5a7c01a24 100644 --- a/openblas_config_template.h +++ b/openblas_config_template.h @@ -39,6 +39,11 @@ typedef unsigned long BLASULONG; typedef uint16_t bfloat16; #endif +#ifndef HFLOAT16 +#include +typedef uint16_t hfloat16; +#endif + #ifdef OPENBLAS_USE64BITINT typedef BLASLONG blasint; #else diff --git a/param.h b/param.h index 48b64fd2a..220f847e9 100644 --- a/param.h +++ b/param.h @@ -72,6 +72,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef PARAM_H #define PARAM_H +#define SHGEMM_DEFAULT_UNROLL_N 8 +#define SHGEMM_DEFAULT_UNROLL_M 8 +#define SHGEMM_DEFAULT_P 128 +#define SHGEMM_DEFAULT_R 240 +#define SHGEMM_DEFAULT_Q 12288 #define SBGEMM_DEFAULT_UNROLL_N 4 #define SBGEMM_DEFAULT_UNROLL_M 8 @@ -3138,10 +3143,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #endif #ifdef RISCV64_ZVL128B + #define GEMM_DEFAULT_OFFSET_A 0 #define GEMM_DEFAULT_OFFSET_B 0 #define GEMM_DEFAULT_ALIGN (BLASLONG)0x03fffUL +#undef SHGEMM_DEFAULT_UNROLL_M +#undef SHGEMM_DEFAULT_UNROLL_N +#define SHGEMM_DEFAULT_UNROLL_M 8 +#define SHGEMM_DEFAULT_UNROLL_N 8 + #define SGEMM_DEFAULT_UNROLL_M 8 #define SGEMM_DEFAULT_UNROLL_N 8 @@ -3154,16 +3165,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define ZGEMM_DEFAULT_UNROLL_M 4 #define ZGEMM_DEFAULT_UNROLL_N 4 +#undef SHGEMM_DEFAULT_P +#define SHGEMM_DEFAULT_P 128 #define SGEMM_DEFAULT_P 128 #define DGEMM_DEFAULT_P 128 #define CGEMM_DEFAULT_P 96 #define ZGEMM_DEFAULT_P 64 +#undef SHGEMM_DEFAULT_Q +#define SHGEMM_DEFAULT_Q 240 #define SGEMM_DEFAULT_Q 240 #define DGEMM_DEFAULT_Q 120 #define CGEMM_DEFAULT_Q 120 #define ZGEMM_DEFAULT_Q 120 +#undef SHGEMM_DEFAULT_R +#define SHGEMM_DEFAULT_R 12288 #define SGEMM_DEFAULT_R 12288 #define DGEMM_DEFAULT_R 8192 #define CGEMM_DEFAULT_R 4096 @@ -3181,6 +3198,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define GEMM_DEFAULT_OFFSET_B 0 #define GEMM_DEFAULT_ALIGN 0x03fffUL +#undef SHGEMM_DEFAULT_UNROLL_M +#undef SHGEMM_DEFAULT_UNROLL_N +#define SHGEMM_DEFAULT_UNROLL_M 16 +#define SHGEMM_DEFAULT_UNROLL_N 8 + #define SGEMM_DEFAULT_UNROLL_M 16 #define SGEMM_DEFAULT_UNROLL_N 8 @@ -3193,16 +3215,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define ZGEMM_DEFAULT_UNROLL_M 8 #define ZGEMM_DEFAULT_UNROLL_N 4 +#undef SHGEMM_DEFAULT_P +#define SHGEMM_DEFAULT_P 128 #define SGEMM_DEFAULT_P 128 #define DGEMM_DEFAULT_P 64 #define CGEMM_DEFAULT_P 64 #define ZGEMM_DEFAULT_P 64 +#undef SHGEMM_DEFAULT_Q +#define SHGEMM_DEFAULT_Q 128 #define SGEMM_DEFAULT_Q 128 #define DGEMM_DEFAULT_Q 128 #define CGEMM_DEFAULT_Q 128 #define ZGEMM_DEFAULT_Q 64 +#undef SHGEMM_DEFAULT_R +#define SHGEMM_DEFAULT_R 16384 #define SGEMM_DEFAULT_R 16384 #define DGEMM_DEFAULT_R 8192 #define CGEMM_DEFAULT_R 8192