python_api: handle array-like args in approx() (#8137)
This commit is contained in:
committed by
Bruno Oliveira
parent
8b8b1214f4
commit
8354995abc
@@ -447,6 +447,36 @@ class TestApprox:
|
||||
assert a12 != approx(a21)
|
||||
assert a21 != approx(a12)
|
||||
|
||||
def test_numpy_array_protocol(self):
|
||||
"""
|
||||
array-like objects such as tensorflow's DeviceArray are handled like ndarray.
|
||||
See issue #8132
|
||||
"""
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
class DeviceArray:
|
||||
def __init__(self, value, size):
|
||||
self.value = value
|
||||
self.size = size
|
||||
|
||||
def __array__(self):
|
||||
return self.value * np.ones(self.size)
|
||||
|
||||
class DeviceScalar:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __array__(self):
|
||||
return np.array(self.value)
|
||||
|
||||
expected = 1
|
||||
actual = 1 + 1e-6
|
||||
assert approx(expected) == DeviceArray(actual, size=1)
|
||||
assert approx(expected) == DeviceArray(actual, size=2)
|
||||
assert approx(expected) == DeviceScalar(actual)
|
||||
assert approx(DeviceScalar(expected)) == actual
|
||||
assert approx(DeviceScalar(expected)) == DeviceScalar(actual)
|
||||
|
||||
def test_doctests(self, mocked_doctest_runner) -> None:
|
||||
import doctest
|
||||
|
||||
|
||||
Reference in New Issue
Block a user