mirror of
https://github.com/OpenMathLib/OpenBLAS
synced 2026-06-15 07:51:43 +08:00
Merge pull request #5499 from martin-frbg/issue5497
Add test for SHGEMM
This commit is contained in:
@@ -36,6 +36,28 @@ foreach(test_bin ${OpenBLAS_Tests})
|
||||
target_link_libraries(${test_bin} ${OpenBLAS_LIBNAME})
|
||||
endforeach()
|
||||
|
||||
if (BUILD_BFLOAT16)
|
||||
add_executable(test_bgemm compare_sgemm_bgemm.c)
|
||||
target_compile_definitions(test_bgemm PUBLIC -DIBFLOAT16 -DOBFLOAT16)
|
||||
target_link_libraries(test_bgemm ${OpenBLAS_LIBNAME})
|
||||
add_executable(test_bgemv compare_sgemv_bgemv.c)
|
||||
target_compile_definitions(test_bgemv PUBLIC -DIBFLOAT16 -DOBFLOAT16)
|
||||
target_link_libraries(test_bgemv ${OpenBLAS_LIBNAME})
|
||||
add_executable(test_sbgemm compare_sgemm_sbgemm.c)
|
||||
target_compile_definitions(test_sbgemm PUBLIC -DIBFLOAT16)
|
||||
target_link_libraries(test_sbgemm ${OpenBLAS_LIBNAME})
|
||||
add_executable(test_sbgemv compare_sgemv_sbgemv.c)
|
||||
target_compile_definitions(test_sbgemv PUBLIC -DIBFLOAT16)
|
||||
target_link_libraries(test_sbgemv ${OpenBLAS_LIBNAME})
|
||||
endif()
|
||||
|
||||
if (BUILD_HFLOAT16)
|
||||
add_executable(test_shgemm compare_sgemm_shgemm.c)
|
||||
target_link_libraries(test_shgemm ${OpenBLAS_LIBNAME})
|
||||
add_executable(test_shgemv compare_sgemv_shgemv.c)
|
||||
target_link_libraries(test_shgemv ${OpenBLAS_LIBNAME})
|
||||
endif()
|
||||
|
||||
# $1 exec, $2 input, $3 output_result
|
||||
if(WIN32)
|
||||
FILE(WRITE ${CMAKE_CURRENT_BINARY_DIR}/test_helper.ps1
|
||||
@@ -94,3 +116,21 @@ add_test(NAME "${float_type}blas3_3m"
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if (BUILD_BFLOAT16)
|
||||
add_test(NAME "bgemm"
|
||||
COMMAND $<TARGET_FILE:test_bgemm>)
|
||||
add_test(NAME "bgemv"
|
||||
COMMAND $<TARGET_FILE:test_bgemv>)
|
||||
add_test(NAME "sbgemm"
|
||||
COMMAND $<TARGET_FILE:test_sbgemm>)
|
||||
add_test(NAME "sbgemv"
|
||||
COMMAND $<TARGET_FILE:test_sbgemv>)
|
||||
endif()
|
||||
|
||||
if (BUILD_HFLOAT16)
|
||||
add_test(NAME "shgemm"
|
||||
COMMAND $<TARGET_FILE:test_shgemm>)
|
||||
add_test(NAME "shgemv"
|
||||
COMMAND $<TARGET_FILE:test_shgemv>)
|
||||
endif()
|
||||
|
||||
@@ -234,6 +234,9 @@ ifeq ($(BUILD_BFLOAT16),1)
|
||||
BF3= test_bgemm
|
||||
B3 = test_sbgemm
|
||||
endif
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
H3 = test_shgemm
|
||||
endif
|
||||
ifeq ($(BUILD_SINGLE),1)
|
||||
S3=sblat3
|
||||
endif
|
||||
@@ -257,9 +260,9 @@ endif
|
||||
|
||||
|
||||
ifeq ($(SUPPORT_GEMM3M),1)
|
||||
level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) level3_3m
|
||||
level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3) level3_3m
|
||||
else
|
||||
level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3)
|
||||
level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3)
|
||||
endif
|
||||
|
||||
ifneq ($(CROSS), 1)
|
||||
@@ -454,6 +457,9 @@ test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME)
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HFLOAT16),1)
|
||||
test_shgemm : compare_sgemm_shgemm.c test_helpers.h ../$(LIBNAME)
|
||||
$(CC) $(CLDFLAGS) -DIHFLOAT16 -o test_shgemm compare_sgemm_shgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
|
||||
|
||||
test_shgemv : compare_sgemv_shgemv.c ../$(LIBNAME)
|
||||
$(CC) $(CLDFLAGS) -o test_shgemv compare_sgemv_shgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
|
||||
endif
|
||||
@@ -475,7 +481,7 @@ clean:
|
||||
@rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \
|
||||
sblat1 dblat1 cblat1 zblat1 \
|
||||
sblat2 dblat2 cblat2 zblat2 \
|
||||
test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemv sblat3 dblat3 cblat3 zblat3 \
|
||||
test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemm test_shgemv sblat3 dblat3 cblat3 zblat3 \
|
||||
sblat1p dblat1p cblat1p zblat1p \
|
||||
sblat2p dblat2p cblat2p zblat2p \
|
||||
sblat3p dblat3p cblat3p zblat3p \
|
||||
|
||||
234
test/compare_sgemm_shgemm.c
Normal file
234
test/compare_sgemm_shgemm.c
Normal file
@@ -0,0 +1,234 @@
|
||||
/***************************************************************************
|
||||
Copyright (c) 2020,2025 The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include "../common.h"
|
||||
|
||||
#include "test_helpers.h"
|
||||
|
||||
#define SGEMM BLASFUNC(sgemm)
|
||||
#define SHGEMM BLASFUNC(shgemm)
|
||||
#define SHGEMM_LARGEST 256
|
||||
|
||||
int
|
||||
main (int argc, char *argv[])
|
||||
{
|
||||
blasint m, n, k;
|
||||
int i, j, l;
|
||||
blasint x, y;
|
||||
int ret = 0;
|
||||
int rret = 0;
|
||||
int loop = SHGEMM_LARGEST;
|
||||
char transA = 'N', transB = 'N';
|
||||
float alpha = 1.0, beta = 0.0;
|
||||
int xvals[6]={3,24,55,71,SHGEMM_LARGEST/2,SHGEMM_LARGEST};
|
||||
|
||||
for (x = 0; x <= loop; x++)
|
||||
{
|
||||
if ((x > 100) && (x != SHGEMM_LARGEST)) continue;
|
||||
m = k = n = x;
|
||||
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
|
||||
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
|
||||
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
|
||||
_Float16 *AA = (_Float16 *)malloc_safe(m * k * sizeof(_Float16));
|
||||
_Float16 *BB = (_Float16 *)malloc_safe(k * n * sizeof(_Float16));
|
||||
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT));
|
||||
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
|
||||
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
|
||||
(DD == NULL) || (CC == NULL))
|
||||
return 1;
|
||||
|
||||
for (j = 0; j < m; j++)
|
||||
{
|
||||
for (i = 0; i < k; i++)
|
||||
{
|
||||
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
AA[j * k + i] = (_Float16) A[j * k + i];
|
||||
}
|
||||
}
|
||||
for (j = 0; j < n; j++)
|
||||
{
|
||||
for (i = 0; i < k; i++)
|
||||
{
|
||||
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
BB[j * k + i] = (_Float16) B[j * k + i];
|
||||
}
|
||||
}
|
||||
for (y = 0; y < 4; y++)
|
||||
{
|
||||
if ((y == 0) || (y == 2)) {
|
||||
transA = 'N';
|
||||
} else {
|
||||
transA = 'T';
|
||||
}
|
||||
if ((y == 0) || (y == 1)) {
|
||||
transB = 'N';
|
||||
} else {
|
||||
transB = 'T';
|
||||
}
|
||||
|
||||
memset(CC, 0, m * n * sizeof(FLOAT));
|
||||
memset(DD, 0, m * n * sizeof(FLOAT));
|
||||
memset(C, 0, m * n * sizeof(FLOAT));
|
||||
|
||||
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
|
||||
&m, B, &k, &beta, C, &m);
|
||||
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, (_Float16*) AA,
|
||||
&m, (_Float16*)BB, &k, &beta, CC, &m);
|
||||
|
||||
for (i = 0; i < n; i++)
|
||||
for (j = 0; j < m; j++)
|
||||
{
|
||||
for (l = 0; l < k; l++)
|
||||
if (transA == 'N' && transB == 'N')
|
||||
{
|
||||
DD[i * m + j] +=
|
||||
(float) AA[l * m + j] * (float)BB[l + k * i];
|
||||
} else if (transA == 'T' && transB == 'N')
|
||||
{
|
||||
DD[i * m + j] +=
|
||||
(float)AA[k * j + l] * (float)BB[l + k * i];
|
||||
} else if (transA == 'N' && transB == 'T')
|
||||
{
|
||||
DD[i * m + j] +=
|
||||
(float)AA[l * m + j] * (float)BB[i + l * n];
|
||||
} else if (transA == 'T' && transB == 'T')
|
||||
{
|
||||
DD[i * m + j] +=
|
||||
(float)AA[k * j + l] * (float)BB[i + l * n];
|
||||
}
|
||||
if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) {
|
||||
fprintf(stderr,"CC %f C %f \n",(float)CC[i*m+j],C[i*m+j]);
|
||||
ret++;
|
||||
}
|
||||
if (!is_close(CC[i * m + j], DD[i * m + j], 0.001, 0.0001)) {
|
||||
fprintf(stderr,"CC %f DD %f \n",(float)CC[i*m+j],(float)DD[i*m+j]);
|
||||
ret++;
|
||||
}
|
||||
}
|
||||
}
|
||||
free(A);
|
||||
free(B);
|
||||
free(C);
|
||||
free(AA);
|
||||
free(BB);
|
||||
free(DD);
|
||||
free(CC);
|
||||
}
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "SHGEMM FAILURES: %d!!!\n", ret);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
for (loop = 0; loop<6; loop++) {
|
||||
x=xvals[loop];
|
||||
for (alpha=0.;alpha<=1.;alpha+=0.5)
|
||||
{
|
||||
for (beta = 0.0; beta <=1.; beta+=0.5) {
|
||||
|
||||
m = k = n = x;
|
||||
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
|
||||
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
|
||||
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
|
||||
_Float16 *AA = (_Float16 *)malloc_safe(m * k * sizeof(_Float16));
|
||||
_Float16 *BB = (_Float16 *)malloc_safe(k * n * sizeof(_Float16));
|
||||
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
|
||||
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
|
||||
(CC == NULL))
|
||||
return 1;
|
||||
|
||||
for (j = 0; j < m; j++)
|
||||
{
|
||||
for (i = 0; i < k; i++)
|
||||
{
|
||||
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
AA[j * k + i] = (_Float16) A[j * k + i];
|
||||
}
|
||||
}
|
||||
for (j = 0; j < n; j++)
|
||||
{
|
||||
for (i = 0; i < k; i++)
|
||||
{
|
||||
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
BB[j * k + i] = (_Float16) B[j * k + i];
|
||||
}
|
||||
}
|
||||
|
||||
for (y = 0; y < 4; y++)
|
||||
{
|
||||
if ((y == 0) || (y == 2)) {
|
||||
transA = 'N';
|
||||
} else {
|
||||
transA = 'T';
|
||||
}
|
||||
if ((y == 0) || (y == 1)) {
|
||||
transB = 'N';
|
||||
} else {
|
||||
transB = 'T';
|
||||
}
|
||||
|
||||
memset(CC, 0, m * n * sizeof(FLOAT));
|
||||
memset(C, 0, m * n * sizeof(FLOAT));
|
||||
|
||||
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
|
||||
&m, B, &k, &beta, C, &m);
|
||||
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, (_Float16*) AA,
|
||||
&m, (_Float16*)BB, &k, &beta, CC, &m);
|
||||
|
||||
for (i = 0; i < n; i++)
|
||||
for (j = 0; j < m; j++)
|
||||
{
|
||||
if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) {
|
||||
ret++;
|
||||
}
|
||||
}
|
||||
}
|
||||
free(A);
|
||||
free(B);
|
||||
free(C);
|
||||
free(AA);
|
||||
free(BB);
|
||||
free(CC);
|
||||
|
||||
if (ret != 0) {
|
||||
/*
|
||||
* fprintf(stderr, "SHGEMM FAILURES FOR n=%d, alpha=%f beta=%f : %d\n", x, alpha, beta, ret);
|
||||
*/
|
||||
rret++;
|
||||
ret=0;
|
||||
/* } else {
|
||||
fprintf(stderr, "SHGEMM SUCCEEDED FOR n=%d, alpha=%f beta=%f : %d\n", x, alpha, beta, ret);
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
if (rret > 0) return(1);
|
||||
return(0);
|
||||
}
|
||||
Reference in New Issue
Block a user