mirror of
https://github.com/OpenMathLib/OpenBLAS
synced 2026-06-15 07:51:43 +08:00
RFC : Add half precision gemm for bfloat16 in OpenBLAS
This patch adds support for bfloat16 data type matrix multiplication kernel. For architectures that don't support bfloat16, it is defined as unsigned short (2 bytes). Default unroll sizes can be changed as per architecture as done for SGEMM and for now 8 and 4 are used for M and N. Size of ncopy/tcopy can be changed as per architecture requirement and for now, size 2 is used. Added shgemm in kernel/power/KERNEL.POWER9 and tested in powerpc64le and powerpc64. For reference, added a small test compare_sgemm_shgemm.c to compare sgemm and shgemm output. This patch does not cover OpenBLAS test, benchmark and lapack tests for shgemm. Complex type implementation can be discussed and added once this is approved.
This commit is contained in:
@@ -59,6 +59,10 @@ ifeq ($(CORE), Z14)
|
||||
USE_TRMM = 1
|
||||
endif
|
||||
|
||||
SHKERNELOBJS += \
|
||||
shgemm_kernel$(TSUFFIX).$(SUFFIX) \
|
||||
$(SHGEMMINCOPYOBJ) $(SHGEMMITCOPYOBJ) \
|
||||
$(SHGEMMONCOPYOBJ) $(SHGEMMOTCOPYOBJ)
|
||||
|
||||
SKERNELOBJS += \
|
||||
sgemm_kernel$(TSUFFIX).$(SUFFIX) \
|
||||
@@ -93,6 +97,7 @@ XKERNELOBJS += \
|
||||
$(XGEMMINCOPYOBJ) $(XGEMMITCOPYOBJ) \
|
||||
$(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ)
|
||||
|
||||
SHBLASOBJS += $(SHKERNELOBJS)
|
||||
SBLASOBJS += $(SKERNELOBJS)
|
||||
DBLASOBJS += $(DKERNELOBJS)
|
||||
QBLASOBJS += $(QKERNELOBJS)
|
||||
@@ -100,6 +105,7 @@ CBLASOBJS += $(CKERNELOBJS)
|
||||
ZBLASOBJS += $(ZKERNELOBJS)
|
||||
XBLASOBJS += $(XKERNELOBJS)
|
||||
|
||||
SHBLASOBJS += shgemm_beta$(TSUFFIX).$(SUFFIX)
|
||||
SBLASOBJS += \
|
||||
sgemm_beta$(TSUFFIX).$(SUFFIX) \
|
||||
strmm_kernel_LN$(TSUFFIX).$(SUFFIX) strmm_kernel_LT$(TSUFFIX).$(SUFFIX) \
|
||||
@@ -390,6 +396,10 @@ ZBLASOBJS += \
|
||||
zgeadd_k$(TSUFFIX).$(SUFFIX)
|
||||
|
||||
|
||||
SHGEMMINCOPYOBJ_P = $(SHGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
|
||||
SHGEMMITCOPYOBJ_P = $(SHGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
|
||||
SHGEMMONCOPYOBJ_P = $(SHGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
|
||||
SHGEMMOTCOPYOBJ_P = $(SHGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
|
||||
SGEMMINCOPYOBJ_P = $(SGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
|
||||
SGEMMITCOPYOBJ_P = $(SGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
|
||||
SGEMMONCOPYOBJ_P = $(SGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
|
||||
@@ -415,6 +425,9 @@ XGEMMITCOPYOBJ_P = $(XGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
|
||||
XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
|
||||
XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
|
||||
|
||||
$(KDIR)shgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA)
|
||||
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
$(KDIR)sgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_BETA)
|
||||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
@@ -433,6 +446,36 @@ $(KDIR)zgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_BETA)
|
||||
$(KDIR)xgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMM_BETA)
|
||||
$(CC) $(CFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@
|
||||
|
||||
$(KDIR)$(SHGEMMONCOPYOBJ) : $(KERNELDIR)/$(SHGEMMONCOPY)
|
||||
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
$(KDIR)$(SHGEMMOTCOPYOBJ) : $(KERNELDIR)/$(SHGEMMOTCOPY)
|
||||
ifeq ($(OS), AIX)
|
||||
$(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemmotcopy.s
|
||||
m4 shgemmotcopy.s > shgemmotcopy_nomacros.s
|
||||
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemmotcopy_nomacros.s -o $@
|
||||
rm shgemmotcopy.s shgemmotcopy_nomacros.s
|
||||
else
|
||||
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
|
||||
endif
|
||||
|
||||
ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))
|
||||
|
||||
$(KDIR)$(SHGEMMINCOPYOBJ) : $(KERNELDIR)/$(SHGEMMINCOPY)
|
||||
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
$(KDIR)$(SHGEMMITCOPYOBJ) : $(KERNELDIR)/$(SHGEMMITCOPY)
|
||||
ifeq ($(OS), AIX)
|
||||
$(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemmitcopy.s
|
||||
m4 shgemmitcopy.s > shgemmitcopy_nomacros.s
|
||||
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemmitcopy_nomacros.s -o $@
|
||||
rm shgemmitcopy.s shgemmitcopy_nomacros.s
|
||||
else
|
||||
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
|
||||
endif
|
||||
|
||||
endif
|
||||
|
||||
$(KDIR)$(SGEMMONCOPYOBJ) : $(KERNELDIR)/$(SGEMMONCOPY)
|
||||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
@@ -590,6 +633,16 @@ else
|
||||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
|
||||
endif
|
||||
|
||||
$(KDIR)shgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
|
||||
ifeq ($(OS), AIX)
|
||||
$(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemm_kernel$(TSUFFIX).s
|
||||
m4 shgemm_kernel$(TSUFFIX).s > shgemm_kernel$(TSUFFIX)_nomacros.s
|
||||
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemm_kernel$(TSUFFIX)_nomacros.s -o $@
|
||||
rm shgemm_kernel$(TSUFFIX).s shgemm_kernel$(TSUFFIX)_nomacros.s
|
||||
else
|
||||
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
|
||||
endif
|
||||
|
||||
$(KDIR)dgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMMKERNEL) $(DGEMMDEPEND)
|
||||
ifeq ($(OS), AIX)
|
||||
$(CC) $(CFLAGS) -E -DDOUBLE -UCOMPLEX $< -o dgemm_kernel$(TSUFFIX).s
|
||||
@@ -2206,6 +2259,9 @@ $(KDIR)xtrsm_oltncopy$(TSUFFIX).$(SUFFIX) : generic/ztrsm_ltcopy_$(XGEMM_UNROLL_
|
||||
$(KDIR)sgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMM_BETA)
|
||||
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
$(KDIR)shgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA)
|
||||
$(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
$(KDIR)dgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DGEMM_BETA)
|
||||
$(CC) $(PFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
@@ -2221,6 +2277,20 @@ $(KDIR)zgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(ZGEMM_BETA)
|
||||
$(KDIR)xgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XGEMM_BETA)
|
||||
$(CC) $(PFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@
|
||||
|
||||
$(SHGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMONCOPY)
|
||||
$(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
$(SHGEMMOTCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMOTCOPY)
|
||||
$(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))
|
||||
$(SHGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMINCOPY)
|
||||
$(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
$(SHGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMITCOPY)
|
||||
$(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
endif
|
||||
$(SGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SGEMMONCOPY)
|
||||
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
@@ -2325,6 +2395,9 @@ endif
|
||||
|
||||
endif
|
||||
|
||||
$(KDIR)shgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
|
||||
$(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
$(KDIR)sgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMMKERNEL) $(SGEMMDEPEND)
|
||||
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@
|
||||
#include "common.h"
|
||||
|
||||
int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta,
|
||||
FLOAT *dummy2, BLASLONG dummy3, FLOAT *dummy4, BLASLONG dummy5,
|
||||
IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5,
|
||||
FLOAT *c, BLASLONG ldc){
|
||||
|
||||
|
||||
|
||||
@@ -39,10 +39,10 @@
|
||||
#include <stdio.h>
|
||||
#include "common.h"
|
||||
|
||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){
|
||||
int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){
|
||||
BLASLONG i, j;
|
||||
FLOAT *a_offset, *a_offset1, *a_offset2;
|
||||
FLOAT *b_offset;
|
||||
IFLOAT *a_offset, *a_offset1, *a_offset2;
|
||||
IFLOAT *b_offset;
|
||||
|
||||
a_offset = a;
|
||||
b_offset = b;
|
||||
|
||||
@@ -39,11 +39,11 @@
|
||||
#include <stdio.h>
|
||||
#include "common.h"
|
||||
|
||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){
|
||||
int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){
|
||||
BLASLONG i, j;
|
||||
|
||||
FLOAT *a_offset, *a_offset1, *a_offset2;
|
||||
FLOAT *b_offset, *b_offset1, *b_offset2;
|
||||
IFLOAT *a_offset, *a_offset1, *a_offset2;
|
||||
IFLOAT *b_offset, *b_offset1, *b_offset2;
|
||||
|
||||
a_offset = a;
|
||||
b_offset = b;
|
||||
|
||||
@@ -1,13 +1,32 @@
|
||||
#include "common.h"
|
||||
int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FLOAT* C,BLASLONG ldc
|
||||
#if defined(HALF) && defined(HALFCONVERSION)
|
||||
float
|
||||
bfloat16tof32 (bfloat16 f16)
|
||||
{
|
||||
float result = 0;
|
||||
unsigned short* q = (unsigned short*)(&result);
|
||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
q[0] = f16;
|
||||
#else
|
||||
q[1] = f16;
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
#define BF16TOF32(x) (bfloat16tof32(x))
|
||||
#else
|
||||
#define BF16TOF32(x) x
|
||||
#endif
|
||||
int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc
|
||||
#ifdef TRMMKERNEL
|
||||
,BLASLONG offset
|
||||
#endif
|
||||
)
|
||||
{
|
||||
BLASLONG i,j,k;
|
||||
FLOAT *C0,*C1,*ptrba,*ptrbb;
|
||||
FLOAT res0,res1,res2,res3,load0,load1,load2,load3,load4,load5,load6,load7;
|
||||
FLOAT *C0,*C1;
|
||||
IFLOAT *ptrba,*ptrbb;
|
||||
FLOAT res0,res1,res2,res3;
|
||||
IFLOAT load0,load1,load2,load3,load4,load5,load6,load7;
|
||||
for (j=0; j<bn/2; j+=1)
|
||||
{
|
||||
C0 = C;
|
||||
@@ -24,36 +43,36 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
|
||||
{
|
||||
load0 = ptrba[2*0+0];
|
||||
load1 = ptrbb[2*0+0];
|
||||
res0 = res0+load0*load1;
|
||||
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
|
||||
load2 = ptrba[2*0+1];
|
||||
res1 = res1+load2*load1;
|
||||
res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
|
||||
load3 = ptrbb[2*0+1];
|
||||
res2 = res2+load0*load3;
|
||||
res3 = res3+load2*load3;
|
||||
res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
|
||||
res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
|
||||
load4 = ptrba[2*1+0];
|
||||
load5 = ptrbb[2*1+0];
|
||||
res0 = res0+load4*load5;
|
||||
res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
|
||||
load6 = ptrba[2*1+1];
|
||||
res1 = res1+load6*load5;
|
||||
res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
|
||||
load7 = ptrbb[2*1+1];
|
||||
res2 = res2+load4*load7;
|
||||
res3 = res3+load6*load7;
|
||||
res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
|
||||
res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
|
||||
load0 = ptrba[2*2+0];
|
||||
load1 = ptrbb[2*2+0];
|
||||
res0 = res0+load0*load1;
|
||||
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
|
||||
load2 = ptrba[2*2+1];
|
||||
res1 = res1+load2*load1;
|
||||
res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
|
||||
load3 = ptrbb[2*2+1];
|
||||
res2 = res2+load0*load3;
|
||||
res3 = res3+load2*load3;
|
||||
res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
|
||||
res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
|
||||
load4 = ptrba[2*3+0];
|
||||
load5 = ptrbb[2*3+0];
|
||||
res0 = res0+load4*load5;
|
||||
res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
|
||||
load6 = ptrba[2*3+1];
|
||||
res1 = res1+load6*load5;
|
||||
res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
|
||||
load7 = ptrbb[2*3+1];
|
||||
res2 = res2+load4*load7;
|
||||
res3 = res3+load6*load7;
|
||||
res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
|
||||
res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
|
||||
ptrba = ptrba+8;
|
||||
ptrbb = ptrbb+8;
|
||||
}
|
||||
@@ -61,12 +80,12 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
|
||||
{
|
||||
load0 = ptrba[2*0+0];
|
||||
load1 = ptrbb[2*0+0];
|
||||
res0 = res0+load0*load1;
|
||||
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
|
||||
load2 = ptrba[2*0+1];
|
||||
res1 = res1+load2*load1;
|
||||
res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
|
||||
load3 = ptrbb[2*0+1];
|
||||
res2 = res2+load0*load3;
|
||||
res3 = res3+load2*load3;
|
||||
res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
|
||||
res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
|
||||
ptrba = ptrba+2;
|
||||
ptrbb = ptrbb+2;
|
||||
}
|
||||
@@ -90,9 +109,9 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
|
||||
{
|
||||
load0 = ptrba[0+0];
|
||||
load1 = ptrbb[2*0+0];
|
||||
res0 = res0+load0*load1;
|
||||
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
|
||||
load2 = ptrbb[2*0+1];
|
||||
res1 = res1+load0*load2;
|
||||
res1 = res1+BF16TOF32(load0)*BF16TOF32(load2);
|
||||
ptrba = ptrba+1;
|
||||
ptrbb = ptrbb+2;
|
||||
}
|
||||
@@ -121,9 +140,9 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
|
||||
{
|
||||
load0 = ptrba[2*0+0];
|
||||
load1 = ptrbb[0+0];
|
||||
res0 = res0+load0*load1;
|
||||
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
|
||||
load2 = ptrba[2*0+1];
|
||||
res1 = res1+load2*load1;
|
||||
res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
|
||||
ptrba = ptrba+2;
|
||||
ptrbb = ptrbb+1;
|
||||
}
|
||||
@@ -141,7 +160,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
|
||||
{
|
||||
load0 = ptrba[0+0];
|
||||
load1 = ptrbb[0+0];
|
||||
res0 = res0+load0*load1;
|
||||
res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
|
||||
ptrba = ptrba+1;
|
||||
ptrbb = ptrbb+1;
|
||||
}
|
||||
|
||||
@@ -12,6 +12,17 @@ DTRMMKERNEL = dgemm_kernel_power9.S
|
||||
CTRMMKERNEL = cgemm_kernel_power9.S
|
||||
ZTRMMKERNEL = zgemm_kernel_power9.S
|
||||
|
||||
SHGEMM_BETA = ../generic/gemm_beta.c
|
||||
SHGEMMKERNEL = ../generic/gemmkernel_2x2.c
|
||||
SHGEMMINCOPY = ../generic/gemm_ncopy_2.c
|
||||
SHGEMMITCOPY = ../generic/gemm_tcopy_2.c
|
||||
SHGEMMONCOPY = ../generic/gemm_ncopy_2.c
|
||||
SHGEMMOTCOPY = ../generic/gemm_tcopy_2.c
|
||||
SHGEMMINCOPYOBJ = shgemm_incopy$(TSUFFIX).$(SUFFIX)
|
||||
SHGEMMITCOPYOBJ = shgemm_itcopy$(TSUFFIX).$(SUFFIX)
|
||||
SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX)
|
||||
SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX)
|
||||
|
||||
SGEMMKERNEL = sgemm_kernel_power9.S
|
||||
SGEMMINCOPY = ../generic/gemm_ncopy_16.c
|
||||
SGEMMITCOPY = sgemm_tcopy_16_power8.S
|
||||
|
||||
@@ -54,6 +54,20 @@ gotoblas_t TABLE_NAME = {
|
||||
GEMM_DEFAULT_OFFSET_A, GEMM_DEFAULT_OFFSET_B, GEMM_DEFAULT_ALIGN,
|
||||
|
||||
0, 0, 0,
|
||||
SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N,
|
||||
#ifdef SHGEMM_DEFAULT_UNROLL_MN
|
||||
SHGEMM_DEFAULT_UNROLL_MN,
|
||||
#else
|
||||
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
|
||||
#endif
|
||||
shgemm_kernelTS, shgemm_betaTS,
|
||||
#if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N
|
||||
shgemm_incopyTS, shgemm_itcopyTS,
|
||||
#else
|
||||
shgemm_oncopyTS, shgemm_otcopyTS,
|
||||
#endif
|
||||
shgemm_oncopyTS, shgemm_otcopyTS,
|
||||
sgemm_kernelTS, sgemm_betaTS,
|
||||
SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N,
|
||||
#ifdef SGEMM_DEFAULT_UNROLL_MN
|
||||
SGEMM_DEFAULT_UNROLL_MN,
|
||||
@@ -648,16 +662,19 @@ gotoblas_t TABLE_NAME = {
|
||||
|
||||
#if defined(ARCH_ARM64)
|
||||
static void init_parameter(void) {
|
||||
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
|
||||
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
|
||||
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
|
||||
TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P;
|
||||
TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P;
|
||||
|
||||
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.zgemm_q = ZGEMM_DEFAULT_Q;
|
||||
|
||||
TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
|
||||
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
|
||||
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
|
||||
TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R;
|
||||
@@ -721,17 +738,20 @@ static void init_parameter(void) {
|
||||
#if defined(ARCH_POWER)
|
||||
static void init_parameter(void) {
|
||||
|
||||
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
|
||||
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
|
||||
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
|
||||
TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P;
|
||||
TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P;
|
||||
|
||||
TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
|
||||
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
|
||||
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
|
||||
TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R;
|
||||
TABLE_NAME.zgemm_r = ZGEMM_DEFAULT_R;
|
||||
|
||||
|
||||
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
|
||||
@@ -741,17 +761,20 @@ static void init_parameter(void) {
|
||||
|
||||
#if defined(ARCH_ZARCH)
|
||||
static void init_parameter(void) {
|
||||
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
|
||||
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
|
||||
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
|
||||
TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P;
|
||||
TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P;
|
||||
|
||||
TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
|
||||
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
|
||||
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
|
||||
TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R;
|
||||
TABLE_NAME.zgemm_r = ZGEMM_DEFAULT_R;
|
||||
|
||||
|
||||
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
|
||||
@@ -891,6 +914,7 @@ static void init_parameter(void) {
|
||||
(void) l2; /* dirty trick to suppress unused variable warning for targets */
|
||||
/* where the GEMM unrolling parameters do not depend on l2 */
|
||||
|
||||
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
|
||||
TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
|
||||
@@ -1261,6 +1285,7 @@ static void init_parameter(void) {
|
||||
|
||||
|
||||
|
||||
TABLE_NAME.shgemm_p = ((TABLE_NAME.shgemm_p + SHGEMM_DEFAULT_UNROLL_M - 1)/SHGEMM_DEFAULT_UNROLL_M) * SHGEMM_DEFAULT_UNROLL_M;
|
||||
TABLE_NAME.sgemm_p = ((TABLE_NAME.sgemm_p + SGEMM_DEFAULT_UNROLL_M - 1)/SGEMM_DEFAULT_UNROLL_M) * SGEMM_DEFAULT_UNROLL_M;
|
||||
TABLE_NAME.dgemm_p = ((TABLE_NAME.dgemm_p + DGEMM_DEFAULT_UNROLL_M - 1)/DGEMM_DEFAULT_UNROLL_M) * DGEMM_DEFAULT_UNROLL_M;
|
||||
TABLE_NAME.cgemm_p = ((TABLE_NAME.cgemm_p + CGEMM_DEFAULT_UNROLL_M - 1)/CGEMM_DEFAULT_UNROLL_M) * CGEMM_DEFAULT_UNROLL_M;
|
||||
@@ -1288,6 +1313,11 @@ static void init_parameter(void) {
|
||||
fprintf(stderr, "L2 = %8d DGEMM_P .. %d\n", l2, TABLE_NAME.dgemm_p);
|
||||
#endif
|
||||
|
||||
TABLE_NAME.shgemm_r = (((BUFFER_SIZE -
|
||||
((TABLE_NAME.shgemm_p * TABLE_NAME.shgemm_q * 4 + TABLE_NAME.offsetA
|
||||
+ TABLE_NAME.align) & ~TABLE_NAME.align)
|
||||
) / (TABLE_NAME.shgemm_q * 4) - 15) & ~15);
|
||||
|
||||
TABLE_NAME.sgemm_r = (((BUFFER_SIZE -
|
||||
((TABLE_NAME.sgemm_p * TABLE_NAME.sgemm_q * 4 + TABLE_NAME.offsetA
|
||||
+ TABLE_NAME.align) & ~TABLE_NAME.align)
|
||||
|
||||
Reference in New Issue
Block a user