python_api: handle array-like args in approx() (#8137)

This commit is contained in:
Jakob van Santen
2020-12-15 12:49:29 +01:00
committed by Bruno Oliveira
parent 8b8b1214f4
commit 8354995abc
3 changed files with 67 additions and 6 deletions

View File

@@ -447,6 +447,36 @@ class TestApprox:
assert a12 != approx(a21)
assert a21 != approx(a12)
def test_numpy_array_protocol(self):
"""
array-like objects such as tensorflow's DeviceArray are handled like ndarray.
See issue #8132
"""
np = pytest.importorskip("numpy")
class DeviceArray:
def __init__(self, value, size):
self.value = value
self.size = size
def __array__(self):
return self.value * np.ones(self.size)
class DeviceScalar:
def __init__(self, value):
self.value = value
def __array__(self):
return np.array(self.value)
expected = 1
actual = 1 + 1e-6
assert approx(expected) == DeviceArray(actual, size=1)
assert approx(expected) == DeviceArray(actual, size=2)
assert approx(expected) == DeviceScalar(actual)
assert approx(DeviceScalar(expected)) == actual
assert approx(DeviceScalar(expected)) == DeviceScalar(actual)
def test_doctests(self, mocked_doctest_runner) -> None:
import doctest