From 6606817a86b96cc66aaa1d567b7bfce0c75500a2 Mon Sep 17 00:00:00 2001 From: pukkandan Date: Sun, 11 Jul 2021 03:29:44 +0530 Subject: [PATCH] [utils] Add `variadic` --- yt_dlp/extractor/common.py | 14 ++++---------- yt_dlp/extractor/instagram.py | 5 ++--- yt_dlp/postprocessor/ffmpeg.py | 11 +++-------- yt_dlp/utils.py | 12 ++++++------ 4 files changed, 15 insertions(+), 27 deletions(-) diff --git a/yt_dlp/extractor/common.py b/yt_dlp/extractor/common.py index 07f413733..8ad657fe5 100644 --- a/yt_dlp/extractor/common.py +++ b/yt_dlp/extractor/common.py @@ -19,7 +19,6 @@ from ..compat import ( compat_etree_Element, compat_etree_fromstring, compat_getpass, - compat_integer_types, compat_http_client, compat_os_name, compat_str, @@ -79,6 +78,7 @@ from ..utils import ( urljoin, url_basename, url_or_none, + variadic, xpath_element, xpath_text, xpath_with_ns, @@ -628,14 +628,10 @@ class InfoExtractor(object): assert isinstance(err, compat_urllib_error.HTTPError) if expected_status is None: return False - if isinstance(expected_status, compat_integer_types): - return err.code == expected_status - elif isinstance(expected_status, (list, tuple)): - return err.code in expected_status elif callable(expected_status): return expected_status(err.code) is True else: - assert False + return err.code in variadic(expected_status) def _request_webpage(self, url_or_request, video_id, note=None, errnote=None, fatal=True, data=None, headers={}, query={}, expected_status=None): """ @@ -1207,8 +1203,7 @@ class InfoExtractor(object): [^>]+?content=(["\'])(?P.*?)\2''' % re.escape(prop) def _og_search_property(self, prop, html, name=None, **kargs): - if not isinstance(prop, (list, tuple)): - prop = [prop] + prop = variadic(prop) if name is None: name = 'OpenGraph %s' % prop[0] og_regexes = [] @@ -1238,8 +1233,7 @@ class InfoExtractor(object): return self._og_search_property('url', html, **kargs) def _html_search_meta(self, name, html, display_name=None, fatal=False, **kwargs): - if not isinstance(name, (list, tuple)): - name = [name] + name = variadic(name) if display_name is None: display_name = name[0] return self._html_search_regex( diff --git a/yt_dlp/extractor/instagram.py b/yt_dlp/extractor/instagram.py index 12e10143c..1261f438e 100644 --- a/yt_dlp/extractor/instagram.py +++ b/yt_dlp/extractor/instagram.py @@ -19,6 +19,7 @@ from ..utils import ( std_headers, try_get, url_or_none, + variadic, ) @@ -188,9 +189,7 @@ class InstagramIE(InfoExtractor): uploader_id = media.get('owner', {}).get('username') def get_count(keys, kind): - if not isinstance(keys, (list, tuple)): - keys = [keys] - for key in keys: + for key in variadic(keys): count = int_or_none(try_get( media, (lambda x: x['edge_media_%s' % key]['count'], lambda x: x['%ss' % kind]['count']))) diff --git a/yt_dlp/postprocessor/ffmpeg.py b/yt_dlp/postprocessor/ffmpeg.py index 0d5e78f3d..fcc32ca03 100644 --- a/yt_dlp/postprocessor/ffmpeg.py +++ b/yt_dlp/postprocessor/ffmpeg.py @@ -24,6 +24,7 @@ from ..utils import ( process_communicate_or_kill, replace_extension, traverse_obj, + variadic, ) @@ -533,15 +534,9 @@ class FFmpegMetadataPP(FFmpegPostProcessor): def add(meta_list, info_list=None): if not meta_list: return - if not info_list: - info_list = meta_list - if not isinstance(meta_list, (list, tuple)): - meta_list = (meta_list,) - if not isinstance(info_list, (list, tuple)): - info_list = (info_list,) - for info_f in info_list: + for info_f in variadic(info_list or meta_list): if isinstance(info.get(info_f), (compat_str, compat_numeric_types)): - for meta_f in meta_list: + for meta_f in variadic(meta_list): metadata[meta_f] = info[info_f] break diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index f0d0097bb..888cfbb7e 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -4289,9 +4289,7 @@ def dict_get(d, key_or_keys, default=None, skip_false_values=True): def try_get(src, getter, expected_type=None): - if not isinstance(getter, (list, tuple)): - getter = [getter] - for get in getter: + for get in variadic(getter): try: v = get(src) except (AttributeError, KeyError, TypeError, IndexError): @@ -4964,11 +4962,9 @@ def cli_configuration_args(argdict, keys, default=[], use_compat=True): assert isinstance(keys, (list, tuple)) for key_list in keys: - if isinstance(key_list, compat_str): - key_list = (key_list,) arg_list = list(filter( lambda x: x is not None, - [argdict.get(key.lower()) for key in key_list])) + [argdict.get(key.lower()) for key in variadic(key_list)])) if arg_list: return [arg for args in arg_list for arg in args] return default @@ -6265,3 +6261,7 @@ def traverse_dict(dictn, keys, casesense=True): ''' For backward compatibility. Do not use ''' return traverse_obj(dictn, keys, casesense=casesense, is_user_input=True, traverse_string=True) + + +def variadic(x, allowed_types=str): + return x if isinstance(x, collections.Iterable) and not isinstance(x, allowed_types) else (x,)