diff --git a/test/test_traversal.py b/test/test_traversal.py index 52215f5a7b..f78c0db4af 100644 --- a/test/test_traversal.py +++ b/test/test_traversal.py @@ -1,4 +1,5 @@ import http.cookies +import dataclasses import re import xml.etree.ElementTree @@ -439,6 +440,17 @@ class TestTraversal: assert traverse_obj(data, [..., filter]) == [True, 1, 1.1, 'str', {0: 0}, [1]], \ '`filter` should filter falsy values' + def test_traversal_dataclass(self): + @dataclasses.dataclass + class _TestDataclass: + val: str + + dc = _TestDataclass(val='yt-dlp') + assert traverse_obj(dc, 'val') == 'yt-dlp', \ + 'Dataclasses should be traversable' + assert traverse_obj({'dataclass': dc}, ('dataclass', 'val')) == 'yt-dlp', \ + 'Dataclasses inside other objects should be traversable' + class TestTraversalHelpers: def test_traversal_require(self): diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index 76b51f53d1..65a94ccb9e 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -3,6 +3,7 @@ from __future__ import annotations import collections import collections.abc import contextlib +import dataclasses import functools import http.cookies import inspect @@ -116,6 +117,9 @@ def traverse_obj( branching = False result = None + if dataclasses.is_dataclass(obj): + obj = dataclasses.asdict(obj) + if obj is None and traverse_string: if key is ... or callable(key) or isinstance(key, slice): branching = True