revert back to init_subclass, add guard against multiple imports of same plugin

pull/11305/head
coletdjnz 3 months ago
parent 42771dde1c
commit a19dd28fdc
No known key found for this signature in database
GPG Key ID: 91984263BB39894A

@ -116,7 +116,7 @@ class TestPlugins(unittest.TestCase):
for module_name in tuple(sys.modules): for module_name in tuple(sys.modules):
if module_name.startswith(f'{PACKAGE_NAME}.extractor'): if module_name.startswith(f'{PACKAGE_NAME}.extractor'):
del sys.modules[module_name] del sys.modules[module_name]
plugins_ie = load_plugins(PluginType.EXTRACTORS) load_plugins(PluginType.EXTRACTORS)
from yt_dlp.extractor.generic import GenericIE from yt_dlp.extractor.generic import GenericIE
@ -124,6 +124,11 @@ class TestPlugins(unittest.TestCase):
self.assertEqual(GenericIE.SECONDARY_TEST_FIELD, 'underscore-override') self.assertEqual(GenericIE.SECONDARY_TEST_FIELD, 'underscore-override')
self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override') self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
importlib.invalidate_caches()
# test that loading a second time doesn't wrap a second time
load_plugins(PluginType.EXTRACTORS)
from yt_dlp.extractor.generic import GenericIE
self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
if __name__ == '__main__': if __name__ == '__main__':

