refactor: Small improvements to approx tuple bugfix

This commit is contained in:
Zach OBrien 2022-06-14 00:17:40 -04:00
parent 7ae22fcb88
commit 89fc472c3d
No known key found for this signature in database
GPG Key ID: 6716FFCF52B3A766
1 changed files with 5 additions and 5 deletions

View File

@ -135,8 +135,8 @@ class ApproxBase:
def _recursive_sequence_map(f, x):
"""Recursively map a function over a sequence of arbitary depth"""
seq_type = type(x)
if seq_type in (list, tuple):
if isinstance(x, (list, tuple)):
seq_type = type(x)
return seq_type(_recursive_sequence_map(f, xi) for xi in x)
else:
return f(x)
@ -168,7 +168,7 @@ class ApproxNumpy(ApproxBase):
return value
np_array_shape = self.expected.shape
approx_side_as_list = _recursive_sequence_map(
approx_side_as_seq = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist()
)
@ -183,7 +183,7 @@ class ApproxNumpy(ApproxBase):
max_rel_diff = -math.inf
different_ids = []
for index in itertools.product(*(range(i) for i in np_array_shape)):
approx_value = get_value_from_nested_list(approx_side_as_list, index)
approx_value = get_value_from_nested_list(approx_side_as_seq, index)
other_value = get_value_from_nested_list(other_side, index)
if approx_value != other_value:
abs_diff = abs(approx_value.expected - other_value)
@ -198,7 +198,7 @@ class ApproxNumpy(ApproxBase):
(
str(index),
str(get_value_from_nested_list(other_side, index)),
str(get_value_from_nested_list(approx_side_as_list, index)),
str(get_value_from_nested_list(approx_side_as_seq, index)),
)
for index in different_ids
]