mirror of
https://github.com/OpenMathLib/OpenBLAS
synced 2026-06-08 01:15:39 +08:00
Add sbgemv_t_bfdot kernel for ARM64
This improves performance for sbgemv_t by up to 100x on NEOVERSEV1. The geometric mean speedup is ~61x for M=N=[2,512].
This commit is contained in:
@@ -236,6 +236,7 @@ In chronological order:
|
||||
* Annop Wongwathanarat <annop.wongwathanarat@arm.com>
|
||||
* [2025-01-10] Add thread throttling profile for SGEMM on NEOVERSEV1
|
||||
* [2025-01-21] Optimize gemv_t_sve_v1x3 kernel
|
||||
* [2025-02-26] Add sbgemv_t_bfdot kernel
|
||||
|
||||
* Marek Michalowski <marek.michalowski@arm.com>
|
||||
* [2025-01-21] Add thread throttling profile for SGEMV on `NEOVERSEV1`
|
||||
|
||||
@@ -198,3 +198,4 @@ SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX)
|
||||
SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX)
|
||||
SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX)
|
||||
SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX)
|
||||
SBGEMVTKERNEL = sbgemv_t_bfdot.c
|
||||
@@ -15,4 +15,5 @@ SBGEMMONCOPY = sbgemm_ncopy_$(SBGEMM_UNROLL_N)_neoversev1.c
|
||||
SBGEMMOTCOPY = sbgemm_tcopy_$(SBGEMM_UNROLL_N)_neoversev1.c
|
||||
SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX)
|
||||
SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX)
|
||||
SBGEMVTKERNEL = sbgemv_t_bfdot.c
|
||||
endif
|
||||
@@ -1 +1,5 @@
|
||||
include $(KERNELDIR)/KERNEL.ARMV8SVE
|
||||
|
||||
ifeq ($(BUILD_BFLOAT16), 1)
|
||||
SBGEMVTKERNEL = sbgemv_t_bfdot.c
|
||||
endif
|
||||
207
kernel/arm64/sbgemv_t_bfdot.c
Normal file
207
kernel/arm64/sbgemv_t_bfdot.c
Normal file
@@ -0,0 +1,207 @@
|
||||
/***************************************************************************
|
||||
Copyright (c) 2025, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written
|
||||
permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include "common.h"
|
||||
|
||||
static inline float bf16_to_fp32(bfloat16 bf16) {
|
||||
uint32_t fp32 = (uint32_t)bf16 << 16;
|
||||
return *((float*)&fp32);
|
||||
}
|
||||
|
||||
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy)
|
||||
{
|
||||
if (m < 1 || n < 1) return(0);
|
||||
BLASLONG i;
|
||||
BLASLONG ix,iy;
|
||||
BLASLONG j;
|
||||
bfloat16_t *a_ptr;
|
||||
bfloat16_t *x_ptr;
|
||||
float *y_ptr;
|
||||
float temp;
|
||||
|
||||
iy = 0;
|
||||
a_ptr = (bfloat16_t*)(a);
|
||||
x_ptr = (bfloat16_t*)(x);
|
||||
|
||||
if (incx == 1) {
|
||||
BLASLONG width = n / 4;
|
||||
|
||||
bfloat16_t *a0_ptr = a_ptr + lda * width * 0;
|
||||
bfloat16_t *a1_ptr = a_ptr + lda * width * 1;
|
||||
bfloat16_t *a2_ptr = a_ptr + lda * width * 2;
|
||||
bfloat16_t *a3_ptr = a_ptr + lda * width * 3;
|
||||
|
||||
float *y0_ptr = y + incy * width * 0;
|
||||
float *y1_ptr = y + incy * width * 1;
|
||||
float *y2_ptr = y + incy * width * 2;
|
||||
float *y3_ptr = y + incy * width * 3;
|
||||
|
||||
for (j = 0; j < width; j++) {
|
||||
float32x4_t temp0_vec = vdupq_n_f32(0.0f);
|
||||
float32x4_t temp1_vec = vdupq_n_f32(0.0f);
|
||||
float32x4_t temp2_vec = vdupq_n_f32(0.0f);
|
||||
float32x4_t temp3_vec = vdupq_n_f32(0.0f);
|
||||
|
||||
i = 0;
|
||||
while (i + 7 < m) {
|
||||
bfloat16x8_t x_vec = vld1q_bf16(x_ptr + i);
|
||||
|
||||
bfloat16x8_t a0_vec = vld1q_bf16(a0_ptr + i);
|
||||
bfloat16x8_t a1_vec = vld1q_bf16(a1_ptr + i);
|
||||
bfloat16x8_t a2_vec = vld1q_bf16(a2_ptr + i);
|
||||
bfloat16x8_t a3_vec = vld1q_bf16(a3_ptr + i);
|
||||
|
||||
temp0_vec = vbfdotq_f32(temp0_vec, a0_vec, x_vec);
|
||||
temp1_vec = vbfdotq_f32(temp1_vec, a1_vec, x_vec);
|
||||
temp2_vec = vbfdotq_f32(temp2_vec, a2_vec, x_vec);
|
||||
temp3_vec = vbfdotq_f32(temp3_vec, a3_vec, x_vec);
|
||||
|
||||
i += 8;
|
||||
}
|
||||
if (i + 3 < m) {
|
||||
float32x2_t t0 = vdup_n_f32(0.0f);
|
||||
float32x2_t t1 = vdup_n_f32(0.0f);
|
||||
float32x2_t t2 = vdup_n_f32(0.0f);
|
||||
float32x2_t t3 = vdup_n_f32(0.0f);
|
||||
|
||||
bfloat16x4_t x_vec = vld1_bf16(x_ptr + i);
|
||||
|
||||
bfloat16x4_t a0_vec = vld1_bf16(a0_ptr + i);
|
||||
bfloat16x4_t a1_vec = vld1_bf16(a1_ptr + i);
|
||||
bfloat16x4_t a2_vec = vld1_bf16(a2_ptr + i);
|
||||
bfloat16x4_t a3_vec = vld1_bf16(a3_ptr + i);
|
||||
|
||||
t0 = vbfdot_f32(t0, a0_vec, x_vec);
|
||||
t1 = vbfdot_f32(t1, a1_vec, x_vec);
|
||||
t2 = vbfdot_f32(t2, a2_vec, x_vec);
|
||||
t3 = vbfdot_f32(t3, a3_vec, x_vec);
|
||||
|
||||
float32x2_t temp0_vec_low = vget_low_f32(temp0_vec);
|
||||
float32x2_t temp1_vec_low = vget_low_f32(temp1_vec);
|
||||
float32x2_t temp2_vec_low = vget_low_f32(temp2_vec);
|
||||
float32x2_t temp3_vec_low = vget_low_f32(temp3_vec);
|
||||
|
||||
temp0_vec = vcombine_f32(vadd_f32(t0, temp0_vec_low), vget_high_f32(temp0_vec));
|
||||
temp1_vec = vcombine_f32(vadd_f32(t1, temp1_vec_low), vget_high_f32(temp1_vec));
|
||||
temp2_vec = vcombine_f32(vadd_f32(t2, temp2_vec_low), vget_high_f32(temp2_vec));
|
||||
temp3_vec = vcombine_f32(vadd_f32(t3, temp3_vec_low), vget_high_f32(temp3_vec));
|
||||
|
||||
i += 4;
|
||||
}
|
||||
if (beta == 0.0f) {
|
||||
y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec);
|
||||
y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec);
|
||||
y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec);
|
||||
y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec);
|
||||
}
|
||||
else {
|
||||
y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y0_ptr[iy];
|
||||
y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec) + beta * y1_ptr[iy];
|
||||
y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec) + beta * y2_ptr[iy];
|
||||
y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec) + beta * y3_ptr[iy];
|
||||
}
|
||||
|
||||
for (; i < m; ++i) {
|
||||
y0_ptr[iy] += alpha * a0_ptr[i] * x_ptr[i];
|
||||
y1_ptr[iy] += alpha * a1_ptr[i] * x_ptr[i];
|
||||
y2_ptr[iy] += alpha * a2_ptr[i] * x_ptr[i];
|
||||
y3_ptr[iy] += alpha * a3_ptr[i] * x_ptr[i];
|
||||
}
|
||||
|
||||
iy += incy;
|
||||
|
||||
a0_ptr += lda;
|
||||
a1_ptr += lda;
|
||||
a2_ptr += lda;
|
||||
a3_ptr += lda;
|
||||
}
|
||||
|
||||
a_ptr = a3_ptr;
|
||||
y_ptr = y3_ptr;
|
||||
for (j = width * 4; j < n; j++) {
|
||||
float32x4_t temp0_vec = vdupq_n_f32(0.0f);
|
||||
i = 0;
|
||||
while (i + 7 < m) {
|
||||
bfloat16x8_t x_vec = vld1q_bf16(x_ptr + i);
|
||||
bfloat16x8_t a0_vec = vld1q_bf16(a_ptr + i);
|
||||
temp0_vec = vbfdotq_f32(temp0_vec, a0_vec, x_vec);
|
||||
|
||||
i += 8;
|
||||
}
|
||||
if (i + 3 < m) {
|
||||
float32x2_t t0 = vdup_n_f32(0.0f);
|
||||
bfloat16x4_t x_vec = vld1_bf16(x_ptr + i);
|
||||
bfloat16x4_t a0_vec = vld1_bf16(a_ptr + i);
|
||||
|
||||
t0 = vbfdot_f32(t0, a0_vec, x_vec);
|
||||
float32x2_t temp0_vec_low = vget_low_f32(temp0_vec);
|
||||
temp0_vec = vcombine_f32(vadd_f32(t0, temp0_vec_low), vget_high_f32(temp0_vec));
|
||||
|
||||
i += 4;
|
||||
}
|
||||
if (beta == 0.0f) {
|
||||
y_ptr[iy] = alpha * vaddvq_f32(temp0_vec);
|
||||
}
|
||||
else {
|
||||
y_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y_ptr[iy];
|
||||
}
|
||||
|
||||
for (; i < m; ++i) {
|
||||
y_ptr[iy] += alpha * a_ptr[i] * x_ptr[i];
|
||||
}
|
||||
|
||||
iy += incy;
|
||||
|
||||
a_ptr += lda;
|
||||
}
|
||||
return(0);
|
||||
}
|
||||
|
||||
for (j = 0; j < n; j++) {
|
||||
temp = 0.0;
|
||||
ix = 0;
|
||||
for (i = 0; i < m; i++) {
|
||||
temp += bf16_to_fp32(a[i]) * bf16_to_fp32(x[ix]);
|
||||
ix += incx;
|
||||
}
|
||||
if (beta == 0.0f) {
|
||||
y[iy] = alpha * temp;
|
||||
}
|
||||
else {
|
||||
y[iy] = alpha * temp + beta * y[iy];
|
||||
}
|
||||
iy += incy;
|
||||
a += lda;
|
||||
}
|
||||
return (0);
|
||||
}
|
||||
Reference in New Issue
Block a user