revert back to init_subclass, add guard against multiple imports of same plugin

pull/11305/head
coletdjnz 3 months ago
parent 42771dde1c
commit a19dd28fdc
No known key found for this signature in database
GPG Key ID: 91984263BB39894A

@ -116,7 +116,7 @@ class TestPlugins(unittest.TestCase):
for module_name in tuple(sys.modules):
if module_name.startswith(f'{PACKAGE_NAME}.extractor'):
del sys.modules[module_name]
plugins_ie = load_plugins(PluginType.EXTRACTORS)
load_plugins(PluginType.EXTRACTORS)
from yt_dlp.extractor.generic import GenericIE
@ -124,6 +124,11 @@ class TestPlugins(unittest.TestCase):
self.assertEqual(GenericIE.SECONDARY_TEST_FIELD, 'underscore-override')
self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
importlib.invalidate_caches()
# test that loading a second time doesn't wrap a second time
load_plugins(PluginType.EXTRACTORS)
from yt_dlp.extractor.generic import GenericIE
self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
if __name__ == '__main__':

@ -6,6 +6,7 @@ import hashlib
import http.client
import http.cookiejar
import http.cookies
import inspect
import itertools
import json
import math
@ -21,6 +22,7 @@ import urllib.parse
import urllib.request
import xml.etree.ElementTree
from .._globals import plugin_overrides
from ..compat import (
compat_etree_fromstring,
compat_expanduser,
@ -3933,7 +3935,19 @@ class InfoExtractor:
@classmethod
def __init_subclass__(cls, *, plugin_name=None, **kwargs):
if plugin_name is not None:
cls._plugin_name = plugin_name
mro = inspect.getmro(cls)
next_mro_class = super_class = mro[mro.index(cls) + 1]
while getattr(super_class, '__wrapped__', None):
super_class = super_class.__wrapped__
if not any(override.PLUGIN_NAME == plugin_name for override in plugin_overrides.get()[super_class]):
cls.__wrapped__ = next_mro_class
cls.PLUGIN_NAME, cls.ie_key = plugin_name, next_mro_class.ie_key
cls.IE_NAME = f'{next_mro_class.IE_NAME}+{plugin_name}'
setattr(sys.modules[super_class.__module__], super_class.__name__, cls)
plugin_overrides.get()[super_class].append(cls)
return super().__init_subclass__(**kwargs)

@ -11,7 +11,6 @@ import os
import pkgutil
import sys
import traceback
import warnings
import zipimport
from contextvars import ContextVar
from pathlib import Path
@ -22,7 +21,8 @@ from ._globals import (
plugin_dirs,
plugin_ies,
plugin_pps,
postprocessors, plugin_overrides, ALL_PLUGINS_LOADED,
postprocessors,
ALL_PLUGINS_LOADED,
)
from .compat import functools # isort: split
@ -170,52 +170,24 @@ def get_regular_classes(module, module_name, suffix):
and obj.__module__.startswith(module_name)
and not obj.__name__.startswith('_')
and obj.__name__ in getattr(module, '__all__', [obj.__name__])
and getattr(obj, '_plugin_name', None) is None
and getattr(obj, 'PLUGIN_NAME', None) is None
))
def get_override_classes(module, module_name, suffix):
# Find override plugin classes
def predicate(obj):
if not inspect.isclass(obj):
return False
mro = inspect.getmro(obj)
return (
obj.__module__.startswith(module_name)
and getattr(obj, '_plugin_name', None) is not None
and mro[mro.index(obj) + 1].__name__.endswith(suffix)
)
return inspect.getmembers(module, predicate)
def configure_ie_override_class(klass, super_class, plugin_name):
ie_key = getattr(super_class, 'ie_key', None)
if not ie_key:
warnings.warn(f'Override plugin {klass} is not an extractor')
return False
klass.ie_key = ie_key
klass.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}'
@dataclasses.dataclass
class _PluginTypeConfig:
destination: ContextVar
plugin_destination: ContextVar
# Function to configure the override class. Return False to skip the class
# Takes (klass, super_class, plugin_name) as arguments
configure_override_func: callable = lambda *args: None
_plugin_type_lookup = {
PluginType.POSTPROCESSORS: _PluginTypeConfig(
destination=postprocessors,
plugin_destination=plugin_pps,
configure_override_func=None,
),
PluginType.EXTRACTORS: _PluginTypeConfig(
destination=extractors,
plugin_destination=plugin_ies,
configure_override_func=configure_ie_override_class,
),
}
@ -224,7 +196,6 @@ def load_plugins(plugin_type: PluginType):
plugin_config = _plugin_type_lookup[plugin_type]
name, suffix = plugin_type.value
regular_classes = {}
override_classes = {}
if os.environ.get('YTDLP_NO_PLUGINS'):
return regular_classes
@ -248,7 +219,6 @@ def load_plugins(plugin_type: PluginType):
)
continue
regular_classes.update(get_regular_classes(module, module_name, suffix))
override_classes.update(get_override_classes(module, module_name, suffix))
# Compat: old plugin system using __init__.py
# Note: plugins imported this way do not show up in directories()
@ -264,25 +234,6 @@ def load_plugins(plugin_type: PluginType):
spec.loader.exec_module(plugins)
regular_classes.update(get_regular_classes(plugins, spec.name, suffix))
# Configure override classes
for _, klass in override_classes.items():
plugin_name = getattr(klass, '_plugin_name', None)
if not plugin_name:
# these should always have plugin_name
continue
mro = inspect.getmro(klass)
super_class = klass.__wrapped__ = mro[mro.index(klass) + 1]
klass.PLUGIN_NAME = plugin_name
if plugin_config.configure_override_func(klass, super_class, plugin_name) is False:
continue
while getattr(super_class, '__wrapped__', None):
super_class = super_class.__wrapped__
setattr(sys.modules[super_class.__module__], super_class.__name__, klass)
plugin_overrides.get()[super_class].append(klass)
# Add the classes into the global plugin lookup for that type
plugin_config.plugin_destination.set(regular_classes)
# We want to prepend to the main lookup for that type

Loading…
Cancel
Save