Improve error message when invalid directory provided

pull/11305/head
coletdjnz 2 months ago
parent 600b2ece36
commit cd490eeeab
No known key found for this signature in database
GPG Key ID: 91984263BB39894A

@ -21,6 +21,7 @@ from yt_dlp.plugins import (
set_plugin_dirs, set_plugin_dirs,
disable_plugins, disable_plugins,
add_plugin_dirs, add_plugin_dirs,
get_plugin_spec,
) )
from yt_dlp._globals import ( from yt_dlp._globals import (
@ -50,24 +51,32 @@ POSTPROCESSOR_PLUGIN_SPEC = PluginSpec(
) )
class TestPlugins(unittest.TestCase): def reset_plugins():
TEST_PLUGIN_DIR = TEST_DATA_DIR / PACKAGE_NAME
def setUp(self):
plugin_ies.value = {} plugin_ies.value = {}
plugin_pps.value = {} plugin_pps.value = {}
plugin_dirs.value = ['external'] plugin_dirs.value = ['external']
plugin_specs.value = {} plugin_specs.value = {}
all_plugins_loaded.value = False all_plugins_loaded.value = False
plugins_enabled.value = True plugins_enabled.value = True
importlib.invalidate_caches()
# Clearing override plugins is probably difficult # Clearing override plugins is probably difficult
for module_name in tuple(sys.modules): for module_name in tuple(sys.modules):
for plugin_type in ('extractor', 'postprocessor'): for plugin_type in ('extractor', 'postprocessor'):
if module_name.startswith(f'{PACKAGE_NAME}.{plugin_type}.'): if module_name.startswith(f'{PACKAGE_NAME}.{plugin_type}.'):
del sys.modules[module_name] del sys.modules[module_name]
importlib.invalidate_caches()
class TestPlugins(unittest.TestCase):
TEST_PLUGIN_DIR = TEST_DATA_DIR / PACKAGE_NAME
def setUp(self):
reset_plugins()
def tearDown(self):
reset_plugins()
def test_directories_containing_plugins(self): def test_directories_containing_plugins(self):
self.assertIn(self.TEST_PLUGIN_DIR, map(Path, directories())) self.assertIn(self.TEST_PLUGIN_DIR, map(Path, directories()))
@ -207,6 +216,11 @@ class TestPlugins(unittest.TestCase):
self.assertIn(f'{PACKAGE_NAME}.extractor.package', sys.modules.keys()) self.assertIn(f'{PACKAGE_NAME}.extractor.package', sys.modules.keys())
self.assertIn('PackagePluginIE', plugin_ies.value) self.assertIn('PackagePluginIE', plugin_ies.value)
def test_invalid_plugin_dir(self):
set_plugin_dirs('invalid_dir')
with self.assertRaises(ValueError):
load_plugins(EXTRACTOR_PLUGIN_SPEC)
def test_add_plugin_dirs(self): def test_add_plugin_dirs(self):
custom_plugin_dir = str(TEST_DATA_DIR / 'plugin_packages') custom_plugin_dir = str(TEST_DATA_DIR / 'plugin_packages')
@ -244,6 +258,14 @@ class TestPlugins(unittest.TestCase):
ies = load_plugins(EXTRACTOR_PLUGIN_SPEC) ies = load_plugins(EXTRACTOR_PLUGIN_SPEC)
self.assertIn('NormalPluginIE', ies) self.assertIn('NormalPluginIE', ies)
def test_get_plugin_spec(self):
register_plugin_spec(EXTRACTOR_PLUGIN_SPEC)
register_plugin_spec(POSTPROCESSOR_PLUGIN_SPEC)
self.assertEqual(get_plugin_spec('extractor'), EXTRACTOR_PLUGIN_SPEC)
self.assertEqual(get_plugin_spec('postprocessor'), POSTPROCESSOR_PLUGIN_SPEC)
self.assertIsNone(get_plugin_spec('invalid'))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

@ -118,7 +118,10 @@ def internal_plugin_paths():
def candidate_plugin_paths(candidate): def candidate_plugin_paths(candidate):
yield from Path(candidate).iterdir() candidate_path = Path(candidate)
if not candidate_path.is_dir():
raise ValueError(f'Invalid plugin directory: {candidate_path}')
yield from candidate_path.iterdir()
yield from internal_plugin_paths() yield from internal_plugin_paths()

Loading…
Cancel
Save