Add Infrastructure for SHGEMV

This adds all the relevant bits and pieces to add a `shgemv` path as
well as a future `hgemm`/`hgemv` path in a similar model to `sb` and `b`
interfaces.

I've also fixed a few bits and pieces around `shgemm` which didn't build
in a few situations.
This commit is contained in:
Chris Sidebottom
2025-10-07 13:24:36 +00:00
parent 8918247207
commit 37fc3bbca0
24 changed files with 383 additions and 84 deletions

2
.gitignore vendored
View File

@@ -80,6 +80,7 @@ test/SBLAT3_3M.SUMM
test/ZBLAT2.SUMM
test/ZBLAT3.SUMM
test/ZBLAT3_3M.SUMM
test/SHBLAT2.SUMM
test/SHBLAT3.SUMM
test/SBBLAT2.SUMM
test/SBBLAT3.SUMM
@@ -98,6 +99,7 @@ test/sblat2
test/sblat3
test/sblat3_3m
test/test_shgemm
test/test_shgemv
test/test_sbgemm
test/test_sbgemv
test/test_bgemm

View File

@@ -175,6 +175,10 @@ if (BUILD_BFLOAT16)
SetFallback(SBGEMVNKERNEL ../x86_64/sbgemv_n.c)
SetFallback(SBGEMVTKERNEL ../x86_64/sbgemv_t.c)
endif ()
if (BUILD_HFLOAT16)
SetFallback(SHGEMVNKERNEL ../generic/gemv_n.c)
SetFallback(SHGEMVTKERNEL ../generic/gemv_t.c)
endif ()
endmacro ()
macro(SetDefaultL2)
@@ -226,6 +230,8 @@ macro(SetDefaultL2)
if (BUILD_BFLOAT16)
SetFallback(BGEMVNKERNEL ../generic/gemv_n.c)
SetFallback(BGEMVTKERNEL ../generic/gemv_t.c)
SetFallback(SHGEMVNKERNEL ../generic/gemv_n.c)
SetFallback(SHGEMVTKERNEL ../generic/gemv_t.c)
SetFallback(SBGEMVNKERNEL ../x86_64/sbgemv_n.c)
SetFallback(SBGEMVTKERNEL ../x86_64/sbgemv_t.c)
SetFallback(SHGERKERNEL ../generic/ger.c)
@@ -260,5 +266,16 @@ if (BUILD_BFLOAT16)
SetFallback(SBGEMMONCOPYOBJ sbgemm_oncopy.o)
SetFallback(SBGEMMOTCOPYOBJ sbgemm_otcopy.o)
endif ()
if (BUILD_HFLOAT16)
SetFallback(SHGEMMKERNEL ../generic/gemmkernel_2x2.c)
SetFallback(SHGEMM_BETA ../generic/gemm_beta.c)
SetFallback(SHGEMMINCOPY ../generic/gemm_ncopy_2.c)
SetFallback(SHGEMMITCOPY ../generic/gemm_tcopy_2.c)
SetFallback(SHGEMMONCOPY ../generic/gemm_ncopy_2.c)
SetFallback(SHGEMMOTCOPY ../generic/gemm_tcopy_2.c)
SetFallback(SHGEMMINCOPYOBJ shgemm_incopy.o)
SetFallback(SHGEMMITCOPYOBJ shgemm_itcopy.o)
SetFallback(SHGEMMONCOPYOBJ shgemm_oncopy.o)
SetFallback(SHGEMMOTCOPYOBJ shgemm_otcopy.o)
endif ()
endmacro ()

View File

@@ -375,9 +375,12 @@ function(GenerateNamedObjects sources_in)
if (NOT no_float_type)
string(SUBSTRING ${float_type} 0 1 float_char)
string(TOLOWER ${float_char} float_char)
if (${float_type} STREQUAL "BFLOAT16" AND NOT "${defines_in}" MATCHES "BGEM")
set (float_char "sb")
endif ()
if (${float_type} STREQUAL "BFLOAT16" AND NOT "${defines_in}" MATCHES "BGEM")
set (float_char "sb")
endif ()
if (${float_type} STREQUAL "HFLOAT16" AND NOT "${defines_in}" MATCHES "HGEM")
set (float_char "sh")
endif ()
endif ()
if (NOT name_in)

View File

