mirror of
https://github.com/OpenMathLib/OpenBLAS
synced 2026-05-31 00:45:48 +08:00
Merge pull request #5512 from quic/topic/ssyrk_direct_sme1
Some checks failed
apple m / build (cmake, gfortran, 0, 0) (push) Has been cancelled
apple m / build (cmake, gfortran, 0, 1) (push) Has been cancelled
apple m / build (cmake, gfortran, 1, 0) (push) Has been cancelled
apple m / build (cmake, gfortran, 1, 1) (push) Has been cancelled
apple m / build (make, gfortran, 0, 0) (push) Has been cancelled
apple m / build (make, gfortran, 0, 1) (push) Has been cancelled
apple m / build (make, gfortran, 1, 0) (push) Has been cancelled
apple m / build (make, gfortran, 1, 1) (push) Has been cancelled
arm64 graviton cirun / build (cmake, gfortran) (push) Has been cancelled
arm64 graviton cirun / build (make, gfortran) (push) Has been cancelled
c910v qemu test / TEST (riscv64-linux-gnu, NO_SHARED=1 TARGET=C910V, C910V, riscv64-unknown-linux-gnu) (push) Has been cancelled
c910v qemu test / TEST (riscv64-linux-gnu, NO_SHARED=1 TARGET=RISCV64_GENERIC, RISCV64_GENERIC, riscv64-linux-gnu) (push) Has been cancelled
Run codspeed benchmarks / benchmarks (make, gfortran, ubuntu-22.04, 3.12) (push) Has been cancelled
Publish docs via GitHub Pages / Deploy docs (push) Has been cancelled
continuous build / build (cmake, flang, ubuntu-latest) (push) Has been cancelled
continuous build / build (cmake, gfortran, macos-latest) (push) Has been cancelled
continuous build / build (cmake, gfortran, ubuntu-latest) (push) Has been cancelled
continuous build / build (make, flang, ubuntu-latest) (push) Has been cancelled
continuous build / build (make, gfortran, macos-latest) (push) Has been cancelled
continuous build / build (make, gfortran, ubuntu-latest) (push) Has been cancelled
continuous build / msys2 (None, fc, int32, UCRT64, mingw-w64-ucrt-x86_64) (push) Has been cancelled
continuous build / msys2 (Release, fc, int32, CLANG64, mingw-w64-clang-x86_64) (push) Has been cancelled
continuous build / msys2 (Release, fc, int32, MINGW32, mingw-w64-i686) (push) Has been cancelled
continuous build / msys2 (Release, fc, int32, UCRT64, mingw-w64-ucrt-x86_64) (push) Has been cancelled
continuous build / msys2 (Release, fc, int64, -DBINARY=64 -DINTERFACE64=1, CLANG64, mingw-w64-clang-x86_64) (push) Has been cancelled
continuous build / msys2 (Release, fc, int64, -DBINARY=64 -DINTERFACE64=1, UCRT64, mingw-w64-ucrt-x86_64) (push) Has been cancelled
continuous build / cross_build (DYNAMIC_ARCH=1 TARGET=GENERIC, mips64el, mips64el-linux-gnuabi64) (push) Has been cancelled
continuous build / cross_build (TARGET=EV4, alpha, alpha-linux-gnu) (push) Has been cancelled
continuous build / cross_build (TARGET=MIPS1004K, mipsel, mipsel-linux-gnu) (push) Has been cancelled
continuous build / cross_build (TARGET=RISCV64_GENERIC, riscv64, riscv64-linux-gnu) (push) Has been cancelled
continuous build / neoverse_build (push) Has been cancelled
harmonyos / build (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=GENERIC, DYNAMIC_ARCH, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LA264, LA264, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LA464, LA464, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LA64_GENERIC, LA64_GENERIC, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LOONGSON2K1000, LOONGSON2K1000, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LOONGSON3R5, LOONGSON3R5, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LOONGSONGENERIC, LOONGSONGENERIC, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=GENERIC, DYNAMIC_ARCH) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LA264, LA264) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LA464, LA464) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LA64_GENERIC, LA64_GENERIC) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LOONGSON2K1000, LOONGSON2K1000) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LOONGSON3R5, LOONGSON3R5) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LOONGSONGENERIC, LOONGSONGENERIC) (push) Has been cancelled
mips64 qemu test / TEST (NO_SHARED=1 TARGET=I6400, I6400, mipsisa64r6el-linux-gnuabi64) (push) Has been cancelled
mips64 qemu test / TEST (NO_SHARED=1 TARGET=I6500, I6500, mipsisa64r6el-linux-gnuabi64) (push) Has been cancelled
mips64 qemu test / TEST (NO_SHARED=1 TARGET=MIPS64_GENERIC, MIPS64_GENERIC, mips64el-linux-gnuabi64) (push) Has been cancelled
mips64 qemu test / TEST (NO_SHARED=1 TARGET=P6600, P6600, mipsisa64r6el-linux-gnuabi64) (push) Has been cancelled
mips64 qemu test / TEST (NO_SHARED=1 TARGET=SICORTEX, SICORTEX, mips64el-linux-gnuabi64) (push) Has been cancelled
riscv64 zvl256b qemu test / TEST (TARGET=RISCV64_GENERIC BINARY=64 ARCH=riscv64 DYNAMIC_ARCH=1, rv64,g=true,c=true,v=true,vext_spec=v1.0,vlen=256,elen=64, DYNAMIC_ARCH=1) (push) Has been cancelled
riscv64 zvl256b qemu test / TEST (TARGET=RISCV64_ZVL128B BINARY=64 ARCH=riscv64, rv64,g=true,c=true,v=true,vext_spec=v1.0,vlen=128,elen=64, RISCV64_ZVL128B) (push) Has been cancelled
riscv64 zvl256b qemu test / TEST (TARGET=RISCV64_ZVL256B BINARY=64 ARCH=riscv64 BUILD_BFLOAT16=1 BUILD_HFLOAT16=1, rv64,g=true,c=true,v=true,vext_spec=v1.0,vlen=256,elen=64,zfh=true,zvfh=true,zvfbfwma=true, RISCV64_ZVL256B) (push) Has been cancelled
Windows ARM64 CI / build (push) Has been cancelled
Nightly-Homebrew-Build / build-OpenBLAS-with-Homebrew (push) Has been cancelled
Some checks failed
apple m / build (cmake, gfortran, 0, 0) (push) Has been cancelled
apple m / build (cmake, gfortran, 0, 1) (push) Has been cancelled
apple m / build (cmake, gfortran, 1, 0) (push) Has been cancelled
apple m / build (cmake, gfortran, 1, 1) (push) Has been cancelled
apple m / build (make, gfortran, 0, 0) (push) Has been cancelled
apple m / build (make, gfortran, 0, 1) (push) Has been cancelled
apple m / build (make, gfortran, 1, 0) (push) Has been cancelled
apple m / build (make, gfortran, 1, 1) (push) Has been cancelled
arm64 graviton cirun / build (cmake, gfortran) (push) Has been cancelled
arm64 graviton cirun / build (make, gfortran) (push) Has been cancelled
c910v qemu test / TEST (riscv64-linux-gnu, NO_SHARED=1 TARGET=C910V, C910V, riscv64-unknown-linux-gnu) (push) Has been cancelled
c910v qemu test / TEST (riscv64-linux-gnu, NO_SHARED=1 TARGET=RISCV64_GENERIC, RISCV64_GENERIC, riscv64-linux-gnu) (push) Has been cancelled
Run codspeed benchmarks / benchmarks (make, gfortran, ubuntu-22.04, 3.12) (push) Has been cancelled
Publish docs via GitHub Pages / Deploy docs (push) Has been cancelled
continuous build / build (cmake, flang, ubuntu-latest) (push) Has been cancelled
continuous build / build (cmake, gfortran, macos-latest) (push) Has been cancelled
continuous build / build (cmake, gfortran, ubuntu-latest) (push) Has been cancelled
continuous build / build (make, flang, ubuntu-latest) (push) Has been cancelled
continuous build / build (make, gfortran, macos-latest) (push) Has been cancelled
continuous build / build (make, gfortran, ubuntu-latest) (push) Has been cancelled
continuous build / msys2 (None, fc, int32, UCRT64, mingw-w64-ucrt-x86_64) (push) Has been cancelled
continuous build / msys2 (Release, fc, int32, CLANG64, mingw-w64-clang-x86_64) (push) Has been cancelled
continuous build / msys2 (Release, fc, int32, MINGW32, mingw-w64-i686) (push) Has been cancelled
continuous build / msys2 (Release, fc, int32, UCRT64, mingw-w64-ucrt-x86_64) (push) Has been cancelled
continuous build / msys2 (Release, fc, int64, -DBINARY=64 -DINTERFACE64=1, CLANG64, mingw-w64-clang-x86_64) (push) Has been cancelled
continuous build / msys2 (Release, fc, int64, -DBINARY=64 -DINTERFACE64=1, UCRT64, mingw-w64-ucrt-x86_64) (push) Has been cancelled
continuous build / cross_build (DYNAMIC_ARCH=1 TARGET=GENERIC, mips64el, mips64el-linux-gnuabi64) (push) Has been cancelled
continuous build / cross_build (TARGET=EV4, alpha, alpha-linux-gnu) (push) Has been cancelled
continuous build / cross_build (TARGET=MIPS1004K, mipsel, mipsel-linux-gnu) (push) Has been cancelled
continuous build / cross_build (TARGET=RISCV64_GENERIC, riscv64, riscv64-linux-gnu) (push) Has been cancelled
continuous build / neoverse_build (push) Has been cancelled
harmonyos / build (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=GENERIC, DYNAMIC_ARCH, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LA264, LA264, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LA464, LA464, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LA64_GENERIC, LA64_GENERIC, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LOONGSON2K1000, LOONGSON2K1000, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LOONGSON3R5, LOONGSON3R5, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LOONGSONGENERIC, LOONGSONGENERIC, loongarch64-linux-gnu) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=GENERIC, DYNAMIC_ARCH) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LA264, LA264) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LA464, LA464) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LA64_GENERIC, LA64_GENERIC) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LOONGSON2K1000, LOONGSON2K1000) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LOONGSON3R5, LOONGSON3R5) (push) Has been cancelled
loongarch64 clang qemu test / TEST (NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=LOONGSONGENERIC, LOONGSONGENERIC) (push) Has been cancelled
mips64 qemu test / TEST (NO_SHARED=1 TARGET=I6400, I6400, mipsisa64r6el-linux-gnuabi64) (push) Has been cancelled
mips64 qemu test / TEST (NO_SHARED=1 TARGET=I6500, I6500, mipsisa64r6el-linux-gnuabi64) (push) Has been cancelled
mips64 qemu test / TEST (NO_SHARED=1 TARGET=MIPS64_GENERIC, MIPS64_GENERIC, mips64el-linux-gnuabi64) (push) Has been cancelled
mips64 qemu test / TEST (NO_SHARED=1 TARGET=P6600, P6600, mipsisa64r6el-linux-gnuabi64) (push) Has been cancelled
mips64 qemu test / TEST (NO_SHARED=1 TARGET=SICORTEX, SICORTEX, mips64el-linux-gnuabi64) (push) Has been cancelled
riscv64 zvl256b qemu test / TEST (TARGET=RISCV64_GENERIC BINARY=64 ARCH=riscv64 DYNAMIC_ARCH=1, rv64,g=true,c=true,v=true,vext_spec=v1.0,vlen=256,elen=64, DYNAMIC_ARCH=1) (push) Has been cancelled
riscv64 zvl256b qemu test / TEST (TARGET=RISCV64_ZVL128B BINARY=64 ARCH=riscv64, rv64,g=true,c=true,v=true,vext_spec=v1.0,vlen=128,elen=64, RISCV64_ZVL128B) (push) Has been cancelled
riscv64 zvl256b qemu test / TEST (TARGET=RISCV64_ZVL256B BINARY=64 ARCH=riscv64 BUILD_BFLOAT16=1 BUILD_HFLOAT16=1, rv64,g=true,c=true,v=true,vext_spec=v1.0,vlen=256,elen=64,zfh=true,zvfh=true,zvfbfwma=true, RISCV64_ZVL256B) (push) Has been cancelled
Windows ARM64 CI / build (push) Has been cancelled
Nightly-Homebrew-Build / build-OpenBLAS-with-Homebrew (push) Has been cancelled
Support for SME1 based ssyrk_direct kernel for cblas_ssyrk level 3 API
This commit is contained in:
@@ -89,6 +89,27 @@ void strmm_direct_LTLN(BLASLONG M, BLASLONG N,
|
||||
float * A, BLASLONG strideA,
|
||||
float * B, BLASLONG strideB);
|
||||
|
||||
void ssyrk_direct_alpha_betaUN(BLASLONG N, BLASLONG K,
|
||||
float alpha,
|
||||
float * A, BLASLONG strideA,
|
||||
float beta,
|
||||
float * C, BLASLONG strideC);
|
||||
void ssyrk_direct_alpha_betaUT(BLASLONG N, BLASLONG K,
|
||||
float alpha,
|
||||
float * A, BLASLONG strideA,
|
||||
float beta,
|
||||
float * C, BLASLONG strideC);
|
||||
void ssyrk_direct_alpha_betaLN(BLASLONG N, BLASLONG K,
|
||||
float alpha,
|
||||
float * A, BLASLONG strideA,
|
||||
float beta,
|
||||
float * C, BLASLONG strideC);
|
||||
void ssyrk_direct_alpha_betaLT(BLASLONG N, BLASLONG K,
|
||||
float alpha,
|
||||
float * A, BLASLONG strideA,
|
||||
float beta,
|
||||
float * C, BLASLONG strideC);
|
||||
|
||||
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
|
||||
|
||||
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
|
||||
|
||||
@@ -264,6 +264,10 @@ int (*shgemv_t) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BL
|
||||
void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
|
||||
void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
|
||||
void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
|
||||
void (*ssyrk_direct_alpha_betaUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
|
||||
void (*ssyrk_direct_alpha_betaUT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
|
||||
void (*ssyrk_direct_alpha_betaLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
|
||||
void (*ssyrk_direct_alpha_betaLT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
@@ -56,6 +56,10 @@
|
||||
#define STRMM_DIRECT_LNLN strmm_direct_LNLN
|
||||
#define STRMM_DIRECT_LTUN strmm_direct_LTUN
|
||||
#define STRMM_DIRECT_LTLN strmm_direct_LTLN
|
||||
#define SSYRK_DIRECT_ALPHA_BETA_UN ssyrk_direct_alpha_betaUN
|
||||
#define SSYRK_DIRECT_ALPHA_BETA_UT ssyrk_direct_alpha_betaUT
|
||||
#define SSYRK_DIRECT_ALPHA_BETA_LN ssyrk_direct_alpha_betaLN
|
||||
#define SSYRK_DIRECT_ALPHA_BETA_LT ssyrk_direct_alpha_betaLT
|
||||
|
||||
#define SGEMM_ONCOPY sgemm_oncopy
|
||||
#define SGEMM_OTCOPY sgemm_otcopy
|
||||
@@ -232,6 +236,10 @@
|
||||
#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN
|
||||
#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN
|
||||
#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN
|
||||
#define SSYRK_DIRECT_ALPHA_BETA_UN gotoblas -> ssyrk_direct_alpha_betaUN
|
||||
#define SSYRK_DIRECT_ALPHA_BETA_UT gotoblas -> ssyrk_direct_alpha_betaUT
|
||||
#define SSYRK_DIRECT_ALPHA_BETA_LN gotoblas -> ssyrk_direct_alpha_betaLN
|
||||
#define SSYRK_DIRECT_ALPHA_BETA_LT gotoblas -> ssyrk_direct_alpha_betaLT
|
||||
#endif
|
||||
|
||||
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy
|
||||
|
||||
@@ -338,6 +338,23 @@ double NNK;
|
||||
BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
|
||||
return;
|
||||
}
|
||||
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
|
||||
#if defined(ARCH_ARM64) && (defined(USE_SSYRK_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
|
||||
#if defined(DYNAMIC_ARCH)
|
||||
if (support_sme1())
|
||||
#endif
|
||||
if (args.n == 0) return;
|
||||
if (order == CblasRowMajor && n == ldc) {
|
||||
if (Trans == CblasNoTrans && k == lda) {
|
||||
(Uplo == CblasUpper ? SSYRK_DIRECT_ALPHA_BETA_UN : SSYRK_DIRECT_ALPHA_BETA_LN)(n, k, alpha, a, lda, beta, c, ldc);
|
||||
return;
|
||||
} else if (Trans == CblasTrans && n == lda){
|
||||
(Uplo == CblasUpper ? SSYRK_DIRECT_ALPHA_BETA_UT : SSYRK_DIRECT_ALPHA_BETA_LT)(n, k, alpha, a, lda, beta, c, ldc);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
@@ -245,6 +245,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
|
||||
if (ARM64)
|
||||
set(USE_DIRECT_STRMM true)
|
||||
endif()
|
||||
set(USE_DIRECT_SSYRK false)
|
||||
if (ARM64)
|
||||
set(USE_DIRECT_SSYRK true)
|
||||
endif()
|
||||
set(USE_DIRECT_SGEMM false)
|
||||
if (X86_64 OR ARM64)
|
||||
set(USE_DIRECT_SGEMM true)
|
||||
@@ -297,6 +301,16 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
if (USE_DIRECT_SSYRK)
|
||||
if (ARM64)
|
||||
set (SSYRKDIRECTKERNEL_ALPHA_BETA ssyrk_direct_alpha_beta_arm64_sme1.c)
|
||||
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaUN" false "" "" false SINGLE)
|
||||
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaUT" false "" "" false SINGLE)
|
||||
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaLN" false "" "" false SINGLE)
|
||||
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaLT" false "" "" false SINGLE)
|
||||
endif ()
|
||||
endif()
|
||||
|
||||
foreach (float_type SINGLE DOUBLE)
|
||||
string(SUBSTRING ${float_type} 0 1 float_char)
|
||||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})
|
||||
|
||||
@@ -54,6 +54,7 @@ USE_TRMM = 1
|
||||
USE_DIRECT_SGEMM = 1
|
||||
USE_DIRECT_SSYMM = 1
|
||||
USE_DIRECT_STRMM = 1
|
||||
USE_DIRECT_SSYRK = 1
|
||||
endif
|
||||
|
||||
ifeq ($(ARCH), riscv64)
|
||||
@@ -161,6 +162,16 @@ endif
|
||||
endif
|
||||
endif
|
||||
|
||||
ifdef USE_DIRECT_SSYRK
|
||||
ifndef SSYRKDIRECTKERNEL_ALPHA_BETA
|
||||
ifeq ($(ARCH), arm64)
|
||||
ifeq ($(TARGET_CORE), ARMV9SME)
|
||||
HAVE_SME = 1
|
||||
endif
|
||||
SSYRKDIRECTKERNEL_ALPHA_BETA = ssyrk_direct_alpha_beta_arm64_sme1.c
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_BFLOAT16), 1)
|
||||
ifndef BGEMMKERNEL
|
||||
@@ -261,6 +272,14 @@ SKERNELOBJS += \
|
||||
endif
|
||||
endif
|
||||
|
||||
ifdef USE_DIRECT_SSYRK
|
||||
ifeq ($(ARCH), arm64)
|
||||
SKERNELOBJS += \
|
||||
ssyrk_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) ssyrk_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) \
|
||||
ssyrk_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX)
|
||||
endif
|
||||
endif
|
||||
|
||||
ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
|
||||
DKERNELOBJS += \
|
||||
dgemm_beta$(TSUFFIX).$(SUFFIX) \
|
||||
@@ -1158,6 +1177,21 @@ $(KDIR)xgemm_kernel_r$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMMKERNEL) $(XGEMMD
|
||||
$(KDIR)xgemm_kernel_b$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMMKERNEL) $(XGEMMDEPEND)
|
||||
$(CC) $(CFLAGS) -c -DXDOUBLE -DCOMPLEX -DCC $< -o $@
|
||||
|
||||
ifdef USE_DIRECT_SSYRK
|
||||
ifeq ($(ARCH), arm64)
|
||||
$(KDIR)ssyrk_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
|
||||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -UTRANSA $< -o $@
|
||||
|
||||
$(KDIR)ssyrk_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
|
||||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -DTRANSA $< -o $@
|
||||
|
||||
$(KDIR)ssyrk_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
|
||||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -UTRANSA $< -o $@
|
||||
|
||||
$(KDIR)ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
|
||||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -DTRANSA $< -o $@
|
||||
endif
|
||||
endif
|
||||
|
||||
ifdef USE_TRMM
|
||||
$(KDIR)strmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMKERNEL)
|
||||
|
||||
250
kernel/arm64/ssyrk_direct_alpha_beta_arm64_sme1.c
Normal file
250
kernel/arm64/ssyrk_direct_alpha_beta_arm64_sme1.c
Normal file
@@ -0,0 +1,250 @@
|
||||
/*
|
||||
Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
||||
SPDX-License-Identifier: BSD-3-Clause-Clear
|
||||
*/
|
||||
|
||||
#include "common.h"
|
||||
#include <stdlib.h>
|
||||
#include <inttypes.h>
|
||||
#include <math.h>
|
||||
#if defined(HAVE_SME)
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16
|
||||
#include <arm_sme.h>
|
||||
#endif
|
||||
|
||||
/* Function prototypes */
|
||||
extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\
|
||||
const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess");
|
||||
|
||||
/* Function Definitions */
|
||||
static uint64_t sve_cntw() {
|
||||
uint64_t cnt;
|
||||
asm volatile(
|
||||
"rdsvl %[res], #1\n"
|
||||
"lsr %[res], %[res], #2\n"
|
||||
: [res] "=r" (cnt) ::
|
||||
);
|
||||
return cnt;
|
||||
}
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16
|
||||
// Outer product kernel.
|
||||
// Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA.
|
||||
__attribute__((always_inline)) inline void
|
||||
kernel_2x2(const float *A, float *B, float *C, size_t shared_dim,
|
||||
size_t ldc, size_t block_rows, size_t block_cols, float alpha,
|
||||
float beta, uint64_t row_idx, uint64_t col_idx)
|
||||
__arm_out("za") __arm_streaming {
|
||||
|
||||
const uint64_t svl = svcntw();
|
||||
size_t ldb = ldc;
|
||||
// Predicate set-up
|
||||
svbool_t pg = svptrue_b32();
|
||||
svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows);
|
||||
svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows);
|
||||
|
||||
svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols);
|
||||
svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols);
|
||||
|
||||
#define pg_c_0 pg_b_0
|
||||
#define pg_c_1 pg_b_1
|
||||
|
||||
svzero_za();
|
||||
svfloat32_t beta_vec = svdup_f32(beta);
|
||||
|
||||
// Load C to ZA
|
||||
for (size_t i = 0; i < MIN(svl, block_rows); i++) {
|
||||
svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]);
|
||||
row_c_0 = svmul_x(pg, beta_vec, row_c_0);
|
||||
svwrite_hor_za32_f32_m(/*tile*/0, /*slice*/i, pg_c_0, row_c_0);
|
||||
|
||||
svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]);
|
||||
row_c_1 = svmul_x(pg, beta_vec, row_c_1);
|
||||
svwrite_hor_za32_f32_m(/*tile*/1, /*slice*/i, pg_c_1, row_c_1);
|
||||
}
|
||||
for (size_t i = svl; i < block_rows; i++) {
|
||||
svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]);
|
||||
row_c_0 = svmul_x(pg, beta_vec, row_c_0);
|
||||
svwrite_hor_za32_f32_m(/*tile*/2, /*slice*/i, pg_c_0, row_c_0);
|
||||
|
||||
svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]);
|
||||
row_c_1 = svmul_x(pg, beta_vec, row_c_1);
|
||||
svwrite_hor_za32_f32_m(/*tile*/3, /*slice*/i, pg_c_1, row_c_1);
|
||||
}
|
||||
|
||||
svfloat32_t alpha_vec = svdup_f32(alpha);
|
||||
// Iterate through shared dimension (K)
|
||||
for (size_t k = 0; k < shared_dim; k++) {
|
||||
#if !defined(TRANSA)
|
||||
// Load column of A
|
||||
svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]);
|
||||
col_a_0 = svmul_x(pg, alpha_vec, col_a_0);
|
||||
svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]);
|
||||
col_a_1 = svmul_x(pg, alpha_vec, col_a_1);
|
||||
|
||||
// Load row of A**T
|
||||
svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * svl]);
|
||||
svfloat32_t row_b_1 = svld1(pg_b_1, &B[(k + shared_dim) * svl]);
|
||||
#else
|
||||
// Load column of A**T
|
||||
svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * ldb]);
|
||||
col_a_0 = svmul_x(pg, alpha_vec, col_a_0);
|
||||
|
||||
svfloat32_t col_a_1 = svld1(pg_a_1, &A[k * ldb + svl]);
|
||||
col_a_1 = svmul_x(pg, alpha_vec, col_a_1);
|
||||
|
||||
// Load row of A
|
||||
svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * ldb]);
|
||||
svfloat32_t row_b_1 = svld1(pg_b_1, &B[k * ldb + svl]);
|
||||
#endif
|
||||
// Perform outer product
|
||||
svmopa_za32_m(/*tile*/0, pg, pg, col_a_0, row_b_0);
|
||||
svmopa_za32_m(/*tile*/1, pg, pg, col_a_0, row_b_1);
|
||||
svmopa_za32_m(/*tile*/2, pg, pg, col_a_1, row_b_0);
|
||||
svmopa_za32_m(/*tile*/3, pg, pg, col_a_1, row_b_1);
|
||||
}
|
||||
|
||||
#if defined(UPPER)
|
||||
#define pg_c_0_full pg_c_0
|
||||
#define pg_c_1_full pg_c_1
|
||||
|
||||
bool need_update_pg_b = true;
|
||||
size_t last_invalid_index = col_idx - row_idx;
|
||||
// For Upper, If col_idx - row_idx >= 2*svl, we don't need to update the predicate due to all elements above the digonal
|
||||
if (col_idx - row_idx >= 2*svl) {
|
||||
need_update_pg_b = false;
|
||||
}
|
||||
// Store to C from ZA
|
||||
for (size_t i = 0; i < MIN(svl, block_rows); i++, last_invalid_index++) {
|
||||
if (need_update_pg_b) {
|
||||
pg_c_0 = svnot_b_z(pg_c_0_full, svwhilelt_b32_u64(0, last_invalid_index));
|
||||
pg_c_1 = svnot_b_z(pg_c_1_full, svwhilelt_b32_u64(svl, last_invalid_index));
|
||||
}
|
||||
|
||||
svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]);
|
||||
svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
|
||||
}
|
||||
for (size_t i = svl; i < block_rows; i++,last_invalid_index++) {
|
||||
if (need_update_pg_b) {
|
||||
pg_c_0 = svnot_b_z(pg_c_0_full, svwhilelt_b32_u64(0, last_invalid_index));
|
||||
pg_c_1 = svnot_b_z(pg_c_1_full, svwhilelt_b32_u64(svl, last_invalid_index));
|
||||
}
|
||||
svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]);
|
||||
svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
|
||||
}
|
||||
#else
|
||||
// Store to C from ZA
|
||||
size_t valid_index = row_idx - col_idx + 1;
|
||||
for (size_t i = 0; i < MIN(svl, block_rows); i++, valid_index++) {
|
||||
pg_c_0 = svwhilelt_b32_u64(0, MIN(valid_index, block_cols));
|
||||
pg_c_1 = svwhilelt_b32_u64(svl, MIN(valid_index, block_cols));
|
||||
svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]);
|
||||
svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
|
||||
}
|
||||
for (size_t i = svl; i < block_rows; i++, valid_index++) {
|
||||
pg_c_0 = svwhilelt_b32_u64(0, MIN(valid_index, block_cols));
|
||||
pg_c_1 = svwhilelt_b32_u64(svl, MIN(valid_index, block_cols));
|
||||
svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]);
|
||||
svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__arm_new("za") __arm_locally_streaming
|
||||
static void ssyrk_direct_sme1_2VLx2VL(uint64_t n, uint64_t k, const float* alpha,\
|
||||
const float *ba, const float* beta, float *restrict bc) {
|
||||
const uint64_t num_rows = n;
|
||||
const uint64_t num_cols = n;
|
||||
|
||||
const float *restrict a_ptr = ba;
|
||||
const float *restrict b_ptr = ba;
|
||||
float *restrict c_ptr = bc;
|
||||
|
||||
const uint64_t svl = svcntw();
|
||||
const uint64_t ldc = n;
|
||||
|
||||
// Block over rows of C (panels of A)
|
||||
uint64_t row_idx = 0;
|
||||
|
||||
// 2x2 loop
|
||||
uint64_t row_batch = 2*svl;
|
||||
|
||||
// Block over row dimension of C
|
||||
for (; row_idx < num_rows; row_idx += row_batch) {
|
||||
row_batch = MIN(row_batch, num_rows - row_idx);
|
||||
uint64_t col_batch = 2*svl;
|
||||
#if defined(UPPER)
|
||||
// for UPLO is upper, Start from column col_idx = rows_index to ensure we only process the upper triangle (col_idx >= rows_index)
|
||||
for (uint64_t col_idx = row_idx; col_idx < num_cols; col_idx += col_batch) {
|
||||
col_batch = MIN(col_batch, num_cols - col_idx);
|
||||
#else
|
||||
// for UPLO is lower, we only process the lower triangle part (col_idx <= row_idxx)
|
||||
for (uint64_t col_idx = 0; col_idx < num_cols && col_idx <= row_idx; col_idx += col_batch) {
|
||||
#endif
|
||||
col_batch = MIN(col_batch, num_cols - col_idx);
|
||||
#if !defined(TRANSA)
|
||||
kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx * k],
|
||||
&c_ptr[row_idx * ldc + col_idx], k,
|
||||
ldc, row_batch, col_batch, *alpha, *beta, row_idx, col_idx);
|
||||
#else
|
||||
kernel_2x2(&a_ptr[row_idx], &b_ptr[col_idx],
|
||||
&c_ptr[row_idx * ldc + col_idx], k,
|
||||
ldc, row_batch, col_batch, *alpha, *beta, row_idx, col_idx);
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
#else
|
||||
static void ssyrk_direct_sme1_2VLx2VL(uint64_t n, uint64_t k, const float* alpha,\
|
||||
const float *ba, const float* beta, float *restrict bc){}
|
||||
#endif
|
||||
|
||||
void CNAME (BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
|
||||
BLASLONG strideA, float beta, float * __restrict C, BLASLONG strideC){
|
||||
#if !defined(TRANSA)
|
||||
uint64_t n_mod, vl_elms;
|
||||
|
||||
vl_elms = sve_cntw();
|
||||
|
||||
n_mod = ceil((double)N/(double)vl_elms) * vl_elms;
|
||||
|
||||
float *A_mod = (float *) malloc(n_mod*K*sizeof(float));
|
||||
|
||||
/* Prevent compiler optimization by reading from memory instead
|
||||
* of reading directly from vector (z) registers.
|
||||
* */
|
||||
asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7",
|
||||
"p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15",
|
||||
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7",
|
||||
"z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15",
|
||||
"z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23",
|
||||
"z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");
|
||||
|
||||
/* Pre-process the left matrix to make it suitable for
|
||||
matrix sum of outer-product calculation
|
||||
*/
|
||||
sgemm_direct_sme1_preprocess(N, K, A, A_mod);
|
||||
asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7",
|
||||
"p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15",
|
||||
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7",
|
||||
"z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15",
|
||||
"z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23",
|
||||
"z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");
|
||||
ssyrk_direct_sme1_2VLx2VL(N, K, &alpha, A_mod, &beta, C);
|
||||
free(A_mod);
|
||||
#else
|
||||
ssyrk_direct_sme1_2VLx2VL(N, K, &alpha, A, &beta, C);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void CNAME (BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
|
||||
BLASLONG strideA, float beta, float * __restrict C, BLASLONG strideC){}
|
||||
|
||||
#endif
|
||||
@@ -223,6 +223,10 @@ gotoblas_t TABLE_NAME = {
|
||||
strmm_direct_LNLNTS,
|
||||
strmm_direct_LTUNTS,
|
||||
strmm_direct_LTLNTS,
|
||||
ssyrk_direct_alpha_betaUNTS,
|
||||
ssyrk_direct_alpha_betaUTTS,
|
||||
ssyrk_direct_alpha_betaLNTS,
|
||||
ssyrk_direct_alpha_betaLTTS,
|
||||
#endif
|
||||
|
||||
sgemm_kernelTS, sgemm_betaTS,
|
||||
|
||||
1
param.h
1
param.h
@@ -3868,6 +3868,7 @@ Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy rout
|
||||
#define USE_SGEMM_KERNEL_DIRECT 1
|
||||
#define USE_SSYMM_KERNEL_DIRECT 1
|
||||
#define USE_STRMM_KERNEL_DIRECT 1
|
||||
#define USE_SSYRK_KERNEL_DIRECT 1
|
||||
#endif /* ARMv9 SME */
|
||||
|
||||
#if defined(ARMV5)
|
||||
|
||||
Reference in New Issue
Block a user