[utils] Generalize `traverse_dict` to `traverse_obj`

pull/310/head
pukkandan 4 years ago
parent beb982bead
commit 324ad82006
No known key found for this signature in database
GPG Key ID: 0F00D95A001F4698

@ -101,7 +101,7 @@ from .utils import (
strftime_or_none, strftime_or_none,
subtitles_filename, subtitles_filename,
to_high_limit_path, to_high_limit_path,
traverse_dict, traverse_obj,
UnavailableVideoError, UnavailableVideoError,
url_basename, url_basename,
version_tuple, version_tuple,
@ -855,7 +855,7 @@ class YoutubeDL(object):
def get_value(mdict): def get_value(mdict):
# Object traversal # Object traversal
fields = mdict['fields'].split('.') fields = mdict['fields'].split('.')
value = traverse_dict(info_dict, fields) value = traverse_obj(info_dict, fields)
# Negative # Negative
if mdict['negate']: if mdict['negate']:
value = float_or_none(value) value = float_or_none(value)
@ -872,7 +872,7 @@ class YoutubeDL(object):
item, multiplier = (item[1:], -1) if item[0] == '-' else (item, 1) item, multiplier = (item[1:], -1) if item[0] == '-' else (item, 1)
offset = float_or_none(item) offset = float_or_none(item)
if offset is None: if offset is None:
offset = float_or_none(traverse_dict(info_dict, item.split('.'))) offset = float_or_none(traverse_obj(info_dict, item.split('.')))
try: try:
value = operator(value, multiplier * offset) value = operator(value, multiplier * offset)
except (TypeError, ZeroDivisionError): except (TypeError, ZeroDivisionError):

@ -23,7 +23,7 @@ from ..utils import (
ISO639Utils, ISO639Utils,
process_communicate_or_kill, process_communicate_or_kill,
replace_extension, replace_extension,
traverse_dict, traverse_obj,
) )
@ -229,7 +229,7 @@ class FFmpegPostProcessor(PostProcessor):
def get_stream_number(self, path, keys, value): def get_stream_number(self, path, keys, value):
streams = self.get_metadata_object(path)['streams'] streams = self.get_metadata_object(path)['streams']
num = next( num = next(
(i for i, stream in enumerate(streams) if traverse_dict(stream, keys, casesense=False) == value), (i for i, stream in enumerate(streams) if traverse_obj(stream, keys, casesense=False) == value),
None) None)
return num, len(streams) return num, len(streams)

@ -6181,21 +6181,38 @@ def load_plugins(name, suffix, namespace):
return classes return classes
def traverse_dict(dictn, keys, casesense=True): def traverse_obj(obj, keys, *, casesense=True, is_user_input=False, traverse_string=False):
''' Traverse nested list/dict/tuple
@param casesense Whether to consider dictionary keys as case sensitive
@param is_user_input Whether the keys are generated from user input. If True,
strings are converted to int/slice if necessary
@param traverse_string Whether to traverse inside strings. If True, any
non-compatible object will also be converted into a string
'''
keys = list(keys)[::-1] keys = list(keys)[::-1]
while keys: while keys:
key = keys.pop() key = keys.pop()
if isinstance(dictn, dict): if isinstance(obj, dict):
assert isinstance(key, compat_str)
if not casesense: if not casesense:
dictn = {k.lower(): v for k, v in dictn.items()} obj = {k.lower(): v for k, v in obj.items()}
key = key.lower() key = key.lower()
dictn = dictn.get(key) obj = obj.get(key)
elif isinstance(dictn, (list, tuple, compat_str)):
if ':' in key:
key = slice(*map(int_or_none, key.split(':')))
else: else:
key = int_or_none(key) if is_user_input:
dictn = try_get(dictn, lambda x: x[key]) key = (int_or_none(key) if ':' not in key
else slice(*map(int_or_none, key.split(':'))))
if not isinstance(obj, (list, tuple)):
if traverse_string:
obj = compat_str(obj)
else: else:
return None return None
return dictn assert isinstance(key, (int, slice))
obj = try_get(obj, lambda x: x[key])
return obj
def traverse_dict(dictn, keys, casesense=True):
''' For backward compatibility. Do not use '''
return traverse_obj(dictn, keys, casesense=casesense,
is_user_input=True, traverse_string=True)

Loading…
Cancel
Save