Merge pull request #4419 from martin-frbg/issue4413

[WIP] Add fixes and utests for ZSCAL with NaN or Inf arguments
This commit is contained in:
Martin Kroeker 2024-01-12 14:27:08 +01:00 committed by GitHub
commit f31bea07dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 172 additions and 42 deletions

View File

@ -223,7 +223,7 @@ zscal_begin:
fcmp DA_I, #0.0 fcmp DA_I, #0.0
beq .Lzscal_kernel_RI_zero beq .Lzscal_kernel_RI_zero
b .Lzscal_kernel_R_zero // b .Lzscal_kernel_R_zero
.Lzscal_kernel_R_non_zero: .Lzscal_kernel_R_non_zero:

View File

@ -103,8 +103,10 @@ endif
ifdef HAVE_MSA ifdef HAVE_MSA
SSCALKERNEL = ../mips/sscal_msa.c SSCALKERNEL = ../mips/sscal_msa.c
DSCALKERNEL = ../mips/dscal_msa.c DSCALKERNEL = ../mips/dscal_msa.c
CSCALKERNEL = ../mips/cscal_msa.c #CSCALKERNEL = ../mips/cscal_msa.c
ZSCALKERNEL = ../mips/zscal_msa.c #ZSCALKERNEL = ../mips/zscal_msa.c
CSCALKERNEL = ../mips/zscal.c
ZSCALKERNEL = ../mips/zscal.c
else else
SSCALKERNEL = ../mips/scal.c SSCALKERNEL = ../mips/scal.c
DSCALKERNEL = ../mips/scal.c DSCALKERNEL = ../mips/scal.c

View File

@ -47,6 +47,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r,FLOAT da_i, F
else else
{ {
temp = - da_i * x[ip+1] ; temp = - da_i * x[ip+1] ;
if (isnan(x[ip]) || isinf(x[ip])) temp = NAN;
x[ip+1] = da_i * x[ip] ; x[ip+1] = da_i * x[ip] ;
} }
} }
@ -63,6 +64,9 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r,FLOAT da_i, F
x[ip+1] = da_r * x[ip+1] + da_i * x[ip] ; x[ip+1] = da_r * x[ip+1] + da_i * x[ip] ;
} }
} }
if ( da_r != da_r )
x[ip] = da_r;
else
x[ip] = temp; x[ip] = temp;
ip += inc_x2; ip += inc_x2;

View File

@ -60,6 +60,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r,FLOAT da_i, F
else else
{ {
temp = - da_i * x[ip+1] ; temp = - da_i * x[ip+1] ;
if (isnan(x[ip]) || isinf(x[ip])) temp = NAN;
x[ip+1] = da_i * x[ip] ; x[ip+1] = da_i * x[ip] ;
} }
} }

View File