@ -6,6 +6,7 @@ import hashlib
import http.client import http.client
import http.cookiejar import http.cookiejar
import http.cookies import http.cookies
import inspect
import itertools import itertools
import json import json
import math import math
@ -21,6 +22,7 @@ import urllib.parse
import urllib.request import urllib.request
import xml.etree.ElementTree import xml.etree.ElementTree
from .._globals import plugin_overrides
from ..compat import ( from ..compat import (
compat_etree_fromstring, compat_etree_fromstring,
compat_expanduser, compat_expanduser,
@ -3933,7 +3935,19 @@ class InfoExtractor:
@classmethod @classmethod
def __init_subclass__(cls, *, plugin_name=None, **kwargs): def __init_subclass__(cls, *, plugin_name=None, **kwargs):
if plugin_name is not None: if plugin_name is not None:
cls._plugin_name = plugin_name mro = inspect.getmro(cls)
next_mro_class = super_class = mro[mro.index(cls) + 1]
while getattr(super_class, '__wrapped__', None):
super_class = super_class.__wrapped__
if not any(override.PLUGIN_NAME == plugin_name for override in plugin_overrides.get()[super_class]):
cls.__wrapped__ = next_mro_class
cls.PLUGIN_NAME, cls.ie_key = plugin_name, next_mro_class.ie_key
cls.IE_NAME = f'{next_mro_class.IE_NAME}+{plugin_name}'
setattr(sys.modules[super_class.__module__], super_class.__name__, cls)
plugin_overrides.get()[super_class].append(cls)
return super().__init_subclass__(**kwargs) return super().__init_subclass__(**kwargs)

@ -11,7 +11,6 @@ import os
import pkgutil import pkgutil
import sys import sys
import traceback import traceback
import warnings
import zipimport import zipimport
from contextvars import ContextVar from contextvars import ContextVar
from pathlib import Path from pathlib import Path
@ -22,7 +21,8 @@ from ._globals import (
plugin_dirs, plugin_dirs,
plugin_ies, plugin_ies,
plugin_pps, plugin_pps,
postprocessors, plugin_overrides, ALL_PLUGINS_LOADED, postprocessors,
ALL_PLUGINS_LOADED,
) )
from .compat import functools # isort: split from .compat import functools # isort: split
@ -170,52 +170,24 @@ def get_regular_classes(module, module_name, suffix):
and obj.__module__.startswith(module_name) and obj.__module__.startswith(module_name)
and not obj.__name__.startswith('_') and not obj.__name__.startswith('_')
and obj.__name__ in getattr(module, '__all__', [obj.__name__]) and obj.__name__ in getattr(module, '__all__', [obj.__name__])
and getattr(obj, '_plugin_name', None) is None and getattr(obj, 'PLUGIN_NAME', None) is None
)) ))
def get_override_classes(module, module_name, suffix):
# Find override plugin classes
def predicate(obj):
if not inspect.isclass(obj):
return False
mro = inspect.getmro(obj)
return (
obj.__module__.startswith(module_name)
and getattr(obj, '_plugin_name', None) is not None
and mro[mro.index(obj) + 1].__name__.endswith(suffix)
)
return inspect.getmembers(module, predicate)
def configure_ie_override_class(klass, super_class, plugin_name):
ie_key = getattr(super_class, 'ie_key', None)
if not ie_key:
warnings.warn(f'Override plugin {klass} is not an extractor')
return False
klass.ie_key = ie_key
klass.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}'
@dataclasses.dataclass @dataclasses.dataclass
class _PluginTypeConfig: class _PluginTypeConfig:
destination: ContextVar destination: ContextVar
plugin_destination: ContextVar plugin_destination: ContextVar
# Function to configure the override class. Return False to skip the class
# Takes (klass, super_class, plugin_name) as arguments
configure_override_func: callable = lambda *args: None
_plugin_type_lookup = { _plugin_type_lookup = {
PluginType.POSTPROCESSORS: _PluginTypeConfig( PluginType.POSTPROCESSORS: _PluginTypeConfig(
destination=postprocessors, destination=postprocessors,
plugin_destination=plugin_pps, plugin_destination=plugin_pps,
configure_override_func=None,
), ),
PluginType.EXTRACTORS: _PluginTypeConfig( PluginType.EXTRACTORS: _PluginTypeConfig(
destination=extractors, destination=extractors,
plugin_destination=plugin_ies, plugin_destination=plugin_ies,
configure_override_func=configure_ie_override_class,
), ),
} }
@ -224,7 +196,6 @@ def load_plugins(plugin_type: PluginType):
plugin_config = _plugin_type_lookup[plugin_type] plugin_config = _plugin_type_lookup[plugin_type]
name, suffix = plugin_type.value name, suffix = plugin_type.value
regular_classes = {} regular_classes = {}
override_classes = {}
if os.environ.get('YTDLP_NO_PLUGINS'): if os.environ.get('YTDLP_NO_PLUGINS'):
return regular_classes return regular_classes
@ -248,7 +219,6 @@ def load_plugins(plugin_type: PluginType):
) )
continue continue
regular_classes.update(get_regular_classes(module, module_name, suffix)) regular_classes.update(get_regular_classes(module, module_name, suffix))
override_classes.update(get_override_classes(module, module_name, suffix))
# Compat: old plugin system using __init__.py # Compat: old plugin system using __init__.py
# Note: plugins imported this way do not show up in directories() # Note: plugins imported this way do not show up in directories()
@ -264,25 +234,6 @@ def load_plugins(plugin_type: PluginType):
spec.loader.exec_module(plugins) spec.loader.exec_module(plugins)
regular_classes.update(get_regular_classes(plugins, spec.name, suffix)) regular_classes.update(get_regular_classes(plugins, spec.name, suffix))
# Configure override classes
for _, klass in override_classes.items():
plugin_name = getattr(klass, '_plugin_name', None)
if not plugin_name:
# these should always have plugin_name
continue
mro = inspect.getmro(klass)
super_class = klass.__wrapped__ = mro[mro.index(klass) + 1]
klass.PLUGIN_NAME = plugin_name
if plugin_config.configure_override_func(klass, super_class, plugin_name) is False:
continue
while getattr(super_class, '__wrapped__', None):
super_class = super_class.__wrapped__
setattr(sys.modules[super_class.__module__], super_class.__name__, klass)
plugin_overrides.get()[super_class].append(klass)
# Add the classes into the global plugin lookup for that type # Add the classes into the global plugin lookup for that type
plugin_config.plugin_destination.set(regular_classes) plugin_config.plugin_destination.set(regular_classes)
# We want to prepend to the main lookup for that type # We want to prepend to the main lookup for that type

Loading…
Cancel
Save