@@ -261,6 +261,8 @@ void BLASFUNC(bgemv)(char *, blasint *, blasint *, bfloat16 *, bfloat16 *, blas
bfloat16 *, blasint *, bfloat16 *, bfloat16 *, blasint *);
void BLASFUNC(sbgemv)(char *, blasint *, blasint *, float *, bfloat16 *, blasint *,
bfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(shgemv)(char *, blasint *, blasint *, float *, hfloat16 *, blasint *,
hfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sgemv)(char *, blasint *, blasint *, float *, float *, blasint *,
float *, blasint *, float *, float *, blasint *);
void BLASFUNC(dgemv)(char *, blasint *, blasint *, double *, double *, blasint *,

View File

@@ -54,6 +54,10 @@ int sbgemv_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLO
int sbgemv_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
int sbgemv_thread_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int);
int sbgemv_thread_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int);
int shgemv_n(BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG);
int shgemv_t(BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG);
int shgemv_thread_n(BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG, int);
int shgemv_thread_t(BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG, int);
int sger_k (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int dger_k (BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, double *, BLASLONG, double *);
int qger_k (BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *);

View File

@@ -703,6 +703,9 @@
#define GEMM_THREAD_RC SHGEMM_THREAD_NT
#define GEMM_THREAD_RR SHGEMM_THREAD_NN
#define SCAL_K SSCAL_K
#define GEMV_N SHGEMV_N_K
#define GEMV_T SHGEMV_T_K
#elif defined(BFLOAT16) && defined(BGEMM)
#define SCAL_K BSCAL_K

View File

@@ -60,7 +60,8 @@ int (*shgemm_itcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*shgemm_oncopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*shgemv_n) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG);
int (*shgemv_t) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG);
#endif

View File

@@ -1,3 +1,31 @@
/***************************************************************************
* Copyright (c) 2025, The OpenBLAS Project
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* 3. Neither the name of the OpenBLAS project nor the names of
* its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
* *****************************************************************************/
#ifndef COMMON_SH_H
#define COMMON_SH_H
@@ -17,6 +45,9 @@
#define SHGEMM_BETA shgemm_beta
#define SHGEMM_KERNEL shgemm_kernel
#define SHGEMV_N_K shgemv_n
#define SHGEMV_T_K shgemv_t
#else // #DYNAMIC_ARCH
@@ -32,6 +63,10 @@
#define SHGEMM_BETA gotoblas -> shgemm_beta
#define SHGEMM_KERNEL gotoblas -> shgemm_kernel
#define SHGEMV_N_K gotoblas->shgemv_n
#define SHGEMV_T_K gotoblas->shgemv_t
#endif // #DYNAMIC_ARCH
#define SHGEMM_NN shgemm_nn

View File

@@ -450,6 +450,12 @@ XBLASOBJS += \
xtbmv_thread_CUU.$(SUFFIX) xtbmv_thread_CUN.$(SUFFIX) \
xtbmv_thread_CLU.$(SUFFIX) xtbmv_thread_CLN.$(SUFFIX)
ifeq ($(BUILD_HFLOAT16),1)
SHBLASOBJS += \
shgemv_thread_n$(TSUFFIX).$(SUFFIX) \
shgemv_thread_t$(TSUFFIX).$(SUFFIX)
endif
ifeq ($(BUILD_BFLOAT16),1)
BBLASOBJS += \
bgemv_thread_n$(TSUFFIX).$(SUFFIX) \
@@ -3737,6 +3743,13 @@ xtrsv_CUU.$(SUFFIX) xtrsv_CUU.$(PSUFFIX) : ztrsv_L.c ../../param.h
xtrsv_CUN.$(SUFFIX) xtrsv_CUN.$(PSUFFIX) : ztrsv_L.c ../../param.h
$(CC) -c $(CFLAGS) -DXDOUBLE -DCOMPLEX -DTRANSA=4 -UUNIT $< -o $(@F)
ifeq ($(BUILD_HFLOAT16),1)
shgemv_thread_n.$(SUFFIX) shgemv_thread_n.$(PSUFFIX) : sbgemv_thread.c ../../common.h
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -UTRANSA -UCONJ -UXCONJ $< -o $(@F)
shgemv_thread_t.$(SUFFIX) shgemv_thread_t.$(PSUFFIX) : sbgemv_thread.c ../../common.h
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -DTRANSA -UCONJ -UXCONJ $< -o $(@F)
endif
ifeq ($(BUILD_BFLOAT16),1)
bgemv_thread_n.$(SUFFIX) bgemv_thread_n.$(PSUFFIX) : sbgemv_thread.c ../../common.h
$(CC) -c $(CFLAGS) -DBGEMM -UCOMPLEX -UDOUBLE -UTRANSA -UCONJ -UXCONJ $< -o $(@F)

