[utils] `traverse_obj`: Allow iterables in traversal (#6902)

Authored by: Grub4K
pull/6907/head
Simon Sawicki 2 years ago committed by GitHub
parent c16644642b
commit 21b5ec86c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2016,6 +2016,8 @@ Line 1
msg='nested `...` queries should work') msg='nested `...` queries should work')
self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4), self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4),
msg='`...` query result should be flattened') msg='`...` query result should be flattened')
self.assertEqual(traverse_obj(range(4), ...), list(range(4)),
msg='`...` should accept iterables')
# Test function as key # Test function as key
self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)), self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
@ -2023,6 +2025,8 @@ Line 1
msg='function as query key should perform a filter based on (key, value)') msg='function as query key should perform a filter based on (key, value)')
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'}, self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
msg='exceptions in the query function should be catched') msg='exceptions in the query function should be catched')
self.assertEqual(traverse_obj(range(4), lambda _, x: x % 2 == 0), [0, 2],
msg='function key should accept iterables')
if __debug__: if __debug__:
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'): with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
traverse_obj(_TEST_DATA, lambda a: ...) traverse_obj(_TEST_DATA, lambda a: ...)

@ -5528,7 +5528,6 @@ def traverse_obj(
If no `default` is given and the last path branches, a `list` of results If no `default` is given and the last path branches, a `list` of results
is always returned. If a path ends on a `dict` that result will always be a `dict`. is always returned. If a path ends on a `dict` that result will always be a `dict`.
""" """
is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes))
casefold = lambda k: k.casefold() if isinstance(k, str) else k casefold = lambda k: k.casefold() if isinstance(k, str) else k
if isinstance(expected_type, type): if isinstance(expected_type, type):
@ -5564,7 +5563,7 @@ def traverse_obj(
branching = True branching = True
if isinstance(obj, collections.abc.Mapping): if isinstance(obj, collections.abc.Mapping):
result = obj.values() result = obj.values()
elif is_sequence(obj): elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)):
result = obj result = obj
elif isinstance(obj, re.Match): elif isinstance(obj, re.Match):
result = obj.groups() result = obj.groups()
@ -5578,7 +5577,7 @@ def traverse_obj(
branching = True branching = True
if isinstance(obj, collections.abc.Mapping): if isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items() iter_obj = obj.items()
elif is_sequence(obj): elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)):
iter_obj = enumerate(obj) iter_obj = enumerate(obj)
elif isinstance(obj, re.Match): elif isinstance(obj, re.Match):
iter_obj = itertools.chain( iter_obj = itertools.chain(
@ -5614,7 +5613,7 @@ def traverse_obj(
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
elif isinstance(key, (int, slice)): elif isinstance(key, (int, slice)):
if is_sequence(obj): if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, (str, bytes)):
branching = isinstance(key, slice) branching = isinstance(key, slice)
with contextlib.suppress(IndexError): with contextlib.suppress(IndexError):
result = obj[key] result = obj[key]

Loading…
Cancel
Save