Fix tests and add test for reloading

pull/11305/head
Simon Sawicki 2 years ago
parent 8c297d184c
commit 00709d5e9a

@ -10,7 +10,8 @@ TEST_DATA_DIR = Path(os.path.dirname(os.path.abspath(__file__)), 'testdata')
sys.path.append(str(TEST_DATA_DIR)) sys.path.append(str(TEST_DATA_DIR))
importlib.invalidate_caches() importlib.invalidate_caches()
from yt_dlp.plugins import PACKAGE_NAME, directories, load_plugins from yt_dlp.plugins import PACKAGE_NAME, PluginType, directories, load_plugins
from yt_dlp.globals import extractors, postprocessors
class TestPlugins(unittest.TestCase): class TestPlugins(unittest.TestCase):
@ -24,7 +25,7 @@ class TestPlugins(unittest.TestCase):
for module_name in tuple(sys.modules): for module_name in tuple(sys.modules):
if module_name.startswith(f'{PACKAGE_NAME}.extractor'): if module_name.startswith(f'{PACKAGE_NAME}.extractor'):
del sys.modules[module_name] del sys.modules[module_name]
plugins_ie = load_plugins('extractor', 'IE') plugins_ie = load_plugins(PluginType.EXTRACTORS)
self.assertIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys()) self.assertIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys())
self.assertIn('NormalPluginIE', plugins_ie.keys()) self.assertIn('NormalPluginIE', plugins_ie.keys())
@ -43,7 +44,7 @@ class TestPlugins(unittest.TestCase):
self.assertIn('InAllPluginIE', plugins_ie.keys()) self.assertIn('InAllPluginIE', plugins_ie.keys())
def test_postprocessor_classes(self): def test_postprocessor_classes(self):
plugins_pp = load_plugins('postprocessor', 'PP') plugins_pp = load_plugins(PluginType.POSTPROCESSORS)
self.assertIn('NormalPluginPP', plugins_pp.keys()) self.assertIn('NormalPluginPP', plugins_pp.keys())
def test_importing_zipped_module(self): def test_importing_zipped_module(self):
@ -57,10 +58,10 @@ class TestPlugins(unittest.TestCase):
package = importlib.import_module(f'{PACKAGE_NAME}.{plugin_type}') package = importlib.import_module(f'{PACKAGE_NAME}.{plugin_type}')
self.assertIn(zip_path / PACKAGE_NAME / plugin_type, map(Path, package.__path__)) self.assertIn(zip_path / PACKAGE_NAME / plugin_type, map(Path, package.__path__))
plugins_ie = load_plugins('extractor', 'IE') plugins_ie = load_plugins(PluginType.EXTRACTORS)
self.assertIn('ZippedPluginIE', plugins_ie.keys()) self.assertIn('ZippedPluginIE', plugins_ie.keys())
plugins_pp = load_plugins('postprocessor', 'PP') plugins_pp = load_plugins(PluginType.POSTPROCESSORS)
self.assertIn('ZippedPluginPP', plugins_pp.keys()) self.assertIn('ZippedPluginPP', plugins_pp.keys())
finally: finally:
@ -68,6 +69,45 @@ class TestPlugins(unittest.TestCase):
os.remove(zip_path) os.remove(zip_path)
importlib.invalidate_caches() # reset the import caches importlib.invalidate_caches() # reset the import caches
def test_reloading_plugins(self):
reload_plugins_path = TEST_DATA_DIR / 'reload_plugins'
for plugin_type in ('extractor', 'postprocessor'):
package = importlib.import_module(f'{PACKAGE_NAME}.{plugin_type}')
load_plugins(PluginType.EXTRACTORS)
load_plugins(PluginType.POSTPROCESSORS)
# Remove default folder and add reload_plugin path
sys.path.remove(str(TEST_DATA_DIR))
sys.path.append(str(reload_plugins_path))
importlib.invalidate_caches()
try:
for plugin_type in ('extractor', 'postprocessor'):
package = importlib.import_module(f'{PACKAGE_NAME}.{plugin_type}')
self.assertIn(reload_plugins_path / PACKAGE_NAME / plugin_type, map(Path, package.__path__))
plugins_ie = load_plugins(PluginType.EXTRACTORS)
self.assertIn('NormalPluginIE', plugins_ie.keys())
self.assertTrue(
plugins_ie['NormalPluginIE'].REPLACED,
msg='Reloading has not replaced original extractor plugin')
self.assertTrue(
extractors.get()['NormalPluginIE'].REPLACED,
msg='Reloading has not replaced original extractor plugin globally')
plugins_pp = load_plugins(PluginType.POSTPROCESSORS)
self.assertIn('NormalPluginPP', plugins_pp.keys())
self.assertTrue(plugins_pp['NormalPluginPP'].REPLACED,
msg='Reloading has not replaced original postprocessor plugin')
self.assertTrue(
postprocessors.get()['NormalPluginPP'].REPLACED,
msg='Reloading has not replaced original postprocessor plugin globally')
finally:
sys.path.remove(str(reload_plugins_path))
sys.path.append(str(TEST_DATA_DIR))
importlib.invalidate_caches()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

@ -0,0 +1,9 @@
from yt_dlp.extractor.common import InfoExtractor
class NormalPluginIE(InfoExtractor):
REPLACED = True
class _IgnoreUnderscorePluginIE(InfoExtractor):
pass

@ -0,0 +1,5 @@
from yt_dlp.postprocessor.common import PostProcessor
class NormalPluginPP(PostProcessor):
REPLACED = True

@ -2,7 +2,7 @@ from yt_dlp.extractor.common import InfoExtractor
class NormalPluginIE(InfoExtractor): class NormalPluginIE(InfoExtractor):
pass REPLACED = False
class _IgnoreUnderscorePluginIE(InfoExtractor): class _IgnoreUnderscorePluginIE(InfoExtractor):

@ -2,4 +2,4 @@ from yt_dlp.postprocessor.common import PostProcessor
class NormalPluginPP(PostProcessor): class NormalPluginPP(PostProcessor):
pass REPLACED = False

Loading…
Cancel
Save