View File

@@ -80,7 +80,7 @@ blasobjsz="
blasobjs="lsame xerbla"
bfblasobjs="bgemm bgemv sbgemm sbgemmt sbgemmtr sbgemv sbdot sbstobf16 sbdtobf16 sbf16tos dbf16tod"
hfblasobjs="shgemm"
hfblasobjs="shgemm shgemv"
cblasobjsc="
cblas_caxpy cblas_ccopy cblas_cdotc cblas_cdotu cblas_cgbmv cblas_cgemm cblas_cgemv
cblas_cgerc cblas_cgeru cblas_chbmv cblas_chemm cblas_chemv cblas_cher2 cblas_cher2k

View File

@@ -80,7 +80,7 @@
@blasobjs = (lsame, xerbla);
@bfblasobjs = (bgemm, bgemv, sbgemm, sbgemmt, sbgemmtr, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod);
@hfblasobjs = (shgemm);
@hfblasobjs = (shgemm, shgemv);
@cblasobjsc = (
cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv,
cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k,

View File

@@ -166,6 +166,7 @@ if (BUILD_BFLOAT16)
endif ()
if (BUILD_HFLOAT16)
GenerateNamedObjects("gemm.c" "" "shgemm" ${CBLAS_FLAG} "" "" true "HFLOAT16")
GenerateNamedObjects("sbgemv.c" "" "shgemv" ${CBLAS_FLAG} "" "" true "HFLOAT16")
endif ()
# complex-specific sources

View File

@@ -87,6 +87,7 @@ endif
ifeq ($(BUILD_HFLOAT16),1)
SHBLAS3OBJS = shgemm.$(SUFFIX)
SHBLAS2OBJS = shgemv.$(SUFFIX)
endif
DBLAS1OBJS = \
@@ -338,6 +339,7 @@ endif
ifeq ($(BUILD_HFLOAT16),1)
CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX)
CSHBLAS2OBJS = cblas_shgemv.$(SUFFIX)
endif
CDBLAS1OBJS = \
@@ -441,6 +443,7 @@ SBBLAS1OBJS += $(CSBBLAS1OBJS)
SBBLAS2OBJS += $(CSBBLAS2OBJS)
SBBLAS3OBJS += $(CSBBLAS3OBJS)
SHBLAS3OBJS += $(CSHBLAS3OBJS)
SHBLAS2OBJS += $(CSHBLAS2OBJS)
DBLAS1OBJS += $(CDBLAS1OBJS)
DBLAS2OBJS += $(CDBLAS2OBJS)
DBLAS3OBJS += $(CDBLAS3OBJS)
@@ -459,7 +462,7 @@ endif
BBLASOBJS = $(BBLAS3OBJS) $(BBLAS2OBJS) $(BBLAS1OBJS)
SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS)
SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS)
SHBLASOBJS = $(SHBLAS3OBJS)
SHBLASOBJS = $(SHBLAS3OBJS) $(SHBLAS2OBJS)
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS)
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS)
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS)
@@ -602,7 +605,7 @@ clean ::
level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
level2 : $(SBBLAS2OBJS) $(BBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
level2 : $(SBBLAS2OBJS) $(BBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) $(SHBLAS2OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
level3 : $(SBBLAS3OBJS) $(BBLAS3OBJ) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(SHBLAS3OBJS)
@@ -1002,6 +1005,11 @@ sbgemv.$(SUFFIX) sbgemv.$(PSUFFIX) : sbgemv.c
$(CC) $(CFLAGS) -c $< -o $(@F)
endif
ifeq ($(BUILD_HFLOAT16),1)
shgemv.$(SUFFIX) shgemv.$(PSUFFIX) : sbgemv.c
$(CC) $(CFLAGS) -c $< -o $(@F)
endif
ifndef USE_NETLIB_GEMV
sgemv.$(SUFFIX) sgemv.$(PSUFFIX): gemv.c
$(CC) -c $(CFLAGS) -o $(@F) $<
@@ -1832,6 +1840,11 @@ cblas_sbgemv.$(SUFFIX) cblas_sbgemv.$(PSUFFIX) : sbgemv.c
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
endif
ifeq ($(BUILD_HFLOAT16),1)
cblas_shgemv.$(SUFFIX) cblas_shgemv.$(PSUFFIX) : sbgemv.c
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
endif
cblas_sgemv.$(SUFFIX) cblas_sgemv.$(PSUFFIX): gemv.c
$(CC) -DCBLAS -c $(CFLAGS) -o $(@F) $<

View File

@@ -587,7 +587,10 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
#endif
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && !defined(HFLOAT16) && (!defined(BFLOAT16) || (!defined(BGEMM) && defined(SBGEMM_GEMV_FORWARD)) || (defined(BGEMM) && defined(BGEMM_GEMV_FORWARD)))
#define BFLOAT16_GEMM_GEMV_FORWARD (!defined(BFLOAT16) || (!defined(BGEMM) && defined(SBGEMM_GEMV_FORWARD)) || (defined(BGEMM) && defined(BGEMM_GEMV_FORWARD)))
#define HFLOAT16_GEMM_GEMV_FORWARD (!defined(HFLOAT16) || (!defined(HGEMM) && defined(SHGEMM_GEMV_FORWARD)) || (defined(HGEMM) && defined(HGEMM_GEMV_FORWARD)))
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && HFLOAT16_GEMM_GEMV_FORWARD && BFLOAT16_GEMM_GEMV_FORWARD
#if defined(ARCH_ARM64)
// The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c}
// perform poorly in certain circumstances. We use the following boolean

View File

@@ -48,6 +48,10 @@
#define GEMV_THREAD_N bgemv_thread_n
#define GEMV_THREAD_T bgemv_thread_t
#define ERROR_NAME "BGEMV "
#elif defined(HFLOAT16)
#define GEMV_THREAD_N shgemv_thread_n
#define GEMV_THREAD_T shgemv_thread_t
#define ERROR_NAME "SHGEMV "
#else
#define GEMV_THREAD_N sbgemv_thread_n
#define GEMV_THREAD_T sbgemv_thread_t

View File

@@ -228,6 +228,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
GenerateNamedObjects("${KERNELDIR}/${SBGEMVNKERNEL}" "" "gemv_n" false "" "" false "BFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SBGEMVTKERNEL}" "" "gemv_t" false "" "" false "BFLOAT16")
endif ()
if (BUILD_HFLOAT16)
GenerateNamedObjects("${KERNELDIR}/${SHGEMVNKERNEL}" "" "gemv_n" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMVTKERNEL}" "" "gemv_t" false "" "" false "HFLOAT16")
endif ()
# Makefile.L3
set(USE_TRMM false)
string(TOUPPER ${TARGET_CORE} UC_TARGET_CORE)

View File

@@ -101,6 +101,16 @@ SBGEMVTKERNEL = ../x86_64/sbgemv_t.c
endif
endif
ifeq ($(BUILD_HFLOAT16),1)
ifndef SHGEMVNKERNEL
SHGEMVNKERNEL = ../generic/gemv_n.c
endif
ifndef SHGEMVTKERNEL
SHGEMVTKERNEL = ../generic/gemv_t.c
endif
endif
### GER ###
ifndef SGERKERNEL
@@ -299,6 +309,12 @@ SBBLASOBJS += \
sbgemv_t$(TSUFFIX).$(SUFFIX)
endif
ifeq ($(BUILD_HFLOAT16),1)
SHBLASOBJS += \
shgemv_n$(TSUFFIX).$(SUFFIX) \
shgemv_t$(TSUFFIX).$(SUFFIX)
endif
ifneq "$(or $(BUILD_SINGLE), $(BUILD_DOUBLE), $(BUILD_COMPLEX))" ""
$(KDIR)sgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)sgemv_n$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMVNKERNEL) $(TOPDIR)/common.h $(GEMVDEP)
$(CC) -c $(CFLAGS) -UDOUBLE -UCOMPLEX -UTRANS $< -o $@
@@ -558,3 +574,10 @@ $(KDIR)bgemv_t$(TSUFFIX).$(SUFFIX) $(KDIR)bgemv_t$(TPSUFFIX).$(PSUFFIX) : $(KERN
$(CC) -c $(CFLAGS) -DBGEMM -UCOMPLEX $< -o $@
endif
ifeq ($(BUILD_HFLOAT16),1)
$(KDIR)shgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)shgemv_n$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMVNKERNEL)
$(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@
$(KDIR)shgemv_t$(TSUFFIX).$(SUFFIX) $(KDIR)shgemv_t$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMVTKERNEL)
$(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@
endif

View File

@@ -27,6 +27,7 @@
* *****************************************************************************/
#if defined(BFLOAT16) && defined(BFLOAT16CONVERSION)
static float
bfloat16tof32 (bfloat16 value)
{
@@ -48,17 +49,34 @@ static bfloat16 f32tobfloat16(float value) {
#ifdef BGEMM
#define ALPHA bfloat16tof32(alpha)
#define BETA bfloat16tof32(beta)
#define BF16TOF32(x) (bfloat16tof32(x))
#define F32TOBF16(x) (f32tobfloat16(x))
#define TO_F32(x) (bfloat16tof32(x))
#define TO_OUTPUT(x) (f32tobfloat16(x))
#else
#define ALPHA alpha
#define BETA beta
#define BF16TOF32(x) (bfloat16tof32(x))
#define F32TOBF16(x) x
#define TO_F32(x) (bfloat16tof32(x))
#define TO_OUTPUT(x) x
#endif
#elif defined(HFLOAT16)
#ifdef HGEMM
#define ALPHA (float)(alpha)
#define BETA (float)(beta)
#define TO_F32(x) ((float)(x))
#define TO_OUTPUT(x) ((_Float16)(x))
#else
#define ALPHA alpha
#define BETA beta
#define BF16TOF32(x) x
#define F32TOBF16(x) x
#define TO_F32(x) ((float)(x))
#define TO_OUTPUT(x) x
#endif
#else
#define ALPHA alpha
#define BETA beta
#define TO_F32(x) x
#define TO_OUTPUT(x) x
#endif

View File

@@ -27,7 +27,8 @@
* *****************************************************************************/
#include "common.h"
#include "bf16_macros.h"
#include "conversion_macros.h"
int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc
#ifdef TRMMKERNEL
@@ -60,36 +61,36 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
{
load0 = ptrba[2*0+0];
load1 = ptrbb[2*0+0];
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
res0 = res0+TO_F32(load0)*TO_F32(load1);
load2 = ptrba[2*0+1];
res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
res1 = res1+TO_F32(load2)*TO_F32(load1);
load3 = ptrbb[2*0+1];
res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
res2 = res2+TO_F32(load0)*TO_F32(load3);
res3 = res3+TO_F32(load2)*TO_F32(load3);
load4 = ptrba[2*1+0];
load5 = ptrbb[2*1+0];
res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
res0 = res0+TO_F32(load4)*TO_F32(load5);
load6 = ptrba[2*1+1];
res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
res1 = res1+TO_F32(load6)*TO_F32(load5);
load7 = ptrbb[2*1+1];
res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
res2 = res2+TO_F32(load4)*TO_F32(load7);
res3 = res3+TO_F32(load6)*TO_F32(load7);
load0 = ptrba[2*2+0];
load1 = ptrbb[2*2+0];
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
res0 = res0+TO_F32(load0)*TO_F32(load1);
load2 = ptrba[2*2+1];
res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
res1 = res1+TO_F32(load2)*TO_F32(load1);
load3 = ptrbb[2*2+1];
res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
res2 = res2+TO_F32(load0)*TO_F32(load3);
res3 = res3+TO_F32(load2)*TO_F32(load3);
load4 = ptrba[2*3+0];
load5 = ptrbb[2*3+0];
res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
res0 = res0+TO_F32(load4)*TO_F32(load5);
load6 = ptrba[2*3+1];
res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
res1 = res1+TO_F32(load6)*TO_F32(load5);
load7 = ptrbb[2*3+1];
res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
res2 = res2+TO_F32(load4)*TO_F32(load7);
res3 = res3+TO_F32(load6)*TO_F32(load7);
ptrba = ptrba+8;
ptrbb = ptrbb+8;
}
@@ -97,23 +98,23 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
{
load0 = ptrba[2*0+0];
load1 = ptrbb[2*0+0];
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
res0 = res0+TO_F32(load0)*TO_F32(load1);
load2 = ptrba[2*0+1];
res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
res1 = res1+TO_F32(load2)*TO_F32(load1);
load3 = ptrbb[2*0+1];
res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
res2 = res2+TO_F32(load0)*TO_F32(load3);
res3 = res3+TO_F32(load2)*TO_F32(load3);
ptrba = ptrba+2;
ptrbb = ptrbb+2;
}
res0 = res0*ALPHA;
C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0);
res1 = res1*ALPHA;
C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1);
C0[1] = TO_OUTPUT(TO_F32(C0[1])+res1);
res2 = res2*ALPHA;
C1[0] = F32TOBF16(BF16TOF32(C1[0])+res2);
C1[0] = TO_OUTPUT(TO_F32(C1[0])+res2);
res3 = res3*ALPHA;
C1[1] = F32TOBF16(BF16TOF32(C1[1])+res3);
C1[1] = TO_OUTPUT(TO_F32(C1[1])+res3);
C0 = C0+2;
C1 = C1+2;
}
@@ -126,16 +127,16 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
{
load0 = ptrba[0+0];
load1 = ptrbb[2*0+0];
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
res0 = res0+TO_F32(load0)*TO_F32(load1);
load2 = ptrbb[2*0+1];
res1 = res1+BF16TOF32(load0)*BF16TOF32(load2);
res1 = res1+TO_F32(load0)*TO_F32(load2);
ptrba = ptrba+1;
ptrbb = ptrbb+2;
}
res0 = res0*ALPHA;
C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0);
res1 = res1*ALPHA;
C1[0] = F32TOBF16(BF16TOF32(C1[0])+res1);
C1[0] = TO_OUTPUT(TO_F32(C1[0])+res1);
C0 = C0+1;
C1 = C1+1;
}
@@ -157,16 +158,16 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
{
load0 = ptrba[2*0+0];
load1 = ptrbb[0+0];
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
res0 = res0+TO_F32(load0)*TO_F32(load1);
load2 = ptrba[2*0+1];
res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
res1 = res1+TO_F32(load2)*TO_F32(load1);
ptrba = ptrba+2;
ptrbb = ptrbb+1;
}
res0 = res0*ALPHA;
C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0);
res1 = res1*ALPHA;
C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1);
C0[1] = TO_OUTPUT(TO_F32(C0[1])+res1);
C0 = C0+2;
}
for (i=0; i<(bm&1); i+=1)
@@ -177,12 +178,12 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
{
load0 = ptrba[0+0];
load1 = ptrbb[0+0];
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
res0 = res0+TO_F32(load0)*TO_F32(load1);
ptrba = ptrba+1;
ptrbb = ptrbb+1;
}
res0 = res0*ALPHA;
C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0);
C0 = C0+1;
}
k = (bk<<0);

