mirror of
https://github.com/OpenMathLib/OpenBLAS
synced 2026-06-15 07:51:43 +08:00
fix bugs in aarch64 sbgemv_n kernel
This commit is contained in:
@@ -69,12 +69,8 @@ static void beta_op(float *x, BLASLONG n, FLOAT beta) {
|
||||
x += 4;
|
||||
}
|
||||
|
||||
if (rest_n & 3) {
|
||||
x[0] *= beta;
|
||||
if ((rest_n & 3) > 1)
|
||||
x[1] *= beta;
|
||||
if ((rest_n & 3) > 2)
|
||||
x[2] *= beta;
|
||||
for (BLASLONG i = 0; i < (rest_n & 3); i ++) {
|
||||
x[i] *= beta;
|
||||
}
|
||||
}
|
||||
return;
|
||||
@@ -88,7 +84,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
|
||||
|
||||
bfloat16x8_t a0, a1, a2, a3, a4, a5, a6, a7;
|
||||
bfloat16x8_t t0, t1, t2, t3, t4, t5, t6, t7;
|
||||
|
||||
bfloat16x8_t x_vec;
|
||||
bfloat16x4_t x_vecx4;
|
||||
|
||||
float32x4_t y1_vec, y2_vec;
|
||||
float32x4_t fp32_low, fp32_high;
|
||||
|
||||
@@ -106,7 +105,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
|
||||
|
||||
if (incx == 1 && incy == 1) {
|
||||
if (beta != 1) {
|
||||
beta_op(y, n, beta);
|
||||
beta_op(y, m, beta);
|
||||
}
|
||||
|
||||
for (i = 0; i < n / 8; i++) {
|
||||
@@ -290,12 +289,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
|
||||
|
||||
a_ptr += 4 * lda;
|
||||
|
||||
bfloat16x4_t x_vecx4 = vld1_bf16(x_ptr);
|
||||
x_vecx4 = vld1_bf16(x_ptr);
|
||||
if (alpha != 1) {
|
||||
x_vec = vcombine_bf16(x_vecx4, bf16_zero);
|
||||
fp32_low = vreinterpretq_f32_u16(
|
||||
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q),
|
||||
vreinterpretq_u16_bf16(x_vec)));
|
||||
fp32_low = vcvt_f32_bf16(x_vecx4);
|
||||
fp32_low = vmulq_n_f32(fp32_low, alpha);
|
||||
x_vecx4 = vcvt_bf16_f32(fp32_low);
|
||||
}
|
||||
@@ -348,15 +344,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
|
||||
|
||||
y1_vec = vld1q_f32(y_ptr);
|
||||
|
||||
a0 = vcombine_bf16(a0x4, bf16_zero);
|
||||
a1 = vcombine_bf16(a1x4, bf16_zero);
|
||||
a2 = vcombine_bf16(a2x4, bf16_zero);
|
||||
a3 = vcombine_bf16(a3x4, bf16_zero);
|
||||
a0 = vcombine_bf16(a0x4, a2x4);
|
||||
a1 = vcombine_bf16(a1x4, a3x4);
|
||||
|
||||
t0 = vreinterpretq_bf16_u16(
|
||||
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
|
||||
t1 = vreinterpretq_bf16_u16(
|
||||
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3)));
|
||||
t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
|
||||
t1 = vreinterpretq_bf16_u16(vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
|
||||
|
||||
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0);
|
||||
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1);
|
||||
@@ -374,10 +366,12 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
|
||||
}
|
||||
|
||||
if (rest_m) {
|
||||
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
|
||||
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
|
||||
x2 = alpha * vcvtah_f32_bf16(x_ptr[2]);
|
||||
x3 = alpha * vcvtah_f32_bf16(x_ptr[3]);
|
||||
fp32_low = vcvt_f32_bf16(x_vecx4);
|
||||
|
||||
x0 = vgetq_lane_f32(fp32_low, 0);
|
||||
x1 = vgetq_lane_f32(fp32_low, 1);
|
||||
x2 = vgetq_lane_f32(fp32_low, 2);
|
||||
x3 = vgetq_lane_f32(fp32_low, 3);
|
||||
|
||||
for (BLASLONG j = 0; j < rest_m; j++) {
|
||||
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]);
|
||||
@@ -396,18 +390,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
|
||||
|
||||
a_ptr += 2 * lda;
|
||||
|
||||
bfloat16_t tmp_buffer[4];
|
||||
memset((void*)tmp_buffer, 0, sizeof(bfloat16_t));
|
||||
|
||||
tmp_buffer[0] = x_ptr[0];
|
||||
tmp_buffer[1] = x_ptr[1];
|
||||
x_vecx4 = vreinterpret_bf16_u16(vzip1_u16(
|
||||
vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[0])),
|
||||
vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[1]))
|
||||
));
|
||||
|
||||
bfloat16x4_t x_vecx4 = vld1_bf16(tmp_buffer);
|
||||
if (alpha != 1) {
|
||||
x_vec = vcombine_bf16(x_vecx4, bf16_zero);
|
||||
fp32_low = vreinterpretq_f32_u16(
|
||||
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q),
|
||||
vreinterpretq_u16_bf16(x_vec)));
|
||||
fp32_low = vcvt_f32_bf16(x_vecx4);
|
||||
fp32_low = vmulq_n_f32(fp32_low, alpha);
|
||||
x_vecx4 = vcvt_bf16_f32(fp32_low);
|
||||
}
|
||||
@@ -422,14 +411,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
|
||||
|
||||
t0 = vreinterpretq_bf16_u16(
|
||||
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
|
||||
t4 = vreinterpretq_bf16_u16(
|
||||
t1 = vreinterpretq_bf16_u16(
|
||||
vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
|
||||
|
||||
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0);
|
||||
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1);
|
||||
|
||||
y2_vec = vbfmlalbq_lane_f32(y2_vec, t4, x_vecx4, 0);
|
||||
y2_vec = vbfmlaltq_lane_f32(y2_vec, t4, x_vecx4, 1);
|
||||
y2_vec = vbfmlalbq_lane_f32(y2_vec, t1, x_vecx4, 0);
|
||||
y2_vec = vbfmlaltq_lane_f32(y2_vec, t1, x_vecx4, 1);
|
||||
|
||||
vst1q_f32(y_ptr, y1_vec);
|
||||
vst1q_f32(y_ptr + 4, y2_vec);
|
||||
@@ -449,29 +438,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
|
||||
a0 = vcombine_bf16(a0x4, bf16_zero);
|
||||
a1 = vcombine_bf16(a1x4, bf16_zero);
|
||||
|
||||
t0 = vreinterpretq_bf16_u16(
|
||||
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
|
||||
t1 = vreinterpretq_bf16_u16(
|
||||
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3)));
|
||||
t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
|
||||
|
||||
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0);
|
||||
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1);
|
||||
y1_vec = vbfmlalbq_lane_f32(y1_vec, t1, x_vecx4, 2);
|
||||
y1_vec = vbfmlaltq_lane_f32(y1_vec, t1, x_vecx4, 3);
|
||||
|
||||
vst1q_f32(y_ptr, y1_vec);
|
||||
|
||||
a_ptr0 += 4;
|
||||
a_ptr1 += 4;
|
||||
a_ptr2 += 4;
|
||||
a_ptr3 += 4;
|
||||
|
||||
y_ptr += 4;
|
||||
}
|
||||
|
||||
if (m & 2) {
|
||||
x0 = alpha * (vcvtah_f32_bf16(x_ptr[0]));
|
||||
x1 = alpha * (vcvtah_f32_bf16(x_ptr[1]));
|
||||
fp32_low = vcvt_f32_bf16(x_vecx4);
|
||||
x0 = vgetq_lane_f32(fp32_low, 0);
|
||||
x1 = vgetq_lane_f32(fp32_low, 1);
|
||||
|
||||
|
||||
y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]);
|
||||
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);
|
||||
@@ -485,8 +469,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
|
||||
}
|
||||
|
||||
if (m & 1) {
|
||||
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
|
||||
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
|
||||
fp32_low = vcvt_f32_bf16(x_vecx4);
|
||||
x0 = vgetq_lane_f32(fp32_low, 0);
|
||||
x1 = vgetq_lane_f32(fp32_low, 1);
|
||||
|
||||
y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]);
|
||||
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);
|
||||
|
||||
Reference in New Issue
Block a user