@ -80,6 +80,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r,FLOAT da_i, F
j += gvl; j += gvl;
ix += inc_x * 2 * gvl; ix += inc_x * 2 * gvl;
} }
#if 0
}else if(da_r == 0.0){ }else if(da_r == 0.0){
gvl = VSETVL(n); gvl = VSETVL(n);
BLASLONG stride_x = inc_x * 2 * sizeof(FLOAT); BLASLONG stride_x = inc_x * 2 * sizeof(FLOAT);
@ -97,6 +98,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r,FLOAT da_i, F
j += gvl; j += gvl;
ix += inc_xv; ix += inc_xv;
} }
#endif
if(j < n){ if(j < n){
gvl = VSETVL(n-j); gvl = VSETVL(n-j);
v0 = VLSEV_FLOAT(&x[ix], stride_x, gvl); v0 = VLSEV_FLOAT(&x[ix], stride_x, gvl);

View File

@ -98,7 +98,7 @@
fcomip %st(1), %st fcomip %st(1), %st
ffreep %st(0) ffreep %st(0)
jne .L30 jne .L30
jp .L30
EMMS EMMS
pxor %mm0, %mm0 pxor %mm0, %mm0

View File

@ -87,6 +87,7 @@
xorps %xmm7, %xmm7 xorps %xmm7, %xmm7
comiss %xmm0, %xmm7 comiss %xmm0, %xmm7
jne .L100 # Alpha_r != ZERO jne .L100 # Alpha_r != ZERO
jp .L100 # Alpha_r NaN
comiss %xmm1, %xmm7 comiss %xmm1, %xmm7
jne .L100 # Alpha_i != ZERO jne .L100 # Alpha_i != ZERO

View File

@ -98,6 +98,7 @@
xorps %xmm7, %xmm7 xorps %xmm7, %xmm7
comisd %xmm0, %xmm7 comisd %xmm0, %xmm7
jne .L100 jne .L100
jp .L100
comisd %xmm1, %xmm7 comisd %xmm1, %xmm7
jne .L100 jne .L100

View File

@ -39,7 +39,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#endif #endif
#include "common.h" #include "common.h"
#include <float.h>
#if defined (SKYLAKEX) || defined (COOPERLAKE) || defined (SAPPHIRERAPIDS) #if defined (SKYLAKEX) || defined (COOPERLAKE) || defined (SAPPHIRERAPIDS)
#include "zscal_microk_skylakex-2.c" #include "zscal_microk_skylakex-2.c"
@ -222,12 +222,10 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
if ( da_r == 0.0 ) if ( da_r == 0.0 )
{ {
BLASLONG n1 = n & -2; BLASLONG n1 = n & -2;
if ( da_i == 0.0 ) if ( da_i == 0.0 )
{ {
while(j < n1) while(j < n1)
{ {
@ -253,7 +251,6 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
} }
else else
{ {
while(j < n1) while(j < n1)
{ {
@ -361,43 +358,53 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
if ( da_i == 0 ) if ( da_i == 0 )
zscal_kernel_8_zero(n1 , alpha , x); zscal_kernel_8_zero(n1 , alpha , x);
else else
zscal_kernel_8_zero_r(n1 , alpha , x); // zscal_kernel_8_zero_r(n1 , alpha , x);
zscal_kernel_8(n1 , alpha , x);
else else
if ( da_i == 0 ) if ( da_i == 0 && da_r == da_r)
zscal_kernel_8_zero_i(n1 , alpha , x); zscal_kernel_8_zero_i(n1 , alpha , x);
else else
zscal_kernel_8(n1 , alpha , x); zscal_kernel_8(n1 , alpha , x);
}
i = n1 << 1; i = n1 << 1;
j = n1; j = n1;
}
if ( da_r == 0.0 || da_r != da_r )
if ( da_r == 0.0 )
{ {
if ( da_i == 0.0 ) if ( da_i == 0.0 )
{ {
FLOAT res=0.0;
if (da_r != da_r) res= da_r;
while(j < n) while(j < n)
{ {
x[i]=res;
x[i]=0.0; x[i+1]=res;
x[i+1]=0.0;
i += 2 ; i += 2 ;
j++; j++;
} }
} }
else else if (da_r < -FLT_MAX || da_r > FLT_MAX) {
while(j < n)
{
x[i]= NAN;
x[i+1] = da_r;
i += 2 ;
j++;
}
} else
{ {
while(j < n) while(j < n)
{ {
temp0 = -da_i * x[i+1]; temp0 = -da_i * x[i+1];
if (x[i] < -FLT_MAX || x[i] > FLT_MAX)
temp0 = NAN;
x[i+1] = da_i * x[i]; x[i+1] = da_i * x[i];
if ( x[i] == x[i]) //preserve NaN
x[i] = temp0; x[i] = temp0;
i += 2 ; i += 2 ;
j++; j++;
@ -409,10 +416,8 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
} }
else else
{ {
if (da_i == 0.0) if (da_i == 0.0)
{ {
while(j < n) while(j < n)
{ {
@ -423,14 +428,12 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
j++; j++;
} }
} }
else else
{ {
while(j < n) while(j < n)
{ {
temp0 = da_r * x[i] - da_i * x[i+1]; temp0 = da_r * x[i] - da_i * x[i+1];
x[i+1] = da_r * x[i+1] + da_i * x[i]; x[i+1] = da_r * x[i+1] + da_i * x[i];
x[i] = temp0; x[i] = temp0;
@ -445,5 +448,3 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
return(0); return(0);
} }

View File

@ -82,6 +82,7 @@
pxor %xmm15, %xmm15 pxor %xmm15, %xmm15
comisd %xmm0, %xmm15 comisd %xmm0, %xmm15
jne .L100 jne .L100
jp .L100
comisd %xmm1, %xmm15 comisd %xmm1, %xmm15
jne .L100 jne .L100

View File

@ -233,9 +233,15 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
while (j < n1) { while (j < n1) {
if (isnan(x[i]) || isinf(x[i]))
temp0 = NAN;
else
temp0 = -da_i * x[i + 1]; temp0 = -da_i * x[i + 1];
x[i + 1] = da_i * x[i]; x[i + 1] = da_i * x[i];
x[i] = temp0; x[i] = temp0;
if (isnan(x[i + inc_x]) || isinf(x[i + inc_x]))
temp1 = NAN;
else
temp1 = -da_i * x[i + 1 + inc_x]; temp1 = -da_i * x[i + 1 + inc_x];
x[i + 1 + inc_x] = da_i * x[i + inc_x]; x[i + 1 + inc_x] = da_i * x[i + inc_x];
x[i + inc_x] = temp1; x[i + inc_x] = temp1;
@ -246,6 +252,9 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
while (j < n) { while (j < n) {
if (isnan(x[i]) || isinf(x[i]))
temp0 = NAN;
else
temp0 = -da_i * x[i + 1]; temp0 = -da_i * x[i + 1];
x[i + 1] = da_i * x[i]; x[i + 1] = da_i * x[i];
x[i] = temp0; x[i] = temp0;
@ -320,7 +329,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
if (da_i == 0) if (da_i == 0)
zscal_kernel_8_zero(n1, x); zscal_kernel_8_zero(n1, x);
else else
zscal_kernel_8_zero_r(n1, alpha, x); zscal_kernel_8(n1, da_r, da_i, x);
else if (da_i == 0) else if (da_i == 0)
zscal_kernel_8_zero_i(n1, alpha, x); zscal_kernel_8_zero_i(n1, alpha, x);
else else
@ -347,6 +356,9 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
while (j < n) { while (j < n) {
if (isnan(x[i]) || isinf(x[i]))
temp0 = NAN;
else
temp0 = -da_i * x[i + 1]; temp0 = -da_i * x[i + 1];
x[i + 1] = da_i * x[i]; x[i + 1] = da_i * x[i];
x[i] = temp0; x[i] = temp0;

View File

@ -15,6 +15,7 @@ else ()
test_dsdot.c test_dsdot.c
test_dnrm2.c test_dnrm2.c
test_swap.c test_swap.c
test_zscal.c
) )
endif () endif ()

View File

@ -11,7 +11,7 @@ UTESTBIN=openblas_utest
include $(TOPDIR)/Makefile.system include $(TOPDIR)/Makefile.system
OBJS=utest_main.o test_min.o test_amax.o test_ismin.o test_rotmg.o test_axpy.o test_dotu.o test_dsdot.o test_swap.o test_rot.o test_dnrm2.o OBJS=utest_main.o test_min.o test_amax.o test_ismin.o test_rotmg.o test_axpy.o test_dotu.o test_dsdot.o test_swap.o test_rot.o test_dnrm2.o test_zscal.o
#test_rot.o test_swap.o test_axpy.o test_dotu.o test_dsdot.o test_fork.o #test_rot.o test_swap.o test_axpy.o test_dotu.o test_dsdot.o test_fork.o
ifneq ($(NO_LAPACK), 1) ifneq ($(NO_LAPACK), 1)

56
utest/test_zscal.c Normal file
View File

@ -0,0 +1,56 @@
#include "openblas_utest.h"
#include <cblas.h>
#ifdef BUILD_COMPLEX16
#ifndef NAN
#define NAN 0.0/0.0
#endif
#ifndef INFINITY
#define INFINITY 1.0/0.0
#endif
CTEST(zscal, i_nan)
{
double i[] = {0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1 };
double nan[] = {NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0};
cblas_zscal(9, i, &nan, 1);
ASSERT_TRUE(isnan(nan[0]));
ASSERT_TRUE(isnan(nan[1]));
ASSERT_TRUE(isnan(nan[16]));
ASSERT_TRUE(isnan(nan[17]));
}
CTEST(zscal, nan_i)
{
double i[] = {0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1 };
double nan[] = {NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0};
cblas_zscal(9, &nan, &i, 1);
ASSERT_TRUE(isnan(i[0]));
ASSERT_TRUE(isnan(i[1]));
ASSERT_TRUE(isnan(i[16]));
ASSERT_TRUE(isnan(i[17]));
}
CTEST(zscal, i_inf)
{
double i[] = {0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1 };
double inf[] = {INFINITY, 0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0};
cblas_zscal(9, i, &inf, 1);
ASSERT_TRUE(isnan(inf[0]));
ASSERT_TRUE(isinf(inf[1]));
ASSERT_TRUE(isnan(inf[16]));
ASSERT_TRUE(isinf(inf[17]));
}
CTEST(zscal, inf_i)
{
double i[] = {0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1 };
double inf[] = {INFINITY, 0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0};
cblas_zscal(9, &inf, &i, 1);
ASSERT_TRUE(isnan(i[0]));
ASSERT_TRUE(isinf(i[1]));
ASSERT_TRUE(isnan(i[16]));
ASSERT_TRUE(isinf(i[17]));
}
#endif

View File

@ -617,6 +617,51 @@ CTEST(max, smax_zero){
ASSERT_DBL_NEAR_TOL((double)(tr_max), (double)(te_max), SINGLE_EPS); ASSERT_DBL_NEAR_TOL((double)(tr_max), (double)(te_max), SINGLE_EPS);
} }
CTEST(zscal, i_nan)
{
double i[] = {0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1 };
double nan[] = {NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0};
cblas_zscal(9, i, &nan, 1);
ASSERT(isnan(nan[0]);
ASSERT(isnan(nan[1]);
ASSERT(isnan(nan[16]);
ASSERT(isnan(nan[17]);
}
CTEST(zscal, nan_i)
{
double i[] = {0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1 };
double nan[] = {NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0, NAN,0};
cblas_zscal(9, &nan, &i, 1);
ASSERT(isnan(i[0]);
ASSERT(isnan(i[1]);
ASSERT(isnan(i[16]);
ASSERT(isnan(i[17]);
}
CTEST(zscal, i_inf)
{
double i[] = {0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1 };
double inf[] = {INFINITY, 0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0};
cblas_zscal(9, i, &inf, 1);
ASSERT(isnan(inf[0]);
ASSERT(isinf(inf[1]);
ASSERT(isnan(inf[16]);
ASSERT(isinf(inf[17]);
}
CTEST(zscal, inf_i)
{
double i[] = {0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1, 0,1 };
double inf[] = {INFINITY, 0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0, INFINITY,0};
cblas_zscal(9, &inf, &i, 1);
ASSERT(isnan(i[0]);
ASSERT(isinf(i[1]);
ASSERT(isnan(i[16]);
ASSERT(isinf(i[17]);
}
int main(int argc, const char ** argv){ int main(int argc, const char ** argv){
CTEST_ADD (amax, samax); CTEST_ADD (amax, samax);
@ -648,7 +693,10 @@ int main(int argc, const char ** argv){
CTEST_ADD (swap,zswap_inc_0); CTEST_ADD (swap,zswap_inc_0);
CTEST_ADD (swap,sswap_inc_0); CTEST_ADD (swap,sswap_inc_0);
CTEST_ADD (swap,cswap_inc_0); CTEST_ADD (swap,cswap_inc_0);
CTEST_ADD (zscal, i_nan);
CTEST_ADD (zscal, nan_i);
CTEST_ADD (zscal, i_inf);
CTEST_ADD (zscal, inf_i);
int num_fail=0; int num_fail=0;
num_fail=ctest_main(argc, argv); num_fail=ctest_main(argc, argv);