BENCH: add BLAS level 2 gemv and gbmv
This commit is contained in:
parent
400cf9f63d
commit
11a0c56166
|
@ -73,6 +73,62 @@ def test_daxpy(benchmark, n, variant):
|
|||
result = benchmark(run_daxpy, x, y, axpy)
|
||||
|
||||
|
||||
# ### BLAS level 2 ###
|
||||
|
||||
gemv_sizes = [100, 1000]
|
||||
|
||||
def run_gemv(a, x, y, func):
|
||||
res = func(1.0, a, x, y=y, overwrite_y=True)
|
||||
return res
|
||||
|
||||
|
||||
@pytest.mark.parametrize('variant', ['s', 'd', 'c', 'z'])
|
||||
@pytest.mark.parametrize('n', gemv_sizes)
|
||||
def test_dgemv(benchmark, n, variant):
|
||||
rndm = np.random.RandomState(1234)
|
||||
dtyp = dtype_map[variant]
|
||||
|
||||
x = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
|
||||
y = np.empty(n, dtype=dtyp)
|
||||
|
||||
a = np.array(rndm.uniform(size=(n,n)), dtype=dtyp)
|
||||
x = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
|
||||
y = np.zeros(n, dtype=dtyp)
|
||||
|
||||
gemv = ow.get_func('gemv', variant)
|
||||
result = benchmark(run_gemv, a, x, y, gemv)
|
||||
|
||||
assert result is y
|
||||
|
||||
|
||||
# dgbmv
|
||||
|
||||
dgbmv_sizes = [100, 1000]
|
||||
|
||||
def run_gbmv(m, n, kl, ku, a, x, y, func):
|
||||
res = func(m, n, kl, ku, 1.0, a, x, y=y, overwrite_y=True)
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize('variant', ['s', 'd', 'c', 'z'])
|
||||
@pytest.mark.parametrize('n', dgbmv_sizes)
|
||||
@pytest.mark.parametrize('kl', [1])
|
||||
def test_dgbmv(benchmark, n, kl, variant):
|
||||
rndm = np.random.RandomState(1234)
|
||||
dtyp = dtype_map[variant]
|
||||
|
||||
x = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
|
||||
y = np.empty(n, dtype=dtyp)
|
||||
|
||||
m = n
|
||||
|
||||
a = rndm.uniform(size=(2*kl + 1, n))
|
||||
a = np.array(a, dtype=dtyp, order='F')
|
||||
|
||||
gbmv = ow.get_func('gbmv', variant)
|
||||
result = benchmark(run_gbmv, m, n, kl, kl, a, x, y, gbmv)
|
||||
assert result is y
|
||||
|
||||
|
||||
# ### BLAS level 3 ###
|
||||
|
|
Loading…
Reference in New Issue