mirror of
https://github.com/OpenMathLib/OpenBLAS
synced 2026-05-31 00:45:48 +08:00
Add Infrastructure for SHGEMV
This adds all the relevant bits and pieces to add a `shgemv` path as well as a future `hgemm`/`hgemv` path in a similar model to `sb` and `b` interfaces. I've also fixed a few bits and pieces around `shgemm` which didn't build in a few situations.
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -80,6 +80,7 @@ test/SBLAT3_3M.SUMM
|
||||
test/ZBLAT2.SUMM
|
||||
test/ZBLAT3.SUMM
|
||||
test/ZBLAT3_3M.SUMM
|
||||
test/SHBLAT2.SUMM
|
||||
test/SHBLAT3.SUMM
|
||||
test/SBBLAT2.SUMM
|
||||
test/SBBLAT3.SUMM
|
||||
@@ -98,6 +99,7 @@ test/sblat2
|
||||
test/sblat3
|
||||
test/sblat3_3m
|
||||
test/test_shgemm
|
||||
test/test_shgemv
|
||||
test/test_sbgemm
|
||||
test/test_sbgemv
|
||||
test/test_bgemm
|
||||
|
||||
@@ -175,6 +175,10 @@ if (BUILD_BFLOAT16)
|
||||
SetFallback(SBGEMVNKERNEL ../x86_64/sbgemv_n.c)
|
||||
SetFallback(SBGEMVTKERNEL ../x86_64/sbgemv_t.c)
|
||||
endif ()
|
||||
if (BUILD_HFLOAT16)
|
||||
SetFallback(SHGEMVNKERNEL ../generic/gemv_n.c)
|
||||
SetFallback(SHGEMVTKERNEL ../generic/gemv_t.c)
|
||||
endif ()
|
||||
endmacro ()
|
||||
|
||||
macro(SetDefaultL2)
|
||||
@@ -226,6 +230,8 @@ macro(SetDefaultL2)
|
||||
if (BUILD_BFLOAT16)
|
||||
SetFallback(BGEMVNKERNEL ../generic/gemv_n.c)
|
||||
SetFallback(BGEMVTKERNEL ../generic/gemv_t.c)
|
||||
SetFallback(SHGEMVNKERNEL ../generic/gemv_n.c)
|
||||
SetFallback(SHGEMVTKERNEL ../generic/gemv_t.c)
|
||||
SetFallback(SBGEMVNKERNEL ../x86_64/sbgemv_n.c)
|
||||
SetFallback(SBGEMVTKERNEL ../x86_64/sbgemv_t.c)
|
||||
SetFallback(SHGERKERNEL ../generic/ger.c)
|
||||
@@ -260,5 +266,16 @@ if (BUILD_BFLOAT16)
|
||||
SetFallback(SBGEMMONCOPYOBJ sbgemm_oncopy.o)
|
||||
SetFallback(SBGEMMOTCOPYOBJ sbgemm_otcopy.o)
|
||||
endif ()
|
||||
|
||||
if (BUILD_HFLOAT16)
|
||||
SetFallback(SHGEMMKERNEL ../generic/gemmkernel_2x2.c)
|
||||
SetFallback(SHGEMM_BETA ../generic/gemm_beta.c)
|
||||
SetFallback(SHGEMMINCOPY ../generic/gemm_ncopy_2.c)
|
||||
SetFallback(SHGEMMITCOPY ../generic/gemm_tcopy_2.c)
|
||||
SetFallback(SHGEMMONCOPY ../generic/gemm_ncopy_2.c)
|
||||
SetFallback(SHGEMMOTCOPY ../generic/gemm_tcopy_2.c)
|
||||
SetFallback(SHGEMMINCOPYOBJ shgemm_incopy.o)
|
||||
SetFallback(SHGEMMITCOPYOBJ shgemm_itcopy.o)
|
||||
SetFallback(SHGEMMONCOPYOBJ shgemm_oncopy.o)
|
||||
SetFallback(SHGEMMOTCOPYOBJ shgemm_otcopy.o)
|
||||
endif ()
|
||||
endmacro ()
|
||||
|
||||
@@ -375,9 +375,12 @@ function(GenerateNamedObjects sources_in)
|
||||
if (NOT no_float_type)
|
||||
string(SUBSTRING ${float_type} 0 1 float_char)
|
||||
string(TOLOWER ${float_char} float_char)
|
||||
if (${float_type} STREQUAL "BFLOAT16" AND NOT "${defines_in}" MATCHES "BGEM")
|
||||
set (float_char "sb")
|
||||
endif ()
|
||||
if (${float_type} STREQUAL "BFLOAT16" AND NOT "${defines_in}" MATCHES "BGEM")
|
||||
set (float_char "sb")
|
||||
endif ()
|
||||
if (${float_type} STREQUAL "HFLOAT16" AND NOT "${defines_in}" MATCHES "HGEM")
|
||||
set (float_char "sh")
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
if (NOT name_in)
|
||||
|
||||
@@ -261,6 +261,8 @@ void BLASFUNC(bgemv)(char *, blasint *, blasint *, bfloat16 *, bfloat16 *, blas
|
||||
bfloat16 *, blasint *, bfloat16 *, bfloat16 *, blasint *);
|
||||
void BLASFUNC(sbgemv)(char *, blasint *, blasint *, float *, bfloat16 *, blasint *,
|
||||
bfloat16 *, blasint *, float *, float *, blasint *);
|
||||
void BLASFUNC(shgemv)(char *, blasint *, blasint *, float *, hfloat16 *, blasint *,
|
||||
hfloat16 *, blasint *, float *, float *, blasint *);
|
||||
void BLASFUNC(sgemv)(char *, blasint *, blasint *, float *, float *, blasint *,
|
||||
float *, blasint *, float *, float *, blasint *);
|
||||
void BLASFUNC(dgemv)(char *, blasint *, blasint *, double *, double *, blasint *,
|
||||
|
||||
@@ -54,6 +54,10 @@ int sbgemv_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLO
|
||||
int sbgemv_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
|
||||
int sbgemv_thread_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int);
|
||||
int sbgemv_thread_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int);
|
||||
int shgemv_n(BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG);
|
||||
int shgemv_t(BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG);
|
||||
int shgemv_thread_n(BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG, int);
|
||||
int shgemv_thread_t(BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG, int);
|
||||
int sger_k (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
|
||||
int dger_k (BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, double *, BLASLONG, double *);
|
||||
int qger_k (BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *);
|
||||
|
||||
@@ -703,6 +703,9 @@
|
||||
#define GEMM_THREAD_RC SHGEMM_THREAD_NT
|
||||
#define GEMM_THREAD_RR SHGEMM_THREAD_NN
|
||||
|
||||
#define SCAL_K SSCAL_K
|
||||
#define GEMV_N SHGEMV_N_K
|
||||
#define GEMV_T SHGEMV_T_K
|
||||
|
||||
#elif defined(BFLOAT16) && defined(BGEMM)
|
||||
#define SCAL_K BSCAL_K
|
||||
|
||||
@@ -60,7 +60,8 @@ int (*shgemm_itcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
|
||||
int (*shgemm_oncopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
|
||||
int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
|
||||
|
||||
|
||||
int (*shgemv_n) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG);
|
||||
int (*shgemv_t) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG);
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
35
common_sh.h
35
common_sh.h
@@ -1,3 +1,31 @@
|
||||
/***************************************************************************
|
||||
* Copyright (c) 2025, The OpenBLAS Project
|
||||
* All rights reserved.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
* 1. Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* 2. Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in
|
||||
* the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
* 3. Neither the name of the OpenBLAS project nor the names of
|
||||
* its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
* POSSIBILITY OF SUCH DAMAGE.
|
||||
* *****************************************************************************/
|
||||
|
||||
#ifndef COMMON_SH_H
|
||||
#define COMMON_SH_H
|
||||
|
||||
@@ -17,6 +45,9 @@
|
||||
#define SHGEMM_BETA shgemm_beta
|
||||
#define SHGEMM_KERNEL shgemm_kernel
|
||||
|
||||
#define SHGEMV_N_K shgemv_n
|
||||
#define SHGEMV_T_K shgemv_t
|
||||
|
||||
|
||||
#else // #DYNAMIC_ARCH
|
||||
|
||||
@@ -32,6 +63,10 @@
|
||||
|
||||
#define SHGEMM_BETA gotoblas -> shgemm_beta
|
||||
#define SHGEMM_KERNEL gotoblas -> shgemm_kernel
|
||||
|
||||
#define SHGEMV_N_K gotoblas->shgemv_n
|
||||
#define SHGEMV_T_K gotoblas->shgemv_t
|
||||
|
||||
#endif // #DYNAMIC_ARCH
|
||||
|
||||
#define SHGEMM_NN shgemm_nn
|
||||
|
||||
@@ -450,6 +450,12 @@ XBLASOBJS += \
|
||||
xtbmv_thread_CUU.$(SUFFIX) xtbmv_thread_CUN.$(SUFFIX) \
|
||||
xtbmv_thread_CLU.$(SUFFIX) xtbmv_thread_CLN.$(SUFFIX)
|
||||
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
SHBLASOBJS += \
|
||||
shgemv_thread_n$(TSUFFIX).$(SUFFIX) \
|
||||
shgemv_thread_t$(TSUFFIX).$(SUFFIX)
|
||||
endif
|
||||
ifeq ($(BUILD_BFLOAT16),1)
|
||||
BBLASOBJS += \
|
||||
bgemv_thread_n$(TSUFFIX).$(SUFFIX) \
|
||||
@@ -3737,6 +3743,13 @@ xtrsv_CUU.$(SUFFIX) xtrsv_CUU.$(PSUFFIX) : ztrsv_L.c ../../param.h
|
||||
xtrsv_CUN.$(SUFFIX) xtrsv_CUN.$(PSUFFIX) : ztrsv_L.c ../../param.h
|
||||
$(CC) -c $(CFLAGS) -DXDOUBLE -DCOMPLEX -DTRANSA=4 -UUNIT $< -o $(@F)
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
shgemv_thread_n.$(SUFFIX) shgemv_thread_n.$(PSUFFIX) : sbgemv_thread.c ../../common.h
|
||||
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -UTRANSA -UCONJ -UXCONJ $< -o $(@F)
|
||||
shgemv_thread_t.$(SUFFIX) shgemv_thread_t.$(PSUFFIX) : sbgemv_thread.c ../../common.h
|
||||
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -DTRANSA -UCONJ -UXCONJ $< -o $(@F)
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_BFLOAT16),1)
|
||||
bgemv_thread_n.$(SUFFIX) bgemv_thread_n.$(PSUFFIX) : sbgemv_thread.c ../../common.h
|
||||
$(CC) -c $(CFLAGS) -DBGEMM -UCOMPLEX -UDOUBLE -UTRANSA -UCONJ -UXCONJ $< -o $(@F)
|
||||
|
||||
@@ -80,7 +80,7 @@ blasobjsz="
|
||||
|
||||
blasobjs="lsame xerbla"
|
||||
bfblasobjs="bgemm bgemv sbgemm sbgemmt sbgemmtr sbgemv sbdot sbstobf16 sbdtobf16 sbf16tos dbf16tod"
|
||||
hfblasobjs="shgemm"
|
||||
hfblasobjs="shgemm shgemv"
|
||||
cblasobjsc="
|
||||
cblas_caxpy cblas_ccopy cblas_cdotc cblas_cdotu cblas_cgbmv cblas_cgemm cblas_cgemv
|
||||
cblas_cgerc cblas_cgeru cblas_chbmv cblas_chemm cblas_chemv cblas_cher2 cblas_cher2k
|
||||
|
||||
@@ -80,7 +80,7 @@
|
||||
|
||||
@blasobjs = (lsame, xerbla);
|
||||
@bfblasobjs = (bgemm, bgemv, sbgemm, sbgemmt, sbgemmtr, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod);
|
||||
@hfblasobjs = (shgemm);
|
||||
@hfblasobjs = (shgemm, shgemv);
|
||||
@cblasobjsc = (
|
||||
cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv,
|
||||
cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k,
|
||||
|
||||
@@ -166,6 +166,7 @@ if (BUILD_BFLOAT16)
|
||||
endif ()
|
||||
if (BUILD_HFLOAT16)
|
||||
GenerateNamedObjects("gemm.c" "" "shgemm" ${CBLAS_FLAG} "" "" true "HFLOAT16")
|
||||
GenerateNamedObjects("sbgemv.c" "" "shgemv" ${CBLAS_FLAG} "" "" true "HFLOAT16")
|
||||
endif ()
|
||||
|
||||
# complex-specific sources
|
||||
|
||||
@@ -87,6 +87,7 @@ endif
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
SHBLAS3OBJS = shgemm.$(SUFFIX)
|
||||
SHBLAS2OBJS = shgemv.$(SUFFIX)
|
||||
endif
|
||||
|
||||
DBLAS1OBJS = \
|
||||
@@ -338,6 +339,7 @@ endif
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX)
|
||||
CSHBLAS2OBJS = cblas_shgemv.$(SUFFIX)
|
||||
endif
|
||||
|
||||
CDBLAS1OBJS = \
|
||||
@@ -441,6 +443,7 @@ SBBLAS1OBJS += $(CSBBLAS1OBJS)
|
||||
SBBLAS2OBJS += $(CSBBLAS2OBJS)
|
||||
SBBLAS3OBJS += $(CSBBLAS3OBJS)
|
||||
SHBLAS3OBJS += $(CSHBLAS3OBJS)
|
||||
SHBLAS2OBJS += $(CSHBLAS2OBJS)
|
||||
DBLAS1OBJS += $(CDBLAS1OBJS)
|
||||
DBLAS2OBJS += $(CDBLAS2OBJS)
|
||||
DBLAS3OBJS += $(CDBLAS3OBJS)
|
||||
@@ -459,7 +462,7 @@ endif
|
||||
BBLASOBJS = $(BBLAS3OBJS) $(BBLAS2OBJS) $(BBLAS1OBJS)
|
||||
SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS)
|
||||
SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS)
|
||||
SHBLASOBJS = $(SHBLAS3OBJS)
|
||||
SHBLASOBJS = $(SHBLAS3OBJS) $(SHBLAS2OBJS)
|
||||
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS)
|
||||
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS)
|
||||
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS)
|
||||
@@ -602,7 +605,7 @@ clean ::
|
||||
level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS)
|
||||
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
|
||||
|
||||
level2 : $(SBBLAS2OBJS) $(BBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
|
||||
level2 : $(SBBLAS2OBJS) $(BBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) $(SHBLAS2OBJS)
|
||||
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
|
||||
|
||||
level3 : $(SBBLAS3OBJS) $(BBLAS3OBJ) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(SHBLAS3OBJS)
|
||||
@@ -1002,6 +1005,11 @@ sbgemv.$(SUFFIX) sbgemv.$(PSUFFIX) : sbgemv.c
|
||||
$(CC) $(CFLAGS) -c $< -o $(@F)
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
shgemv.$(SUFFIX) shgemv.$(PSUFFIX) : sbgemv.c
|
||||
$(CC) $(CFLAGS) -c $< -o $(@F)
|
||||
endif
|
||||
|
||||
ifndef USE_NETLIB_GEMV
|
||||
sgemv.$(SUFFIX) sgemv.$(PSUFFIX): gemv.c
|
||||
$(CC) -c $(CFLAGS) -o $(@F) $<
|
||||
@@ -1832,6 +1840,11 @@ cblas_sbgemv.$(SUFFIX) cblas_sbgemv.$(PSUFFIX) : sbgemv.c
|
||||
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
cblas_shgemv.$(SUFFIX) cblas_shgemv.$(PSUFFIX) : sbgemv.c
|
||||
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
|
||||
endif
|
||||
|
||||
cblas_sgemv.$(SUFFIX) cblas_sgemv.$(PSUFFIX): gemv.c
|
||||
$(CC) -DCBLAS -c $(CFLAGS) -o $(@F) $<
|
||||
|
||||
|
||||
@@ -587,7 +587,10 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
|
||||
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
|
||||
#endif
|
||||
|
||||
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && !defined(HFLOAT16) && (!defined(BFLOAT16) || (!defined(BGEMM) && defined(SBGEMM_GEMV_FORWARD)) || (defined(BGEMM) && defined(BGEMM_GEMV_FORWARD)))
|
||||
#define BFLOAT16_GEMM_GEMV_FORWARD (!defined(BFLOAT16) || (!defined(BGEMM) && defined(SBGEMM_GEMV_FORWARD)) || (defined(BGEMM) && defined(BGEMM_GEMV_FORWARD)))
|
||||
#define HFLOAT16_GEMM_GEMV_FORWARD (!defined(HFLOAT16) || (!defined(HGEMM) && defined(SHGEMM_GEMV_FORWARD)) || (defined(HGEMM) && defined(HGEMM_GEMV_FORWARD)))
|
||||
|
||||
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && HFLOAT16_GEMM_GEMV_FORWARD && BFLOAT16_GEMM_GEMV_FORWARD
|
||||
#if defined(ARCH_ARM64)
|
||||
// The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c}
|
||||
// perform poorly in certain circumstances. We use the following boolean
|
||||
|
||||
@@ -48,6 +48,10 @@
|
||||
#define GEMV_THREAD_N bgemv_thread_n
|
||||
#define GEMV_THREAD_T bgemv_thread_t
|
||||
#define ERROR_NAME "BGEMV "
|
||||
#elif defined(HFLOAT16)
|
||||
#define GEMV_THREAD_N shgemv_thread_n
|
||||
#define GEMV_THREAD_T shgemv_thread_t
|
||||
#define ERROR_NAME "SHGEMV "
|
||||
#else
|
||||
#define GEMV_THREAD_N sbgemv_thread_n
|
||||
#define GEMV_THREAD_T sbgemv_thread_t
|
||||
|
||||
@@ -228,6 +228,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
|
||||
GenerateNamedObjects("${KERNELDIR}/${SBGEMVNKERNEL}" "" "gemv_n" false "" "" false "BFLOAT16")
|
||||
GenerateNamedObjects("${KERNELDIR}/${SBGEMVTKERNEL}" "" "gemv_t" false "" "" false "BFLOAT16")
|
||||
endif ()
|
||||
if (BUILD_HFLOAT16)
|
||||
GenerateNamedObjects("${KERNELDIR}/${SHGEMVNKERNEL}" "" "gemv_n" false "" "" false "HFLOAT16")
|
||||
GenerateNamedObjects("${KERNELDIR}/${SHGEMVTKERNEL}" "" "gemv_t" false "" "" false "HFLOAT16")
|
||||
endif ()
|
||||
# Makefile.L3
|
||||
set(USE_TRMM false)
|
||||
string(TOUPPER ${TARGET_CORE} UC_TARGET_CORE)
|
||||
|
||||
@@ -101,6 +101,16 @@ SBGEMVTKERNEL = ../x86_64/sbgemv_t.c
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
ifndef SHGEMVNKERNEL
|
||||
SHGEMVNKERNEL = ../generic/gemv_n.c
|
||||
endif
|
||||
|
||||
ifndef SHGEMVTKERNEL
|
||||
SHGEMVTKERNEL = ../generic/gemv_t.c
|
||||
endif
|
||||
endif
|
||||
|
||||
### GER ###
|
||||
|
||||
ifndef SGERKERNEL
|
||||
@@ -299,6 +309,12 @@ SBBLASOBJS += \
|
||||
sbgemv_t$(TSUFFIX).$(SUFFIX)
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
SHBLASOBJS += \
|
||||
shgemv_n$(TSUFFIX).$(SUFFIX) \
|
||||
shgemv_t$(TSUFFIX).$(SUFFIX)
|
||||
endif
|
||||
|
||||
ifneq "$(or $(BUILD_SINGLE), $(BUILD_DOUBLE), $(BUILD_COMPLEX))" ""
|
||||
$(KDIR)sgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)sgemv_n$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMVNKERNEL) $(TOPDIR)/common.h $(GEMVDEP)
|
||||
$(CC) -c $(CFLAGS) -UDOUBLE -UCOMPLEX -UTRANS $< -o $@
|
||||
@@ -558,3 +574,10 @@ $(KDIR)bgemv_t$(TSUFFIX).$(SUFFIX) $(KDIR)bgemv_t$(TPSUFFIX).$(PSUFFIX) : $(KERN
|
||||
$(CC) -c $(CFLAGS) -DBGEMM -UCOMPLEX $< -o $@
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
$(KDIR)shgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)shgemv_n$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMVNKERNEL)
|
||||
$(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@
|
||||
$(KDIR)shgemv_t$(TSUFFIX).$(SUFFIX) $(KDIR)shgemv_t$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMVTKERNEL)
|
||||
$(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@
|
||||
endif
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
* *****************************************************************************/
|
||||
|
||||
#if defined(BFLOAT16) && defined(BFLOAT16CONVERSION)
|
||||
|
||||
static float
|
||||
bfloat16tof32 (bfloat16 value)
|
||||
{
|
||||
@@ -48,17 +49,34 @@ static bfloat16 f32tobfloat16(float value) {
|
||||
#ifdef BGEMM
|
||||
#define ALPHA bfloat16tof32(alpha)
|
||||
#define BETA bfloat16tof32(beta)
|
||||
#define BF16TOF32(x) (bfloat16tof32(x))
|
||||
#define F32TOBF16(x) (f32tobfloat16(x))
|
||||
#define TO_F32(x) (bfloat16tof32(x))
|
||||
#define TO_OUTPUT(x) (f32tobfloat16(x))
|
||||
#else
|
||||
#define ALPHA alpha
|
||||
#define BETA beta
|
||||
#define BF16TOF32(x) (bfloat16tof32(x))
|
||||
#define F32TOBF16(x) x
|
||||
#define TO_F32(x) (bfloat16tof32(x))
|
||||
#define TO_OUTPUT(x) x
|
||||
#endif
|
||||
|
||||
#elif defined(HFLOAT16)
|
||||
|
||||
#ifdef HGEMM
|
||||
#define ALPHA (float)(alpha)
|
||||
#define BETA (float)(beta)
|
||||
#define TO_F32(x) ((float)(x))
|
||||
#define TO_OUTPUT(x) ((_Float16)(x))
|
||||
#else
|
||||
#define ALPHA alpha
|
||||
#define BETA beta
|
||||
#define BF16TOF32(x) x
|
||||
#define F32TOBF16(x) x
|
||||
#define TO_F32(x) ((float)(x))
|
||||
#define TO_OUTPUT(x) x
|
||||
#endif
|
||||
|
||||
#else
|
||||
|
||||
#define ALPHA alpha
|
||||
#define BETA beta
|
||||
#define TO_F32(x) x
|
||||
#define TO_OUTPUT(x) x
|
||||
|
||||
#endif
|
||||
@@ -27,7 +27,8 @@
|
||||
* *****************************************************************************/
|
||||
|
||||
#include "common.h"
|
||||
#include "bf16_macros.h"
|
||||
|
||||
#include "conversion_macros.h"
|
||||
|
||||
int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc
|
||||
#ifdef TRMMKERNEL
|
||||
@@ -60,36 +61,36 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
|
||||
{
|
||||
load0 = ptrba[2*0+0];
|
||||
load1 = ptrbb[2*0+0];
|
||||
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
|
||||
res0 = res0+TO_F32(load0)*TO_F32(load1);
|
||||
load2 = ptrba[2*0+1];
|
||||
res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
|
||||
res1 = res1+TO_F32(load2)*TO_F32(load1);
|
||||
load3 = ptrbb[2*0+1];
|
||||
res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
|
||||
res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
|
||||
res2 = res2+TO_F32(load0)*TO_F32(load3);
|
||||
res3 = res3+TO_F32(load2)*TO_F32(load3);
|
||||
load4 = ptrba[2*1+0];
|
||||
load5 = ptrbb[2*1+0];
|
||||
res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
|
||||
res0 = res0+TO_F32(load4)*TO_F32(load5);
|
||||
load6 = ptrba[2*1+1];
|
||||
res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
|
||||
res1 = res1+TO_F32(load6)*TO_F32(load5);
|
||||
load7 = ptrbb[2*1+1];
|
||||
res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
|
||||
res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
|
||||
res2 = res2+TO_F32(load4)*TO_F32(load7);
|
||||
res3 = res3+TO_F32(load6)*TO_F32(load7);
|
||||
load0 = ptrba[2*2+0];
|
||||
load1 = ptrbb[2*2+0];
|
||||
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
|
||||
res0 = res0+TO_F32(load0)*TO_F32(load1);
|
||||
load2 = ptrba[2*2+1];
|
||||
res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
|
||||
res1 = res1+TO_F32(load2)*TO_F32(load1);
|
||||
load3 = ptrbb[2*2+1];
|
||||
res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
|
||||
res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
|
||||
res2 = res2+TO_F32(load0)*TO_F32(load3);
|
||||
res3 = res3+TO_F32(load2)*TO_F32(load3);
|
||||
load4 = ptrba[2*3+0];
|
||||
load5 = ptrbb[2*3+0];
|
||||
res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
|
||||
res0 = res0+TO_F32(load4)*TO_F32(load5);
|
||||
load6 = ptrba[2*3+1];
|
||||
res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
|
||||
res1 = res1+TO_F32(load6)*TO_F32(load5);
|
||||
load7 = ptrbb[2*3+1];
|
||||
res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
|
||||
res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
|
||||
res2 = res2+TO_F32(load4)*TO_F32(load7);
|
||||
res3 = res3+TO_F32(load6)*TO_F32(load7);
|
||||
ptrba = ptrba+8;
|
||||
ptrbb = ptrbb+8;
|
||||
}
|
||||
@@ -97,23 +98,23 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
|
||||
{
|
||||
load0 = ptrba[2*0+0];
|
||||
load1 = ptrbb[2*0+0];
|
||||
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
|
||||
res0 = res0+TO_F32(load0)*TO_F32(load1);
|
||||
load2 = ptrba[2*0+1];
|
||||
res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
|
||||
res1 = res1+TO_F32(load2)*TO_F32(load1);
|
||||
load3 = ptrbb[2*0+1];
|
||||
res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
|
||||
res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
|
||||
res2 = res2+TO_F32(load0)*TO_F32(load3);
|
||||
res3 = res3+TO_F32(load2)*TO_F32(load3);
|
||||
ptrba = ptrba+2;
|
||||
ptrbb = ptrbb+2;
|
||||
}
|
||||
res0 = res0*ALPHA;
|
||||
C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
|
||||
C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0);
|
||||
res1 = res1*ALPHA;
|
||||
C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1);
|
||||
C0[1] = TO_OUTPUT(TO_F32(C0[1])+res1);
|
||||
res2 = res2*ALPHA;
|
||||
C1[0] = F32TOBF16(BF16TOF32(C1[0])+res2);
|
||||
C1[0] = TO_OUTPUT(TO_F32(C1[0])+res2);
|
||||
res3 = res3*ALPHA;
|
||||
C1[1] = F32TOBF16(BF16TOF32(C1[1])+res3);
|
||||
C1[1] = TO_OUTPUT(TO_F32(C1[1])+res3);
|
||||
C0 = C0+2;
|
||||
C1 = C1+2;
|
||||
}
|
||||
@@ -126,16 +127,16 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
|
||||
{
|
||||
load0 = ptrba[0+0];
|
||||
load1 = ptrbb[2*0+0];
|
||||
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
|
||||
res0 = res0+TO_F32(load0)*TO_F32(load1);
|
||||
load2 = ptrbb[2*0+1];
|
||||
res1 = res1+BF16TOF32(load0)*BF16TOF32(load2);
|
||||
res1 = res1+TO_F32(load0)*TO_F32(load2);
|
||||
ptrba = ptrba+1;
|
||||
ptrbb = ptrbb+2;
|
||||
}
|
||||
res0 = res0*ALPHA;
|
||||
C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
|
||||
C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0);
|
||||
res1 = res1*ALPHA;
|
||||
C1[0] = F32TOBF16(BF16TOF32(C1[0])+res1);
|
||||
C1[0] = TO_OUTPUT(TO_F32(C1[0])+res1);
|
||||
C0 = C0+1;
|
||||
C1 = C1+1;
|
||||
}
|
||||
@@ -157,16 +158,16 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
|
||||
{
|
||||
load0 = ptrba[2*0+0];
|
||||
load1 = ptrbb[0+0];
|
||||
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
|
||||
res0 = res0+TO_F32(load0)*TO_F32(load1);
|
||||
load2 = ptrba[2*0+1];
|
||||
res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
|
||||
res1 = res1+TO_F32(load2)*TO_F32(load1);
|
||||
ptrba = ptrba+2;
|
||||
ptrbb = ptrbb+1;
|
||||
}
|
||||
res0 = res0*ALPHA;
|
||||
C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
|
||||
C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0);
|
||||
res1 = res1*ALPHA;
|
||||
C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1);
|
||||
C0[1] = TO_OUTPUT(TO_F32(C0[1])+res1);
|
||||
C0 = C0+2;
|
||||
}
|
||||
for (i=0; i<(bm&1); i+=1)
|
||||
@@ -177,12 +178,12 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
|
||||
{
|
||||
load0 = ptrba[0+0];
|
||||
load1 = ptrbb[0+0];
|
||||
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
|
||||
res0 = res0+TO_F32(load0)*TO_F32(load1);
|
||||
ptrba = ptrba+1;
|
||||
ptrbb = ptrbb+1;
|
||||
}
|
||||
res0 = res0*ALPHA;
|
||||
C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
|
||||
C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0);
|
||||
C0 = C0+1;
|
||||
}
|
||||
k = (bk<<0);
|
||||
|
||||
@@ -26,15 +26,14 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "common.h"
|
||||
#include "bf16_macros.h"
|
||||
|
||||
#include "conversion_macros.h"
|
||||
|
||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y)
|
||||
{
|
||||
BLASLONG i;
|
||||
BLASLONG ix, iy;
|
||||
BLASLONG j;
|
||||
FLOAT *a_ptr;
|
||||
#ifdef BGEMM
|
||||
IFLOAT *a_ptr;
|
||||
#if defined(BGEMM) || defined(HGEMM)
|
||||
float temp;
|
||||
#else
|
||||
FLOAT temp;
|
||||
@@ -49,18 +48,18 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||
a_ptr = a;
|
||||
for (BLASLONG j = 0; j < n; j++)
|
||||
{
|
||||
temp += BF16TOF32(a_ptr[i]) * BF16TOF32(x[ix]);
|
||||
temp += TO_F32(a_ptr[i]) * TO_F32(x[ix]);
|
||||
ix += inc_x;
|
||||
a_ptr += lda;
|
||||
}
|
||||
|
||||
if (BETA == ZERO)
|
||||
{
|
||||
y[iy] = F32TOBF16(ALPHA * temp);
|
||||
y[iy] = TO_OUTPUT(ALPHA * temp);
|
||||
}
|
||||
else
|
||||
{
|
||||
y[iy] = F32TOBF16(ALPHA * temp + BETA * BF16TOF32(y[iy]));
|
||||
y[iy] = TO_OUTPUT(ALPHA * temp + BETA * TO_F32(y[iy]));
|
||||
}
|
||||
|
||||
iy += inc_y;
|
||||
|
||||
@@ -26,15 +26,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "common.h"
|
||||
#include "bf16_macros.h"
|
||||
|
||||
#include "conversion_macros.h"
|
||||
|
||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y)
|
||||
{
|
||||
BLASLONG i;
|
||||
BLASLONG ix, iy;
|
||||
BLASLONG j;
|
||||
FLOAT *a_ptr;
|
||||
#ifdef BGEMM
|
||||
IFLOAT *a_ptr;
|
||||
#if defined(BGEMM) || defined(HGEMM)
|
||||
float temp;
|
||||
#else
|
||||
FLOAT temp;
|
||||
@@ -49,16 +50,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||
ix = 0;
|
||||
for (i = 0; i < m; i++)
|
||||
{
|
||||
temp += BF16TOF32(a_ptr[i]) * BF16TOF32(x[ix]);
|
||||
temp += TO_F32(a_ptr[i]) * TO_F32(x[ix]);
|
||||
ix += inc_x;
|
||||
}
|
||||
if (BETA == ZERO)
|
||||
{
|
||||
y[iy] = F32TOBF16(ALPHA * temp);
|
||||
y[iy] = TO_OUTPUT(ALPHA * temp);
|
||||
}
|
||||
else
|
||||
{
|
||||
y[iy] = F32TOBF16(ALPHA * temp + BETA * BF16TOF32(y[iy]));
|
||||
y[iy] = TO_OUTPUT(ALPHA * temp + BETA * TO_F32(y[iy]));
|
||||
}
|
||||
iy += inc_y;
|
||||
a_ptr += lda;
|
||||
|
||||
@@ -56,6 +56,24 @@ gotoblas_t TABLE_NAME = {
|
||||
|
||||
GEMM_DEFAULT_OFFSET_A, GEMM_DEFAULT_OFFSET_B, GEMM_DEFAULT_ALIGN,
|
||||
|
||||
#ifdef BUILD_HFLOAT16
|
||||
0, 0, 0,
|
||||
SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N,
|
||||
#ifdef SHGEMM_DEFAULT_UNROLL_MN
|
||||
SHGEMM_DEFAULT_UNROLL_MN,
|
||||
#else
|
||||
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
|
||||
#endif
|
||||
shgemm_kernelTS, shgemm_betaTS,
|
||||
#if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N
|
||||
shgemm_incopyTS, shgemm_itcopyTS,
|
||||
#else
|
||||
shgemm_oncopyTS, shgemm_otcopyTS,
|
||||
#endif
|
||||
shgemm_oncopyTS, shgemm_otcopyTS,
|
||||
shgemv_nTS, shgemv_tTS,
|
||||
#endif
|
||||
|
||||
#ifdef BUILD_BFLOAT16
|
||||
0, 0, 0,
|
||||
BGEMM_DEFAULT_UNROLL_M, BGEMM_DEFAULT_UNROLL_N,
|
||||
@@ -142,23 +160,6 @@ gotoblas_t TABLE_NAME = {
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef BUILD_HFLOAT16
|
||||
0, 0, 0,
|
||||
SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N,
|
||||
#ifdef SHGEMM_DEFAULT_UNROLL_MN
|
||||
SHGEMM_DEFAULT_UNROLL_MN,
|
||||
#else
|
||||
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
|
||||
#endif
|
||||
shgemm_kernelTS, shgemm_betaTS,
|
||||
#if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N
|
||||
shgemm_incopyTS, shgemm_itcopyTS,
|
||||
#else
|
||||
shgemm_oncopyTS, shgemm_otcopyTS,
|
||||
#endif
|
||||
shgemm_oncopyTS, shgemm_otcopyTS,
|
||||
#endif
|
||||
|
||||
#if ( BUILD_SINGLE==1) || (BUILD_DOUBLE==1) || (BUILD_COMPLEX==1) || (BUILD_COMPLEX16==1)
|
||||
0, 0, 0,
|
||||
SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N,
|
||||
|
||||
@@ -119,6 +119,9 @@ endif
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16), 1)
|
||||
SH2 = test_shgemv
|
||||
endif
|
||||
ifeq ($(BUILD_BFLOAT16), 1)
|
||||
BB2 = test_bgemv
|
||||
B2 = test_sbgemv
|
||||
@@ -136,7 +139,7 @@ ifeq ($(BUILD_COMPLEX16),1)
|
||||
Z2=zblat2
|
||||
endif
|
||||
|
||||
level2: $(BB2) $(B2) $(S2) $(D2) $(C2) $(Z2)
|
||||
level2: $(SH2) $(BB2) $(B2) $(S2) $(D2) $(C2) $(Z2)
|
||||
|
||||
|
||||
ifneq ($(CROSS), 1)
|
||||
@@ -147,6 +150,10 @@ ifeq ($(BUILD_BFLOAT16),1)
|
||||
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemv > SBBLAT2.SUMM
|
||||
@$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0
|
||||
endif
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_shgemv > SHBLAT2.SUMM
|
||||
@$(GREP) -q FATAL SHBLAT2.SUMM && cat SHBLAT2.SUMM || exit 0
|
||||
endif
|
||||
ifeq ($(BUILD_SINGLE),1)
|
||||
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat2 < ./sblat2.dat
|
||||
@$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0
|
||||
@@ -172,6 +179,10 @@ ifeq ($(BUILD_BFLOAT16),1)
|
||||
OMP_NUM_THREADS=2 ./test_sbgemv > SBBLAT2.SUMM
|
||||
@$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0
|
||||
endif
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
OMP_NUM_THREADS=2 ./test_shgemv > SHBLAT2.SUMM
|
||||
@$(GREP) -q FATAL SHBLAT2.SUMM && cat SHBLAT2.SUMM || exit 0
|
||||
endif
|
||||
ifeq ($(BUILD_SINGLE),1)
|
||||
OMP_NUM_THREADS=2 ./sblat2 < ./sblat2.dat
|
||||
@$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0
|
||||
@@ -195,6 +206,10 @@ ifeq ($(BUILD_BFLOAT16),1)
|
||||
OMP_NUM_THREADS=2 ./test_sbgemv > SBBLAT2.SUMM
|
||||
@$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0
|
||||
endif
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
OMP_NUM_THREADS=2 ./test_shgemv > SHBLAT2.SUMM
|
||||
@$(GREP) -q FATAL SHBLAT2.SUMM && cat SHBLAT2.SUMM || exit 0
|
||||
endif
|
||||
ifeq ($(BUILD_SINGLE),1)
|
||||
OPENBLAS_NUM_THREADS=2 ./sblat2 < ./sblat2.dat
|
||||
@$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0
|
||||
@@ -438,6 +453,12 @@ test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME)
|
||||
$(CC) $(CLDFLAGS) -DIBFLOAT16 -o test_sbgemv compare_sgemv_sbgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
test_shgemv : compare_sgemv_shgemv.c ../$(LIBNAME)
|
||||
$(CC) $(CLDFLAGS) -o test_shgemv compare_sgemv_shgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
|
||||
endif
|
||||
|
||||
|
||||
ifeq ($(BUILD_COMPLEX),1)
|
||||
cblat3_3m : cblat3_3m.$(SUFFIX) ../$(LIBNAME)
|
||||
$(FC) $(FLDFLAGS) -o cblat3_3m cblat3_3m.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
|
||||
@@ -454,7 +475,7 @@ clean:
|
||||
@rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \
|
||||
sblat1 dblat1 cblat1 zblat1 \
|
||||
sblat2 dblat2 cblat2 zblat2 \
|
||||
test_bgemm test_bgemv test_sbgemm test_sbgemv sblat3 dblat3 cblat3 zblat3 \
|
||||
test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemv sblat3 dblat3 cblat3 zblat3 \
|
||||
sblat1p dblat1p cblat1p zblat1p \
|
||||
sblat2p dblat2p cblat2p zblat2p \
|
||||
sblat3p dblat3p cblat3p zblat3p \
|
||||
|
||||
130
test/compare_sgemv_shgemv.c
Normal file
130
test/compare_sgemv_shgemv.c
Normal file
@@ -0,0 +1,130 @@
|
||||
/***************************************************************************
|
||||
Copyright (c) 2020,2025 The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include "../common.h"
|
||||
|
||||
#include "test_helpers.h"
|
||||
|
||||
#define SGEMV BLASFUNC(sgemv)
|
||||
#define SHGEMV BLASFUNC(shgemv)
|
||||
#define SHGEMV_LARGEST 256
|
||||
|
||||
int
|
||||
main (int argc, char *argv[])
|
||||
{
|
||||
blasint k;
|
||||
int i, j, l;
|
||||
blasint x, y;
|
||||
int ret = 0;
|
||||
int loop = SHGEMV_LARGEST;
|
||||
char transA = 'N';
|
||||
float alpha = 1.0, beta = 0.0;
|
||||
|
||||
for (beta = 0; beta < 3; beta += 1) {
|
||||
for (alpha = 0; alpha < 3; alpha += 1) {
|
||||
for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one.
|
||||
for (x = 1; x <= loop; x++)
|
||||
{
|
||||
k = (x == 0) ? 0 : l + 1;
|
||||
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
|
||||
float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l);
|
||||
float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l);
|
||||
hfloat16 *AA = (hfloat16 *)malloc_safe(x * x * sizeof(hfloat16));
|
||||
hfloat16 *BB = (hfloat16 *)malloc_safe(x * sizeof(hfloat16) << l);
|
||||
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l);
|
||||
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
|
||||
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
|
||||
(DD == NULL) || (CC == NULL))
|
||||
return 1;
|
||||
|
||||
for (j = 0; j < x; j++)
|
||||
{
|
||||
for (i = 0; i < x; i++)
|
||||
{
|
||||
A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
AA[j * x + i] = (_Float16)A[j * x + i];
|
||||
}
|
||||
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
BB[j << l]= (_Float16)B[j << l];
|
||||
|
||||
CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
}
|
||||
|
||||
for (y = 0; y < 2; y++)
|
||||
{
|
||||
if (y == 0) {
|
||||
transA = 'N';
|
||||
} else {
|
||||
transA = 'T';
|
||||
}
|
||||
|
||||
memset(CC, 0, x * sizeof(FLOAT) << l);
|
||||
memset(DD, 0, x * sizeof(FLOAT));
|
||||
memset(C, 0, x * sizeof(FLOAT) << l);
|
||||
|
||||
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
|
||||
SHGEMV (&transA, &x, &x, &alpha, (hfloat16*) AA, &x, (hfloat16*) BB, &k, &beta, CC, &k);
|
||||
|
||||
for (int i = 0; i < x; i ++) DD[i] *= beta;
|
||||
|
||||
for (j = 0; j < x; j++)
|
||||
for (i = 0; i < x; i++)
|
||||
if (transA == 'N') {
|
||||
DD[i] += alpha * (float)(AA[j * x + i]) * (float)(BB[j << l]);
|
||||
} else if (transA == 'T') {
|
||||
DD[j] += alpha * (float)(AA[j * x + i]) * (float)(BB[i << l]);
|
||||
}
|
||||
|
||||
for (j = 0; j < x; j++) {
|
||||
if (!is_close(CC[j << l], C[j << l], 0.01, 0.001)) {
|
||||
ret++;
|
||||
}
|
||||
if (!is_close(CC[j << l], DD[j], 0.001, 0.0001)) {
|
||||
ret++;
|
||||
}
|
||||
}
|
||||
}
|
||||
free(A);
|
||||
free(B);
|
||||
free(C);
|
||||
free(AA);
|
||||
free(BB);
|
||||
free(DD);
|
||||
free(CC);
|
||||
} // x
|
||||
} // l
|
||||
} // alpha
|
||||
} // beta
|
||||
|
||||
if (ret != 0) {
|
||||
fprintf (stderr, "SHGEMV FAILURES: %d\n", ret);
|
||||
return 1;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
Reference in New Issue
Block a user