diff --git a/test/test_plugins.py b/test/test_plugins.py index af4b77dab..12315c40b 100644 --- a/test/test_plugins.py +++ b/test/test_plugins.py @@ -116,7 +116,7 @@ class TestPlugins(unittest.TestCase): for module_name in tuple(sys.modules): if module_name.startswith(f'{PACKAGE_NAME}.extractor'): del sys.modules[module_name] - plugins_ie = load_plugins(PluginType.EXTRACTORS) + load_plugins(PluginType.EXTRACTORS) from yt_dlp.extractor.generic import GenericIE @@ -124,6 +124,11 @@ class TestPlugins(unittest.TestCase): self.assertEqual(GenericIE.SECONDARY_TEST_FIELD, 'underscore-override') self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override') + importlib.invalidate_caches() + # test that loading a second time doesn't wrap a second time + load_plugins(PluginType.EXTRACTORS) + from yt_dlp.extractor.generic import GenericIE + self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override') if __name__ == '__main__': diff --git a/yt_dlp/extractor/common.py b/yt_dlp/extractor/common.py index 68909d9d6..1c8727504 100644 --- a/yt_dlp/extractor/common.py +++ b/yt_dlp/extractor/common.py @@ -6,6 +6,7 @@ import hashlib import http.client import http.cookiejar import http.cookies +import inspect import itertools import json import math @@ -21,6 +22,7 @@ import urllib.parse import urllib.request import xml.etree.ElementTree +from .._globals import plugin_overrides from ..compat import ( compat_etree_fromstring, compat_expanduser, @@ -3933,7 +3935,19 @@ class InfoExtractor: @classmethod def __init_subclass__(cls, *, plugin_name=None, **kwargs): if plugin_name is not None: - cls._plugin_name = plugin_name + mro = inspect.getmro(cls) + next_mro_class = super_class = mro[mro.index(cls) + 1] + + while getattr(super_class, '__wrapped__', None): + super_class = super_class.__wrapped__ + + if not any(override.PLUGIN_NAME == plugin_name for override in plugin_overrides.get()[super_class]): + cls.__wrapped__ = next_mro_class + cls.PLUGIN_NAME, cls.ie_key = plugin_name, next_mro_class.ie_key + cls.IE_NAME = f'{next_mro_class.IE_NAME}+{plugin_name}' + + setattr(sys.modules[super_class.__module__], super_class.__name__, cls) + plugin_overrides.get()[super_class].append(cls) return super().__init_subclass__(**kwargs) diff --git a/yt_dlp/plugins.py b/yt_dlp/plugins.py index 97ef41fa5..f09787e9d 100644 --- a/yt_dlp/plugins.py +++ b/yt_dlp/plugins.py @@ -11,7 +11,6 @@ import os import pkgutil import sys import traceback -import warnings import zipimport from contextvars import ContextVar from pathlib import Path @@ -22,7 +21,8 @@ from ._globals import ( plugin_dirs, plugin_ies, plugin_pps, - postprocessors, plugin_overrides, ALL_PLUGINS_LOADED, + postprocessors, + ALL_PLUGINS_LOADED, ) from .compat import functools # isort: split @@ -170,52 +170,24 @@ def get_regular_classes(module, module_name, suffix): and obj.__module__.startswith(module_name) and not obj.__name__.startswith('_') and obj.__name__ in getattr(module, '__all__', [obj.__name__]) - and getattr(obj, '_plugin_name', None) is None + and getattr(obj, 'PLUGIN_NAME', None) is None )) -def get_override_classes(module, module_name, suffix): - # Find override plugin classes - def predicate(obj): - if not inspect.isclass(obj): - return False - mro = inspect.getmro(obj) - return ( - obj.__module__.startswith(module_name) - and getattr(obj, '_plugin_name', None) is not None - and mro[mro.index(obj) + 1].__name__.endswith(suffix) - ) - return inspect.getmembers(module, predicate) - - -def configure_ie_override_class(klass, super_class, plugin_name): - ie_key = getattr(super_class, 'ie_key', None) - if not ie_key: - warnings.warn(f'Override plugin {klass} is not an extractor') - return False - klass.ie_key = ie_key - klass.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}' - - @dataclasses.dataclass class _PluginTypeConfig: destination: ContextVar plugin_destination: ContextVar - # Function to configure the override class. Return False to skip the class - # Takes (klass, super_class, plugin_name) as arguments - configure_override_func: callable = lambda *args: None _plugin_type_lookup = { PluginType.POSTPROCESSORS: _PluginTypeConfig( destination=postprocessors, plugin_destination=plugin_pps, - configure_override_func=None, ), PluginType.EXTRACTORS: _PluginTypeConfig( destination=extractors, plugin_destination=plugin_ies, - configure_override_func=configure_ie_override_class, ), } @@ -224,7 +196,6 @@ def load_plugins(plugin_type: PluginType): plugin_config = _plugin_type_lookup[plugin_type] name, suffix = plugin_type.value regular_classes = {} - override_classes = {} if os.environ.get('YTDLP_NO_PLUGINS'): return regular_classes @@ -248,7 +219,6 @@ def load_plugins(plugin_type: PluginType): ) continue regular_classes.update(get_regular_classes(module, module_name, suffix)) - override_classes.update(get_override_classes(module, module_name, suffix)) # Compat: old plugin system using __init__.py # Note: plugins imported this way do not show up in directories() @@ -264,25 +234,6 @@ def load_plugins(plugin_type: PluginType): spec.loader.exec_module(plugins) regular_classes.update(get_regular_classes(plugins, spec.name, suffix)) - # Configure override classes - for _, klass in override_classes.items(): - plugin_name = getattr(klass, '_plugin_name', None) - if not plugin_name: - # these should always have plugin_name - continue - - mro = inspect.getmro(klass) - super_class = klass.__wrapped__ = mro[mro.index(klass) + 1] - klass.PLUGIN_NAME = plugin_name - - if plugin_config.configure_override_func(klass, super_class, plugin_name) is False: - continue - - while getattr(super_class, '__wrapped__', None): - super_class = super_class.__wrapped__ - setattr(sys.modules[super_class.__module__], super_class.__name__, klass) - plugin_overrides.get()[super_class].append(klass) - # Add the classes into the global plugin lookup for that type plugin_config.plugin_destination.set(regular_classes) # We want to prepend to the main lookup for that type