Merge pull request #3052 from ashwinyes/arm64_fix_nrm2
arm64: Fix nrm2 for input vectors with Inf
This commit is contained in:
commit
d6c97cf010
|
@ -91,10 +91,10 @@ IDAMAXKERNEL = iamax_thunderx2t99.c
|
||||||
ICAMAXKERNEL = izamax_thunderx2t99.c
|
ICAMAXKERNEL = izamax_thunderx2t99.c
|
||||||
IZAMAXKERNEL = izamax_thunderx2t99.c
|
IZAMAXKERNEL = izamax_thunderx2t99.c
|
||||||
|
|
||||||
SNRM2KERNEL = nrm2.S
|
SNRM2KERNEL = scnrm2_thunderx2t99.c
|
||||||
DNRM2KERNEL = nrm2.S
|
DNRM2KERNEL = dznrm2_thunderx2t99.c
|
||||||
CNRM2KERNEL = znrm2.S
|
CNRM2KERNEL = scnrm2_thunderx2t99.c
|
||||||
ZNRM2KERNEL = znrm2.S
|
ZNRM2KERNEL = dznrm2_thunderx2t99.c
|
||||||
|
|
||||||
DDOTKERNEL = dot_thunderx2t99.c
|
DDOTKERNEL = dot_thunderx2t99.c
|
||||||
SDOTKERNEL = dot_thunderx2t99.c
|
SDOTKERNEL = dot_thunderx2t99.c
|
||||||
|
|
|
@ -153,12 +153,12 @@ IDAMAXKERNEL = iamax_thunderx2t99.c
|
||||||
ICAMAXKERNEL = izamax_thunderx2t99.c
|
ICAMAXKERNEL = izamax_thunderx2t99.c
|
||||||
IZAMAXKERNEL = izamax_thunderx2t99.c
|
IZAMAXKERNEL = izamax_thunderx2t99.c
|
||||||
|
|
||||||
SNRM2KERNEL = nrm2.S
|
SNRM2KERNEL = scnrm2_thunderx2t99.c
|
||||||
CNRM2KERNEL = nrm2.S
|
CNRM2KERNEL = scnrm2_thunderx2t99.c
|
||||||
#DNRM2KERNEL = dznrm2_thunderx2t99_fast.c
|
#DNRM2KERNEL = dznrm2_thunderx2t99_fast.c
|
||||||
#ZNRM2KERNEL = dznrm2_thunderx2t99_fast.c
|
#ZNRM2KERNEL = dznrm2_thunderx2t99_fast.c
|
||||||
DNRM2KERNEL = znrm2.S
|
DNRM2KERNEL = dznrm2_thunderx2t99.c
|
||||||
ZNRM2KERNEL = znrm2.S
|
ZNRM2KERNEL = dznrm2_thunderx2t99.c
|
||||||
|
|
||||||
|
|
||||||
DDOTKERNEL = dot_thunderx2t99.c
|
DDOTKERNEL = dot_thunderx2t99.c
|
||||||
|
|
|
@ -153,16 +153,13 @@ IDAMAXKERNEL = iamax_thunderx2t99.c
|
||||||
ICAMAXKERNEL = izamax_thunderx2t99.c
|
ICAMAXKERNEL = izamax_thunderx2t99.c
|
||||||
IZAMAXKERNEL = izamax_thunderx2t99.c
|
IZAMAXKERNEL = izamax_thunderx2t99.c
|
||||||
|
|
||||||
#SNRM2KERNEL = scnrm2_thunderx2t99.c
|
SNRM2KERNEL = scnrm2_thunderx2t99.c
|
||||||
#CNRM2KERNEL = scnrm2_thunderx2t99.c
|
CNRM2KERNEL = scnrm2_thunderx2t99.c
|
||||||
##DNRM2KERNEL = dznrm2_thunderx2t99_fast.c
|
#DNRM2KERNEL = dznrm2_thunderx2t99_fast.c
|
||||||
##ZNRM2KERNEL = dznrm2_thunderx2t99_fast.c
|
#ZNRM2KERNEL = dznrm2_thunderx2t99_fast.c
|
||||||
#DNRM2KERNEL = dznrm2_thunderx2t99.c
|
DNRM2KERNEL = dznrm2_thunderx2t99.c
|
||||||
#ZNRM2KERNEL = dznrm2_thunderx2t99.c
|
ZNRM2KERNEL = dznrm2_thunderx2t99.c
|
||||||
SNRM2KERNEL = nrm2.S
|
|
||||||
DNRM2KERNEL = nrm2.S
|
|
||||||
CNRM2KERNEL = znrm2.S
|
|
||||||
ZNRM2KERNEL = znrm2.S
|
|
||||||
|
|
||||||
DDOTKERNEL = dot_thunderx2t99.c
|
DDOTKERNEL = dot_thunderx2t99.c
|
||||||
SDOTKERNEL = dot_thunderx2t99.c
|
SDOTKERNEL = dot_thunderx2t99.c
|
||||||
|
|
|
@ -58,6 +58,7 @@ extern int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n
|
||||||
#define CUR_MAXINV "d8"
|
#define CUR_MAXINV "d8"
|
||||||
#define CUR_MAXINV_V "v8.2d"
|
#define CUR_MAXINV_V "v8.2d"
|
||||||
#define CUR_MAX_V "v8.2d"
|
#define CUR_MAX_V "v8.2d"
|
||||||
|
#define REGINF "d9"
|
||||||
|
|
||||||
static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
|
static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
|
||||||
double *ssq, double *scale)
|
double *ssq, double *scale)
|
||||||
|
@ -79,8 +80,10 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
|
||||||
" ble 9f //nrm2_kernel_L999 \n"
|
" ble 9f //nrm2_kernel_L999 \n"
|
||||||
|
|
||||||
"1: //nrm2_kernel_F_BEGIN: \n"
|
"1: //nrm2_kernel_F_BEGIN: \n"
|
||||||
|
" mov x6, #0x7FF0000000000000 //+Infinity \n"
|
||||||
" fmov "REGZERO", xzr \n"
|
" fmov "REGZERO", xzr \n"
|
||||||
" fmov "REGONE", #1.0 \n"
|
" fmov "REGONE", #1.0 \n"
|
||||||
|
" fmov "REGINF", x6 \n"
|
||||||
" lsl "INC_X", "INC_X", #"INC_SHIFT" \n"
|
" lsl "INC_X", "INC_X", #"INC_SHIFT" \n"
|
||||||
" mov "J", "N" \n"
|
" mov "J", "N" \n"
|
||||||
" cmp "J", xzr \n"
|
" cmp "J", xzr \n"
|
||||||
|
@ -104,6 +107,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
|
||||||
" ldr d4, ["X"] \n"
|
" ldr d4, ["X"] \n"
|
||||||
" fabs d4, d4 \n"
|
" fabs d4, d4 \n"
|
||||||
" fmax "CUR_MAX", "SCALE", d4 \n"
|
" fmax "CUR_MAX", "SCALE", d4 \n"
|
||||||
|
" fcmp "CUR_MAX", "REGINF" \n"
|
||||||
|
" beq 10f \n"
|
||||||
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
|
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
|
||||||
" fmul "SCALE", "SCALE", "SCALE" \n"
|
" fmul "SCALE", "SCALE", "SCALE" \n"
|
||||||
" fmul "SSQ", "SSQ", "SCALE" \n"
|
" fmul "SSQ", "SSQ", "SCALE" \n"
|
||||||
|
@ -116,6 +121,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
|
||||||
" ldr d3, ["X", #8] \n"
|
" ldr d3, ["X", #8] \n"
|
||||||
" fabs d3, d3 \n"
|
" fabs d3, d3 \n"
|
||||||
" fmax "CUR_MAX", "SCALE", d3 \n"
|
" fmax "CUR_MAX", "SCALE", d3 \n"
|
||||||
|
" fcmp "CUR_MAX", "REGINF" \n"
|
||||||
|
" beq 10f \n"
|
||||||
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
|
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
|
||||||
" fmul "SCALE", "SCALE", "SCALE" \n"
|
" fmul "SCALE", "SCALE", "SCALE" \n"
|
||||||
" fmul "SSQ", "SSQ", "SCALE" \n"
|
" fmul "SSQ", "SSQ", "SCALE" \n"
|
||||||
|
@ -158,6 +165,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
|
||||||
" fmaxp v24.2d, v24.2d, v26.2d \n"
|
" fmaxp v24.2d, v24.2d, v26.2d \n"
|
||||||
" fmaxp v24.2d, v24.2d, v24.2d \n"
|
" fmaxp v24.2d, v24.2d, v24.2d \n"
|
||||||
" fmax "CUR_MAX", "SCALE", d24 \n"
|
" fmax "CUR_MAX", "SCALE", d24 \n"
|
||||||
|
" fcmp "CUR_MAX", "REGINF" \n"
|
||||||
|
" beq 10f \n"
|
||||||
" fdiv "CUR_MAXINV", "REGONE", "CUR_MAX" \n"
|
" fdiv "CUR_MAXINV", "REGONE", "CUR_MAX" \n"
|
||||||
" //dup "CUR_MAX_V", v7.d[0] \n"
|
" //dup "CUR_MAX_V", v7.d[0] \n"
|
||||||
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
|
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
|
||||||
|
@ -217,6 +226,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
|
||||||
" fmaxp v24.2d, v24.2d, v26.2d \n"
|
" fmaxp v24.2d, v24.2d, v26.2d \n"
|
||||||
" fmaxp v24.2d, v24.2d, v24.2d \n"
|
" fmaxp v24.2d, v24.2d, v24.2d \n"
|
||||||
" fmax "CUR_MAX", "SCALE", d24 \n"
|
" fmax "CUR_MAX", "SCALE", d24 \n"
|
||||||
|
" fcmp "CUR_MAX", "REGINF" \n"
|
||||||
|
" beq 10f \n"
|
||||||
" fdiv "CUR_MAXINV", "REGONE", "CUR_MAX" \n"
|
" fdiv "CUR_MAXINV", "REGONE", "CUR_MAX" \n"
|
||||||
" //dup "CUR_MAX_V", v7.d[0] \n"
|
" //dup "CUR_MAX_V", v7.d[0] \n"
|
||||||
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
|
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
|
||||||
|
@ -265,6 +276,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
|
||||||
" ldr d4, ["X"] \n"
|
" ldr d4, ["X"] \n"
|
||||||
" fabs d4, d4 \n"
|
" fabs d4, d4 \n"
|
||||||
" fmax "CUR_MAX", "SCALE", d4 \n"
|
" fmax "CUR_MAX", "SCALE", d4 \n"
|
||||||
|
" fcmp "CUR_MAX", "REGINF" \n"
|
||||||
|
" beq 10f \n"
|
||||||
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
|
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
|
||||||
" fmul "SCALE", "SCALE", "SCALE" \n"
|
" fmul "SCALE", "SCALE", "SCALE" \n"
|
||||||
" fmul "SSQ", "SSQ", "SCALE" \n"
|
" fmul "SSQ", "SSQ", "SCALE" \n"
|
||||||
|
@ -276,6 +289,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
|
||||||
" ldr d3, ["X", #8] \n"
|
" ldr d3, ["X", #8] \n"
|
||||||
" fabs d3, d3 \n"
|
" fabs d3, d3 \n"
|
||||||
" fmax "CUR_MAX", "SCALE", d3 \n"
|
" fmax "CUR_MAX", "SCALE", d3 \n"
|
||||||
|
" fcmp "CUR_MAX", "REGINF" \n"
|
||||||
|
" beq 10f \n"
|
||||||
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
|
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
|
||||||
" fmul "SCALE", "SCALE", "SCALE" \n"
|
" fmul "SCALE", "SCALE", "SCALE" \n"
|
||||||
" fmul "SSQ", "SSQ", "SCALE" \n"
|
" fmul "SSQ", "SSQ", "SCALE" \n"
|
||||||
|
@ -291,6 +306,11 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
|
||||||
"9: //nrm2_kernel_L999: \n"
|
"9: //nrm2_kernel_L999: \n"
|
||||||
" str "SSQ", [%[SSQ_]] \n"
|
" str "SSQ", [%[SSQ_]] \n"
|
||||||
" str "SCALE", [%[SCALE_]] \n"
|
" str "SCALE", [%[SCALE_]] \n"
|
||||||
|
" b 11f \n"
|
||||||
|
"10: \n"
|
||||||
|
" str "REGINF", [%[SSQ_]] \n"
|
||||||
|
" str "REGINF", [%[SCALE_]] \n"
|
||||||
|
"11: \n"
|
||||||
|
|
||||||
:
|
:
|
||||||
: [SSQ_] "r" (ssq), //%0
|
: [SSQ_] "r" (ssq), //%0
|
||||||
|
@ -300,7 +320,7 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
|
||||||
[INCX_] "r" (inc_x) //%4
|
[INCX_] "r" (inc_x) //%4
|
||||||
: "cc",
|
: "cc",
|
||||||
"memory",
|
"memory",
|
||||||
"x0", "x1", "x2", "x3", "x4", "x5",
|
"x0", "x1", "x2", "x3", "x4", "x5", "x6",
|
||||||
"d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8"
|
"d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -359,6 +379,12 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
|
||||||
cur_ssq = *ptr;
|
cur_ssq = *ptr;
|
||||||
cur_scale = *(ptr + 1);
|
cur_scale = *(ptr + 1);
|
||||||
|
|
||||||
|
if (cur_ssq == INFINITY) {
|
||||||
|
ssq = INFINITY;
|
||||||
|
scale = INFINITY;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
if (cur_scale != 0) {
|
if (cur_scale != 0) {
|
||||||
if (cur_scale > scale) {
|
if (cur_scale > scale) {
|
||||||
scale = (scale / cur_scale);
|
scale = (scale / cur_scale);
|
||||||
|
|
Loading…
Reference in New Issue