Merge pull request #3284 from martin-frbg/potrf_potri

Add lower thresholds for multithreading in POTRF/POTRI and improve the related benchmark
This commit is contained in:
Martin Kroeker 2021-06-30 07:42:45 +02:00 committed by GitHub
commit 623be6600a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 57 additions and 15 deletions

View File

@ -99,14 +99,15 @@ int main(int argc, char *argv[]){
char *p;
char btest = 'F';
blasint m, i, j, info, uplos=0;
double flops;
blasint m, i, j, l, info, uplos=0;
double flops = 0.;
int from = 1;
int to = 200;
int step = 1;
int loops = 1;
double time1;
double time1, timeg;
argc--;argv++;
@ -119,6 +120,8 @@ int main(int argc, char *argv[]){
if ((p = getenv("OPENBLAS_TEST"))) btest=*p;
if ((p = getenv("OPENBLAS_LOOPS"))) loops=*p;
fprintf(stderr, "From : %3d To : %3d Step = %3d Uplo = %c\n", from, to, step,*uplo[uplos]);
if (( a = (FLOAT *)malloc(sizeof(FLOAT) * to * to * COMPSIZE)) == NULL){
@ -129,19 +132,21 @@ int main(int argc, char *argv[]){
fprintf(stderr,"Out of Memory!!\n");exit(1);
}
for(m = from; m <= to; m += step){
for(m = from; m <= to; m += step){
timeg=0.;
for (l = 0; l < loops; l++) {
#ifndef COMPLEX
if (uplos & 1) {
for (j = 0; j < m; j++) {
for(i = 0; i < j; i++) a[(long)i + (long)j * (long)m] = 0.;
a[(long)j + (long)j * (long)m] = ((double) rand() / (double) RAND_MAX) + 8.;
a[(long)j + (long)j * (long)m] = ((double) rand() / (double) RAND_MAX) + 8.;
for(i = j + 1; i < m; i++) a[(long)i + (long)j * (long)m] = ((double) rand() / (double) RAND_MAX) - 0.5;
}
} else {
for (j = 0; j < m; j++) {
for(i = 0; i < j; i++) a[(long)i + (long)j * (long)m] = ((double) rand() / (double) RAND_MAX) - 0.5;
a[(long)j + (long)j * (long)m] = ((double) rand() / (double) RAND_MAX) + 8.;
a[(long)j + (long)j * (long)m] = ((double) rand() / (double) RAND_MAX) + 8.;
for(i = j + 1; i < m; i++) a[(long)i + (long)j * (long)m] = 0.;
}
}
@ -192,8 +197,8 @@ int main(int argc, char *argv[]){
exit(1);
}
time1 = getsec();
flops = COMPSIZE * COMPSIZE * (1.0/3.0 * (double)m * (double)m *(double)m +1.0/2.0* (double)m *(double)m + 1.0/6.0* (double)m) / time1 * 1.e-6;
if ( btest == 'F')
timeg += getsec();
if ( btest == 'S' )
{
@ -214,9 +219,7 @@ int main(int argc, char *argv[]){
fprintf(stderr, "Potrs info = %d\n", info);
exit(1);
}
time1 = getsec();
flops = COMPSIZE * COMPSIZE * (2.0 * (double)m * (double)m *(double)m ) / time1 * 1.e-6;
timeg += getsec();
}
if ( btest == 'I' )
@ -232,11 +235,17 @@ int main(int argc, char *argv[]){
fprintf(stderr, "Potri info = %d\n", info);
exit(1);
}
time1 = getsec();
flops = COMPSIZE * COMPSIZE * (2.0/3.0 * (double)m * (double)m *(double)m +1.0/2.0* (double)m *(double)m + 5.0/6.0* (double)m) / time1 * 1.e-6;
timeg += getsec();
}
} // loops
time1 = timeg/(double)loops;
if ( btest == 'F')
flops = COMPSIZE * COMPSIZE * (1.0/3.0 * (double)m * (double)m *(double)m +1.0/2.0* (double)m *(double)m + 1.0/6.0* (double)m) / time1 * 1.e-6;
if ( btest == 'S')
flops = COMPSIZE * COMPSIZE * (2.0 * (double)m * (double)m *(double)m ) / time1 * 1.e-6;
if ( btest == 'I')
flops = COMPSIZE * COMPSIZE * (2.0/3.0 * (double)m * (double)m *(double)m +1.0/2.0* (double)m *(double)m + 5.0/6.0* (double)m) / time1 * 1.e-6;
fprintf(stderr, "%8d : %10.2f MFlops : %10.3f Sec : Test=%c\n",m,flops ,time1,btest);

View File

@ -709,6 +709,13 @@ int BLASFUNC(cpotrf)(char *, blasint *, float *, blasint *, blasint *);
int BLASFUNC(zpotrf)(char *, blasint *, double *, blasint *, blasint *);
int BLASFUNC(xpotrf)(char *, blasint *, xdouble *, blasint *, blasint *);
int BLASFUNC(spotri)(char *, blasint *, float *, blasint *, blasint *);
int BLASFUNC(dpotri)(char *, blasint *, double *, blasint *, blasint *);
int BLASFUNC(qpotri)(char *, blasint *, xdouble *, blasint *, blasint *);
int BLASFUNC(cpotri)(char *, blasint *, float *, blasint *, blasint *);
int BLASFUNC(zpotri)(char *, blasint *, double *, blasint *, blasint *);
int BLASFUNC(xpotri)(char *, blasint *, xdouble *, blasint *, blasint *);
int BLASFUNC(spotrs)(char *, blasint *, blasint *, float *, blasint *, float *, blasint *, blasint *);
int BLASFUNC(dpotrs)(char *, blasint *, blasint *, double *, blasint *, double *, blasint *, blasint *);
int BLASFUNC(qpotrs)(char *, blasint *, blasint *, xdouble *, blasint *, xdouble *, blasint *, blasint *);

View File

@ -112,6 +112,13 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){
#ifdef SMP
args.common = NULL;
#ifndef DOUBLE
if (args.n <128)
#else
if (args.n <64)
#endif
args.nthreads = 1;
else
args.nthreads = num_cpu_avail(4);
if (args.nthreads == 1) {

View File

@ -121,6 +121,9 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){
#ifdef SMP
args.common = NULL;
if (args.n < 180)
args.nthreads = 1;
else
args.nthreads = num_cpu_avail(4);
if (args.nthreads == 1) {

View File

@ -112,6 +112,13 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){
#ifdef SMP
args.common = NULL;
#ifndef DOUBLE
if (args.n < 64)
#else
if (args.n < 64)
#endif
args.nthreads = 1;
else
args.nthreads = num_cpu_avail(4);
if (args.nthreads == 1) {

View File

@ -121,6 +121,15 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){
#ifdef SMP
args.nthreads = num_cpu_avail(4);
#ifndef DOUBLE
if (args.n < 200)
#else
if (args.n < 150)
#endif
args.nthreads=1;
else
#endif
args.nthreads = num_cpu_avail(4);
if (args.nthreads == 1) {
#endif