diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 5fa219619..bd74ec296 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -450,9 +450,26 @@ class ApproxScalar(ApproxBase): # for compatibility with complex numbers. if math.isinf(abs(self.expected)): # type: ignore[arg-type] return False + + tolerance = self.tolerance + expected = self.expected + + if (isinstance(self.expected, Decimal) and not isinstance(actual, Decimal)): + try: + actual = Decimal(str(actual)) + tolerance = Decimal(str(tolerance)) + except TypeError: # + return False + elif (isinstance(actual, Decimal) and not isinstance(self.expected, Decimal)): + try : + expected = Decimal(str(expected)) + tolerance = Decimal(str(tolerance)) + except TypeError: + return False + # Return true if the two numbers are within the tolerance. - result: bool = abs(self.expected - actual) <= self.tolerance + result: bool = (abs(actual - expected) <= tolerance ) # type: ignore[arg-type] return result # Ignore type because of https://github.com/python/mypy/issues/4266. @@ -491,9 +508,15 @@ class ApproxScalar(ApproxBase): # we've made sure the user didn't ask for an absolute tolerance only, # because we don't want to raise errors about the relative tolerance if # we aren't even going to use it. - relative_tolerance = set_default( - self.rel, self.DEFAULT_RELATIVE_TOLERANCE - ) * abs(self.expected) + + if isinstance(self.rel, Decimal): + relative_tolerance = set_default( + self.rel, self.DEFAULT_RELATIVE_TOLERANCE + )*Decimal(str(abs(self.expected))) + else: + relative_tolerance = set_default( + self.rel, self.DEFAULT_RELATIVE_TOLERANCE + ) * abs(self.expected) if relative_tolerance < 0: raise ValueError( @@ -572,7 +595,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase: >>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6}) True - The comparison will be true if both mappings have the same keys and their + The comparision will be true if both mappings have the same keys and their respective values match the expected tolerances. **Tolerances** diff --git a/testing/python/approx.py b/testing/python/approx.py index 7b4fbad15..eadaba1da 100644 --- a/testing/python/approx.py +++ b/testing/python/approx.py @@ -341,6 +341,8 @@ class TestApprox: (0.0, -0.0), (345678, 345678), (Decimal("1.0001"), Decimal("1.0001")), + (1.0001, Decimal("1.0001")), + (Decimal("1.0001"), 1.0001), (Fraction(1, 3), Fraction(-1, -3)), ] for a, x in examples: @@ -515,6 +517,10 @@ class TestApprox: within_1e6 = [ (Decimal("1.000001"), Decimal("1.0")), (Decimal("-1.000001"), Decimal("-1.0")), + (-1.000001, Decimal("-1.0")), + (1.000001, Decimal("1.0")), + (Decimal("1.000001"), 1.0), + (Decimal("-1.000001"), -1.0), ] for a, x in within_1e6: assert a == approx(x) @@ -562,6 +568,15 @@ class TestApprox: expected = [Decimal("1"), Decimal("2")] assert actual == approx(expected) + + expected = [1, 2] + + assert actual == approx(expected) + + expected = [Decimal("1"), Decimal("2")] + actual = [1.000001, 2.000001] + + assert actual == approx(expected) def test_list_wrong_len(self): assert [1, 2] != approx([1]) @@ -603,6 +618,15 @@ class TestApprox: expected = {"b": Decimal("2"), "a": Decimal("1")} assert actual == approx(expected) + + actual = {"a": 1.000001, "b": 2.000001} + + assert actual == approx(expected) + + actual = {"a": Decimal("1.000001"), "b": Decimal("2.000001")} + expected = {"b": 2, "a": 1} + + assert actual == approx(expected) def test_dict_wrong_len(self): assert {"a": 1, "b": 2} != approx({"a": 1}) @@ -878,3 +902,4 @@ class TestApprox: """pytest.approx() should raise an error on unordered sequences (#9692).""" with pytest.raises(TypeError, match="only supports ordered sequences"): assert {1, 2, 3} == approx({1, 2, 3}) + \ No newline at end of file diff --git a/testing/test_collection.py b/testing/test_collection.py index 9099ec57f..e92a40076 100644 --- a/testing/test_collection.py +++ b/testing/test_collection.py @@ -51,7 +51,7 @@ class TestCollector: fn3 = pytester.collect_by_name(modcol, "test_fail") assert isinstance(fn3, pytest.Function) - assert not (fn1 == fn3) + assert fn1 != fn3 assert fn1 != fn3 for fn in fn1, fn2, fn3: