From 476d9d07d6bfb3e83896abf6e85eb95c6dad405d Mon Sep 17 00:00:00 2001 From: poulami-sau <109125687+poulami-sau@users.noreply.github.com> Date: Mon, 22 Apr 2024 09:52:35 -0400 Subject: [PATCH] expanded the type annotation to include objects which may cast to a array and renamed other_side to other_side_as_array and asserted that it is not none --- src/_pytest/python_api.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 750769af7..7d89fdd80 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -142,7 +142,7 @@ class ApproxNumpy(ApproxBase): ) return f"approx({list_scalars!r})" - def _repr_compare(self, other_side: "ndarray") -> List[str]: + def _repr_compare(self, other_side: Union["ndarray", List[Any]]) -> List[str]: import itertools import math @@ -164,12 +164,13 @@ class ApproxNumpy(ApproxBase): ) # convert other_side to numpy array to ensure shape attribute is available - other_side = _as_numpy_array(other_side) + other_side_as_array = _as_numpy_array(other_side) + assert other_side_as_array is not None - if np_array_shape != other_side.shape: + if np_array_shape != other_side_as_array.shape: return [ "Impossible to compare arrays with different shapes.", - f"Shapes: {np_array_shape} and {other_side.shape}", + f"Shapes: {np_array_shape} and {other_side_as_array.shape}", ] number_of_elements = self.expected.size @@ -178,7 +179,7 @@ class ApproxNumpy(ApproxBase): different_ids = [] for index in itertools.product(*(range(i) for i in np_array_shape)): approx_value = get_value_from_nested_list(approx_side_as_seq, index) - other_value = get_value_from_nested_list(other_side, index) + other_value = get_value_from_nested_list(other_side_as_array, index) if approx_value != other_value: abs_diff = abs(approx_value.expected - other_value) max_abs_diff = max(max_abs_diff, abs_diff) @@ -191,7 +192,7 @@ class ApproxNumpy(ApproxBase): message_data = [ ( str(index), - str(get_value_from_nested_list(other_side, index)), + str(get_value_from_nested_list(other_side_as_array, index)), str(get_value_from_nested_list(approx_side_as_seq, index)), ) for index in different_ids