View File

@@ -26,15 +26,14 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include "common.h"
#include "bf16_macros.h"
#include "conversion_macros.h"
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y)
{
BLASLONG i;
BLASLONG ix, iy;
BLASLONG j;
FLOAT *a_ptr;
#ifdef BGEMM
IFLOAT *a_ptr;
#if defined(BGEMM) || defined(HGEMM)
float temp;
#else
FLOAT temp;
@@ -49,18 +48,18 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
a_ptr = a;
for (BLASLONG j = 0; j < n; j++)
{
temp += BF16TOF32(a_ptr[i]) * BF16TOF32(x[ix]);
temp += TO_F32(a_ptr[i]) * TO_F32(x[ix]);
ix += inc_x;
a_ptr += lda;
}
if (BETA == ZERO)
{
y[iy] = F32TOBF16(ALPHA * temp);
y[iy] = TO_OUTPUT(ALPHA * temp);
}
else
{
y[iy] = F32TOBF16(ALPHA * temp + BETA * BF16TOF32(y[iy]));
y[iy] = TO_OUTPUT(ALPHA * temp + BETA * TO_F32(y[iy]));
}
iy += inc_y;

View File

@@ -26,15 +26,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include "common.h"
#include "bf16_macros.h"
#include "conversion_macros.h"
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y)
{
BLASLONG i;
BLASLONG ix, iy;
BLASLONG j;
FLOAT *a_ptr;
#ifdef BGEMM
IFLOAT *a_ptr;
#if defined(BGEMM) || defined(HGEMM)
float temp;
#else
FLOAT temp;
@@ -49,16 +50,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
ix = 0;
for (i = 0; i < m; i++)
{
temp += BF16TOF32(a_ptr[i]) * BF16TOF32(x[ix]);
temp += TO_F32(a_ptr[i]) * TO_F32(x[ix]);
ix += inc_x;
}
if (BETA == ZERO)
{
y[iy] = F32TOBF16(ALPHA * temp);
y[iy] = TO_OUTPUT(ALPHA * temp);
}
else
{
y[iy] = F32TOBF16(ALPHA * temp + BETA * BF16TOF32(y[iy]));
y[iy] = TO_OUTPUT(ALPHA * temp + BETA * TO_F32(y[iy]));
}
iy += inc_y;
a_ptr += lda;

View File

@@ -56,6 +56,24 @@ gotoblas_t TABLE_NAME = {
GEMM_DEFAULT_OFFSET_A, GEMM_DEFAULT_OFFSET_B, GEMM_DEFAULT_ALIGN,
#ifdef BUILD_HFLOAT16
0, 0, 0,
SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N,
#ifdef SHGEMM_DEFAULT_UNROLL_MN
SHGEMM_DEFAULT_UNROLL_MN,
#else
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
#endif
shgemm_kernelTS, shgemm_betaTS,
#if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N
shgemm_incopyTS, shgemm_itcopyTS,
#else
shgemm_oncopyTS, shgemm_otcopyTS,
#endif
shgemm_oncopyTS, shgemm_otcopyTS,
shgemv_nTS, shgemv_tTS,
#endif
#ifdef BUILD_BFLOAT16
0, 0, 0,
BGEMM_DEFAULT_UNROLL_M, BGEMM_DEFAULT_UNROLL_N,
@@ -142,23 +160,6 @@ gotoblas_t TABLE_NAME = {
#endif
#endif
#ifdef BUILD_HFLOAT16
0, 0, 0,
SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N,
#ifdef SHGEMM_DEFAULT_UNROLL_MN
SHGEMM_DEFAULT_UNROLL_MN,
#else
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
#endif
shgemm_kernelTS, shgemm_betaTS,
#if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N
shgemm_incopyTS, shgemm_itcopyTS,
#else
shgemm_oncopyTS, shgemm_otcopyTS,
#endif
shgemm_oncopyTS, shgemm_otcopyTS,
#endif
#if ( BUILD_SINGLE==1) || (BUILD_DOUBLE==1) || (BUILD_COMPLEX==1) || (BUILD_COMPLEX16==1)
0, 0, 0,
SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N,

View File

@@ -119,6 +119,9 @@ endif
endif
endif
ifeq ($(BUILD_HFLOAT16), 1)
SH2 = test_shgemv
endif
ifeq ($(BUILD_BFLOAT16), 1)
BB2 = test_bgemv
B2 = test_sbgemv
@@ -136,7 +139,7 @@ ifeq ($(BUILD_COMPLEX16),1)
Z2=zblat2
endif
level2: $(BB2) $(B2) $(S2) $(D2) $(C2) $(Z2)
level2: $(SH2) $(BB2) $(B2) $(S2) $(D2) $(C2) $(Z2)
ifneq ($(CROSS), 1)
@@ -147,6 +150,10 @@ ifeq ($(BUILD_BFLOAT16),1)
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemv > SBBLAT2.SUMM
@$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0
endif
ifeq ($(BUILD_HFLOAT16),1)
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_shgemv > SHBLAT2.SUMM
@$(GREP) -q FATAL SHBLAT2.SUMM && cat SHBLAT2.SUMM || exit 0
endif
ifeq ($(BUILD_SINGLE),1)
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat2 < ./sblat2.dat
@$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0
@@ -172,6 +179,10 @@ ifeq ($(BUILD_BFLOAT16),1)
OMP_NUM_THREADS=2 ./test_sbgemv > SBBLAT2.SUMM
@$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0
endif
ifeq ($(BUILD_HFLOAT16),1)
OMP_NUM_THREADS=2 ./test_shgemv > SHBLAT2.SUMM
@$(GREP) -q FATAL SHBLAT2.SUMM && cat SHBLAT2.SUMM || exit 0
endif
ifeq ($(BUILD_SINGLE),1)
OMP_NUM_THREADS=2 ./sblat2 < ./sblat2.dat
@$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0
@@ -195,6 +206,10 @@ ifeq ($(BUILD_BFLOAT16),1)
OMP_NUM_THREADS=2 ./test_sbgemv > SBBLAT2.SUMM
@$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0
endif
ifeq ($(BUILD_HFLOAT16),1)
OMP_NUM_THREADS=2 ./test_shgemv > SHBLAT2.SUMM
@$(GREP) -q FATAL SHBLAT2.SUMM && cat SHBLAT2.SUMM || exit 0
endif
ifeq ($(BUILD_SINGLE),1)
OPENBLAS_NUM_THREADS=2 ./sblat2 < ./sblat2.dat
@$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0
@@ -438,6 +453,12 @@ test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME)
$(CC) $(CLDFLAGS) -DIBFLOAT16 -o test_sbgemv compare_sgemv_sbgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
endif
ifeq ($(BUILD_HFLOAT16),1)
test_shgemv : compare_sgemv_shgemv.c ../$(LIBNAME)
$(CC) $(CLDFLAGS) -o test_shgemv compare_sgemv_shgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
endif
ifeq ($(BUILD_COMPLEX),1)
cblat3_3m : cblat3_3m.$(SUFFIX) ../$(LIBNAME)
$(FC) $(FLDFLAGS) -o cblat3_3m cblat3_3m.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
@@ -454,7 +475,7 @@ clean:
@rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \
sblat1 dblat1 cblat1 zblat1 \
sblat2 dblat2 cblat2 zblat2 \
test_bgemm test_bgemv test_sbgemm test_sbgemv sblat3 dblat3 cblat3 zblat3 \
test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemv sblat3 dblat3 cblat3 zblat3 \
sblat1p dblat1p cblat1p zblat1p \
sblat2p dblat2p cblat2p zblat2p \
sblat3p dblat3p cblat3p zblat3p \

130
test/compare_sgemv_shgemv.c Normal file
View File

@@ -0,0 +1,130 @@
/***************************************************************************
Copyright (c) 2020,2025 The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include <stdio.h>
#include <stdint.h>
#include "../common.h"
#include "test_helpers.h"
#define SGEMV BLASFUNC(sgemv)
#define SHGEMV BLASFUNC(shgemv)
#define SHGEMV_LARGEST 256
int
main (int argc, char *argv[])
{
blasint k;
int i, j, l;
blasint x, y;
int ret = 0;
int loop = SHGEMV_LARGEST;
char transA = 'N';
float alpha = 1.0, beta = 0.0;
for (beta = 0; beta < 3; beta += 1) {
for (alpha = 0; alpha < 3; alpha += 1) {
for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one.
for (x = 1; x <= loop; x++)
{
k = (x == 0) ? 0 : l + 1;
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l);
float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l);
hfloat16 *AA = (hfloat16 *)malloc_safe(x * x * sizeof(hfloat16));
hfloat16 *BB = (hfloat16 *)malloc_safe(x * sizeof(hfloat16) << l);
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l);
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
(DD == NULL) || (CC == NULL))
return 1;
for (j = 0; j < x; j++)
{
for (i = 0; i < x; i++)
{
A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
AA[j * x + i] = (_Float16)A[j * x + i];
}
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
BB[j << l]= (_Float16)B[j << l];
CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
}
for (y = 0; y < 2; y++)
{
if (y == 0) {
transA = 'N';
} else {
transA = 'T';
}
memset(CC, 0, x * sizeof(FLOAT) << l);
memset(DD, 0, x * sizeof(FLOAT));
memset(C, 0, x * sizeof(FLOAT) << l);
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
SHGEMV (&transA, &x, &x, &alpha, (hfloat16*) AA, &x, (hfloat16*) BB, &k, &beta, CC, &k);
for (int i = 0; i < x; i ++) DD[i] *= beta;
for (j = 0; j < x; j++)
for (i = 0; i < x; i++)
if (transA == 'N') {
DD[i] += alpha * (float)(AA[j * x + i]) * (float)(BB[j << l]);
} else if (transA == 'T') {
DD[j] += alpha * (float)(AA[j * x + i]) * (float)(BB[i << l]);
}
for (j = 0; j < x; j++) {
if (!is_close(CC[j << l], C[j << l], 0.01, 0.001)) {
ret++;
}
if (!is_close(CC[j << l], DD[j], 0.001, 0.0001)) {
ret++;
}
}
}
free(A);
free(B);
free(C);
free(AA);
free(BB);
free(DD);
free(CC);
} // x
} // l
} // alpha
} // beta
if (ret != 0) {
fprintf (stderr, "SHGEMV FAILURES: %d\n", ret);
return 1;
}
return ret;
}