From 11a0c56166d837698f019a6485f7159946194eaf Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 27 Jun 2024 11:09:51 +0300 Subject: [PATCH] BENCH: add BLAS level 2 gemv and gbmv --- benchmark/pybench/benchmarks/bench_blas.py | 56 ++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/benchmark/pybench/benchmarks/bench_blas.py b/benchmark/pybench/benchmarks/bench_blas.py index cface1685..70ea03073 100644 --- a/benchmark/pybench/benchmarks/bench_blas.py +++ b/benchmark/pybench/benchmarks/bench_blas.py @@ -73,6 +73,62 @@ def test_daxpy(benchmark, n, variant): result = benchmark(run_daxpy, x, y, axpy) +# ### BLAS level 2 ### + +gemv_sizes = [100, 1000] + +def run_gemv(a, x, y, func): + res = func(1.0, a, x, y=y, overwrite_y=True) + return res + + +@pytest.mark.parametrize('variant', ['s', 'd', 'c', 'z']) +@pytest.mark.parametrize('n', gemv_sizes) +def test_dgemv(benchmark, n, variant): + rndm = np.random.RandomState(1234) + dtyp = dtype_map[variant] + + x = np.array(rndm.uniform(size=(n,)), dtype=dtyp) + y = np.empty(n, dtype=dtyp) + + a = np.array(rndm.uniform(size=(n,n)), dtype=dtyp) + x = np.array(rndm.uniform(size=(n,)), dtype=dtyp) + y = np.zeros(n, dtype=dtyp) + + gemv = ow.get_func('gemv', variant) + result = benchmark(run_gemv, a, x, y, gemv) + + assert result is y + + +# dgbmv + +dgbmv_sizes = [100, 1000] + +def run_gbmv(m, n, kl, ku, a, x, y, func): + res = func(m, n, kl, ku, 1.0, a, x, y=y, overwrite_y=True) + return res + + + +@pytest.mark.parametrize('variant', ['s', 'd', 'c', 'z']) +@pytest.mark.parametrize('n', dgbmv_sizes) +@pytest.mark.parametrize('kl', [1]) +def test_dgbmv(benchmark, n, kl, variant): + rndm = np.random.RandomState(1234) + dtyp = dtype_map[variant] + + x = np.array(rndm.uniform(size=(n,)), dtype=dtyp) + y = np.empty(n, dtype=dtyp) + + m = n + + a = rndm.uniform(size=(2*kl + 1, n)) + a = np.array(a, dtype=dtyp, order='F') + + gbmv = ow.get_func('gbmv', variant) + result = benchmark(run_gbmv, m, n, kl, kl, a, x, y, gbmv) + assert result is y # ### BLAS level 3 ###