mirror of
https://github.com/OpenMathLib/OpenBLAS
synced 2026-06-08 01:15:39 +08:00
Add test for SHGEMM
This commit is contained in:
@@ -234,6 +234,9 @@ ifeq ($(BUILD_BFLOAT16),1)
|
||||
BF3= test_bgemm
|
||||
B3 = test_sbgemm
|
||||
endif
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
H3 = test_shgemm
|
||||
endif
|
||||
ifeq ($(BUILD_SINGLE),1)
|
||||
S3=sblat3
|
||||
endif
|
||||
@@ -257,9 +260,9 @@ endif
|
||||
|
||||
|
||||
ifeq ($(SUPPORT_GEMM3M),1)
|
||||
level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) level3_3m
|
||||
level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3) level3_3m
|
||||
else
|
||||
level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3)
|
||||
level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3)
|
||||
endif
|
||||
|
||||
ifneq ($(CROSS), 1)
|
||||
@@ -454,6 +457,9 @@ test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME)
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
test_shgemm : compare_sgemm_shgemm.c test_helpers.h ../$(LIBNAME)
|
||||
$(CC) $(CLDFLAGS) -DIHFLOAT16 -o test_shgemm compare_sgemm_shgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
|
||||
|
||||
test_shgemv : compare_sgemv_shgemv.c ../$(LIBNAME)
|
||||
$(CC) $(CLDFLAGS) -o test_shgemv compare_sgemv_shgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
|
||||
endif
|
||||
@@ -475,7 +481,7 @@ clean:
|
||||
@rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \
|
||||
sblat1 dblat1 cblat1 zblat1 \
|
||||
sblat2 dblat2 cblat2 zblat2 \
|
||||
test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemv sblat3 dblat3 cblat3 zblat3 \
|
||||
test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemm test_shgemv sblat3 dblat3 cblat3 zblat3 \
|
||||
sblat1p dblat1p cblat1p zblat1p \
|
||||
sblat2p dblat2p cblat2p zblat2p \
|
||||
sblat3p dblat3p cblat3p zblat3p \
|
||||
|
||||
148
test/compare_sgemm_shgemm.c
Normal file
148
test/compare_sgemm_shgemm.c
Normal file
@@ -0,0 +1,148 @@
|
||||
/***************************************************************************
|
||||
Copyright (c) 2020,2025 The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#define __USE_POSIX199309
|
||||
#include "../common.h"
|
||||
|
||||
#include "test_helpers.h"
|
||||
|
||||
#define SGEMM BLASFUNC(sgemm)
|
||||
#define SHGEMM BLASFUNC(shgemm)
|
||||
#define SGEMV BLASFUNC(sgemv)
|
||||
#define SHGEMV BLASFUNC(shgemv)
|
||||
#define SHGEMM_LARGEST 256
|
||||
|
||||
int
|
||||
main (int argc, char *argv[])
|
||||
{
|
||||
blasint m, n, k;
|
||||
int i, j, l;
|
||||
blasint x, y;
|
||||
int ret = 0;
|
||||
int loop = SHGEMM_LARGEST;
|
||||
char transA = 'N', transB = 'N';
|
||||
float alpha = 1.0, beta = 0.0;
|
||||
|
||||
for (x = 0; x <= loop; x++)
|
||||
{
|
||||
if ((x > 100) && (x != SHGEMM_LARGEST)) continue;
|
||||
m = k = n = x;
|
||||
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
|
||||
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
|
||||
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
|
||||
hfloat16 *AA = (hfloat16 *)malloc_safe(m * k * sizeof(bfloat16));
|
||||
hfloat16 *BB = (hfloat16 *)malloc_safe(k * n * sizeof(bfloat16));
|
||||
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT));
|
||||
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
|
||||
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
|
||||
(DD == NULL) || (CC == NULL))
|
||||
return 1;
|
||||
|
||||
for (j = 0; j < m; j++)
|
||||
{
|
||||
for (i = 0; i < k; i++)
|
||||
{
|
||||
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
AA[j * k + i] = (hfloat16) A[j * k + i];
|
||||
}
|
||||
}
|
||||
for (j = 0; j < n; j++)
|
||||
{
|
||||
for (i = 0; i < k; i++)
|
||||
{
|
||||
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
BB[j * k + i] = (hfloat16) A[j * k + i];
|
||||
}
|
||||
}
|
||||
for (y = 0; y < 4; y++)
|
||||
{
|
||||
if ((y == 0) || (y == 2)) {
|
||||
transA = 'N';
|
||||
} else {
|
||||
transA = 'T';
|
||||
}
|
||||
if ((y == 0) || (y == 1)) {
|
||||
transB = 'N';
|
||||
} else {
|
||||
transB = 'T';
|
||||
}
|
||||
|
||||
memset(CC, 0, m * n * sizeof(FLOAT));
|
||||
memset(DD, 0, m * n * sizeof(FLOAT));
|
||||
memset(C, 0, m * n * sizeof(FLOAT));
|
||||
|
||||
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
|
||||
&m, B, &k, &beta, C, &m);
|
||||
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA,
|
||||
&m, BB, &k, &beta, CC, &m);
|
||||
|
||||
for (i = 0; i < n; i++)
|
||||
for (j = 0; j < m; j++)
|
||||
{
|
||||
for (l = 0; l < k; l++)
|
||||
if (transA == 'N' && transB == 'N')
|
||||
{
|
||||
DD[i * m + j] +=
|
||||
(float) AA[l * m + j] * (float)BB[l + k * i];
|
||||
} else if (transA == 'T' && transB == 'N')
|
||||
{
|
||||
DD[i * m + j] +=
|
||||
(float)AA[k * j + l] * (float)BB[l + k * i];
|
||||
} else if (transA == 'N' && transB == 'T')
|
||||
{
|
||||
DD[i * m + j] +=
|
||||
(float)AA[l * m + j] * (float)BB[i + l * n];
|
||||
} else if (transA == 'T' && transB == 'T')
|
||||
{
|
||||
DD[i * m + j] +=
|
||||
(float)AA[k * j + l] * (float)BB[i + l * n];
|
||||
}
|
||||
if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) {
|
||||
ret++;
|
||||
}
|
||||
if (!is_close(CC[i * m + j], DD[i * m + j], 0.001, 0.0001)) {
|
||||
ret++;
|
||||
}
|
||||
}
|
||||
}
|
||||
free(A);
|
||||
free(B);
|
||||
free(C);
|
||||
free(AA);
|
||||
free(BB);
|
||||
free(DD);
|
||||
free(CC);
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "SHGEMM FAILURES: %d\n", ret);
|
||||
return 1;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
Reference in New Issue
Block a user