Get plugin overrides working

pull/11305/head
coletdjnz 2 months ago
parent 9f1f2c5410
commit 4266658602
No known key found for this signature in database
GPG Key ID: 91984263BB39894A

@ -6,7 +6,6 @@ import hashlib
import http.client import http.client
import http.cookiejar import http.cookiejar
import http.cookies import http.cookies
import inspect
import itertools import itertools
import json import json
import math import math
@ -31,7 +30,6 @@ from ..compat import (
from ..cookies import LenientSimpleCookie from ..cookies import LenientSimpleCookie
from ..downloader.f4m import get_base_url, remove_encrypted_media from ..downloader.f4m import get_base_url, remove_encrypted_media
from ..downloader.hls import HlsFD from ..downloader.hls import HlsFD
from ..globals import plugin_overrides
from ..networking import HEADRequest, Request from ..networking import HEADRequest, Request
from ..networking.exceptions import ( from ..networking.exceptions import (
HTTPError, HTTPError,
@ -3934,17 +3932,8 @@ class InfoExtractor:
@classmethod @classmethod
def __init_subclass__(cls, *, plugin_name=None, **kwargs): def __init_subclass__(cls, *, plugin_name=None, **kwargs):
if plugin_name: if plugin_name is not None:
mro = inspect.getmro(cls) cls._plugin_name = plugin_name
super_class = cls.__wrapped__ = mro[mro.index(cls) + 1]
cls.PLUGIN_NAME, cls.ie_key = plugin_name, super_class.ie_key
cls.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}'
while getattr(super_class, '__wrapped__', None):
super_class = super_class.__wrapped__
setattr(sys.modules[super_class.__module__], super_class.__name__, cls)
plugin_overrides.get()[super_class].append(cls)
# if plugin_name is not None:
# cls._plugin_name = plugin_name
return super().__init_subclass__(**kwargs) return super().__init_subclass__(**kwargs)

@ -1,4 +1,5 @@
import contextlib import contextlib
import dataclasses
import enum import enum
import importlib import importlib
import importlib.abc import importlib.abc
@ -10,7 +11,9 @@ import os
import pkgutil import pkgutil
import sys import sys
import traceback import traceback
import warnings
import zipimport import zipimport
from contextvars import ContextVar
from pathlib import Path from pathlib import Path
from zipfile import ZipFile from zipfile import ZipFile
@ -19,7 +22,7 @@ from .globals import (
plugin_dirs, plugin_dirs,
plugin_ies, plugin_ies,
plugin_pps, plugin_pps,
postprocessors, postprocessors, plugin_overrides,
) )
from .compat import functools # isort: split from .compat import functools # isort: split
@ -42,12 +45,6 @@ class PluginType(enum.Enum):
EXTRACTORS = ('extractor', 'IE') EXTRACTORS = ('extractor', 'IE')
_plugin_type_lookup = {
PluginType.POSTPROCESSORS: (postprocessors, plugin_pps),
PluginType.EXTRACTORS: (extractors, plugin_ies),
}
class PluginLoader(importlib.abc.Loader): class PluginLoader(importlib.abc.Loader):
"""Dummy loader for virtual namespace packages""" """Dummy loader for virtual namespace packages"""
@ -165,22 +162,74 @@ def iter_modules(subpackage):
yield from pkgutil.iter_modules(path=pkg.__path__, prefix=f'{fullname}.') yield from pkgutil.iter_modules(path=pkg.__path__, prefix=f'{fullname}.')
def load_module(module, module_name, suffix): def get_regular_modules(module, module_name, suffix):
result = inspect.getmembers(module, lambda obj: ( # Find standard public plugin classes (not overrides)
return inspect.getmembers(module, lambda obj: (
inspect.isclass(obj) inspect.isclass(obj)
and obj.__name__.endswith(suffix) and obj.__name__.endswith(suffix)
and obj.__module__.startswith(module_name) and obj.__module__.startswith(module_name)
and not obj.__name__.startswith('_') and not obj.__name__.startswith('_')
and obj.__name__ in getattr(module, '__all__', [obj.__name__]))) and obj.__name__ in getattr(module, '__all__', [obj.__name__])
return result and getattr(obj, '_plugin_name', None) is None
))
load_module = get_regular_modules
def get_override_modules(module, module_name, suffix):
# Find override plugin classes
def predicate(obj):
if not inspect.isclass(obj):
return False
mro = inspect.getmro(obj)
return (
obj.__module__.startswith(module_name)
and getattr(obj, '_plugin_name', None) is not None
and mro[mro.index(obj) + 1].__name__.endswith(suffix)
)
return inspect.getmembers(module, predicate)
def configure_ie_override_class(klass, super_class, plugin_name):
ie_key = getattr(super_class, 'ie_key', None)
if not ie_key:
warnings.warn(f'Override plugin {klass} is not an extractor')
return False
klass.ie_key = ie_key
klass.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}'
@dataclasses.dataclass
class _PluginTypeConfig:
destination: ContextVar
plugin_destination: ContextVar
# Function to configure the override class. Return False to skip the class
# Takes (klass, super_class, plugin_name) as arguments
configure_override_func: callable = lambda *args: None
_plugin_type_lookup = {
PluginType.POSTPROCESSORS: _PluginTypeConfig(
destination=postprocessors,
plugin_destination=plugin_pps,
configure_override_func=None,
),
PluginType.EXTRACTORS: _PluginTypeConfig(
destination=extractors,
plugin_destination=plugin_ies,
configure_override_func=configure_ie_override_class,
),
}
def load_plugins(plugin_type: PluginType): def load_plugins(plugin_type: PluginType):
destination, plugin_destination = _plugin_type_lookup[plugin_type] plugin_config = _plugin_type_lookup[plugin_type]
name, suffix = plugin_type.value name, suffix = plugin_type.value
classes = {} regular_classes = {}
override_classes = {}
if os.environ.get('YTDLP_NO_PLUGINS'): if os.environ.get('YTDLP_NO_PLUGINS'):
return classes return regular_classes
for finder, module_name, _ in iter_modules(name): for finder, module_name, _ in iter_modules(name):
if any(x.startswith('_') for x in module_name.split('.')): if any(x.startswith('_') for x in module_name.split('.')):
@ -201,7 +250,8 @@ def load_plugins(plugin_type: PluginType):
f'Error while importing module {module_name!r}\n{traceback.format_exc(limit=-1)}', f'Error while importing module {module_name!r}\n{traceback.format_exc(limit=-1)}',
) )
continue continue
classes.update(load_module(module, module_name, suffix)) regular_classes.update(get_regular_modules(module, module_name, suffix))
override_classes.update(get_override_modules(module, module_name, suffix))
# Compat: old plugin system using __init__.py # Compat: old plugin system using __init__.py
# Note: plugins imported this way do not show up in directories() # Note: plugins imported this way do not show up in directories()
@ -215,41 +265,38 @@ def load_plugins(plugin_type: PluginType):
plugins = importlib.util.module_from_spec(spec) plugins = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = plugins sys.modules[spec.name] = plugins
spec.loader.exec_module(plugins) spec.loader.exec_module(plugins)
classes.update(load_module(plugins, spec.name, suffix)) regular_classes.update(get_regular_modules(plugins, spec.name, suffix))
# regular_plugins = {} # Configure override classes
# __init_subclass__ was removed so we manually add overrides for name, klass in override_classes.items():
# for name, klass in classes.items(): plugin_name = getattr(klass, '_plugin_name', None)
# plugin_name = getattr(klass, '_plugin_name', None) if not plugin_name:
# if not plugin_name: # these should always have plugin_name
# regular_plugins[name] = klass continue
# continue
# FIXME: Most likely something wrong here mro = inspect.getmro(klass)
# This does not work as plugin overrides are not available here. They are not imported in plugin_ies. super_class = klass.__wrapped__ = mro[mro.index(klass) + 1]
klass.PLUGIN_NAME = plugin_name
# mro = inspect.getmro(klass) if plugin_config.configure_override_func(klass, super_class, plugin_name) is False:
# super_class = klass.__wrapped__ = mro[mro.index(klass) + 1] continue
# klass.PLUGIN_NAME, klass.ie_key = plugin_name, super_class.ie_key
# klass.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}'
# while getattr(super_class, '__wrapped__', None):
# super_class = super_class.__wrapped__
# setattr(sys.modules[super_class.__module__], super_class.__name__, klass)
# plugin_overrides.get()[super_class].append(klass)
# Add the classes into the global plugin lookup while getattr(super_class, '__wrapped__', None):
plugin_destination.set(classes) super_class = super_class.__wrapped__
# # We want to prepend to the main lookup setattr(sys.modules[super_class.__module__], super_class.__name__, klass)
destination.set(merge_dicts(destination.get(), classes)) plugin_overrides.get()[super_class].append(klass)
return classes # Add the classes into the global plugin lookup for that type
plugin_config.plugin_destination.set(regular_classes)
# We want to prepend to the main lookup for that type
plugin_config.destination.set(merge_dicts(plugin_config.destination.get(), regular_classes))
return regular_classes
def load_all_plugin_types():
# for plugin_type in PluginType:
# load_plugins(plugin_type)
load_plugins(PluginType.EXTRACTORS)
def load_all_plugin_types():
for plugin_type in PluginType:
load_plugins(plugin_type)
sys.meta_path.insert(0, PluginFinder(f'{PACKAGE_NAME}.extractor', f'{PACKAGE_NAME}.postprocessor')) sys.meta_path.insert(0, PluginFinder(f'{PACKAGE_NAME}.extractor', f'{PACKAGE_NAME}.postprocessor'))

Loading…
Cancel
Save