From 7d18fed8f1983fe6de4ddc810dfb2761ba5744ac Mon Sep 17 00:00:00 2001 From: Simon Sawicki Date: Mon, 3 Mar 2025 00:10:01 +0100 Subject: [PATCH] [networking] Add `keep_header_casing` extension (#11652) Authored by: coletdjnz, Grub4K Co-authored-by: coletdjnz --- test/test_networking.py | 13 +++ test/test_utils.py | 23 +++-- test/test_websockets.py | 22 +++-- yt_dlp/networking/_requests.py | 8 +- yt_dlp/networking/_urllib.py | 8 +- yt_dlp/networking/_websockets.py | 8 +- yt_dlp/networking/common.py | 19 ++++ yt_dlp/networking/impersonate.py | 22 ++++- yt_dlp/utils/networking.py | 146 +++++++++++++++++++++++++++---- 9 files changed, 229 insertions(+), 40 deletions(-) diff --git a/test/test_networking.py b/test/test_networking.py index d96624af1..63914bc4b 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -720,6 +720,15 @@ class TestHTTPRequestHandler(TestRequestHandlerBase): rh, Request( f'http://127.0.0.1:{self.http_port}/headers', proxies={'all': 'http://10.255.255.255'})).close() + @pytest.mark.skip_handlers_if(lambda _, handler: handler not in ['Urllib', 'CurlCFFI'], 'handler does not support keep_header_casing') + def test_keep_header_casing(self, handler): + with handler() as rh: + res = validate_and_send( + rh, Request( + f'http://127.0.0.1:{self.http_port}/headers', headers={'X-test-heaDer': 'test'}, extensions={'keep_header_casing': True})).read().decode() + + assert 'X-test-heaDer: test' in res + @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True) class TestClientCertificate: @@ -1289,6 +1298,7 @@ class TestRequestHandlerValidation: ({'legacy_ssl': False}, False), ({'legacy_ssl': True}, False), ({'legacy_ssl': 'notabool'}, AssertionError), + ({'keep_header_casing': True}, UnsupportedRequest), ]), ('Requests', 'http', [ ({'cookiejar': 'notacookiejar'}, AssertionError), @@ -1299,6 +1309,9 @@ class TestRequestHandlerValidation: ({'legacy_ssl': False}, False), ({'legacy_ssl': True}, False), ({'legacy_ssl': 'notabool'}, AssertionError), + ({'keep_header_casing': False}, False), + ({'keep_header_casing': True}, False), + ({'keep_header_casing': 'notabool'}, AssertionError), ]), ('CurlCFFI', 'http', [ ({'cookiejar': 'notacookiejar'}, AssertionError), diff --git a/test/test_utils.py b/test/test_utils.py index 8f81d0b1b..65f28db36 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -3,19 +3,20 @@ # Allow direct execution import os import sys -import unittest -import unittest.mock -import warnings -import datetime as dt sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import contextlib +import datetime as dt import io import itertools import json +import pickle import subprocess +import unittest +import unittest.mock +import warnings import xml.etree.ElementTree from yt_dlp.compat import ( @@ -2087,21 +2088,26 @@ Line 1 headers = HTTPHeaderDict() headers['ytdl-test'] = b'0' self.assertEqual(list(headers.items()), [('Ytdl-Test', '0')]) + self.assertEqual(list(headers.sensitive().items()), [('ytdl-test', '0')]) headers['ytdl-test'] = 1 self.assertEqual(list(headers.items()), [('Ytdl-Test', '1')]) + self.assertEqual(list(headers.sensitive().items()), [('ytdl-test', '1')]) headers['Ytdl-test'] = '2' self.assertEqual(list(headers.items()), [('Ytdl-Test', '2')]) + self.assertEqual(list(headers.sensitive().items()), [('Ytdl-test', '2')]) self.assertTrue('ytDl-Test' in headers) self.assertEqual(str(headers), str(dict(headers))) self.assertEqual(repr(headers), str(dict(headers))) headers.update({'X-dlp': 'data'}) self.assertEqual(set(headers.items()), {('Ytdl-Test', '2'), ('X-Dlp', 'data')}) + self.assertEqual(set(headers.sensitive().items()), {('Ytdl-test', '2'), ('X-dlp', 'data')}) self.assertEqual(dict(headers), {'Ytdl-Test': '2', 'X-Dlp': 'data'}) self.assertEqual(len(headers), 2) self.assertEqual(headers.copy(), headers) - headers2 = HTTPHeaderDict({'X-dlp': 'data3'}, **headers, **{'X-dlp': 'data2'}) + headers2 = HTTPHeaderDict({'X-dlp': 'data3'}, headers, **{'X-dlP': 'data2'}) self.assertEqual(set(headers2.items()), {('Ytdl-Test', '2'), ('X-Dlp', 'data2')}) + self.assertEqual(set(headers2.sensitive().items()), {('Ytdl-test', '2'), ('X-dlP', 'data2')}) self.assertEqual(len(headers2), 2) headers2.clear() self.assertEqual(len(headers2), 0) @@ -2109,16 +2115,23 @@ Line 1 # ensure we prefer latter headers headers3 = HTTPHeaderDict({'Ytdl-TeSt': 1}, {'Ytdl-test': 2}) self.assertEqual(set(headers3.items()), {('Ytdl-Test', '2')}) + self.assertEqual(set(headers3.sensitive().items()), {('Ytdl-test', '2')}) del headers3['ytdl-tesT'] self.assertEqual(dict(headers3), {}) headers4 = HTTPHeaderDict({'ytdl-test': 'data;'}) self.assertEqual(set(headers4.items()), {('Ytdl-Test', 'data;')}) + self.assertEqual(set(headers4.sensitive().items()), {('ytdl-test', 'data;')}) # common mistake: strip whitespace from values # https://github.com/yt-dlp/yt-dlp/issues/8729 headers5 = HTTPHeaderDict({'ytdl-test': ' data; '}) self.assertEqual(set(headers5.items()), {('Ytdl-Test', 'data;')}) + self.assertEqual(set(headers5.sensitive().items()), {('ytdl-test', 'data;')}) + + # test if picklable + headers6 = HTTPHeaderDict(a=1, b=2) + self.assertEqual(pickle.loads(pickle.dumps(headers6)), headers6) def test_extract_basic_auth(self): assert extract_basic_auth('http://:foo.bar') == ('http://:foo.bar', None) diff --git a/test/test_websockets.py b/test/test_websockets.py index 06112cc0b..dead5fe5c 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -44,7 +44,7 @@ def websocket_handler(websocket): return websocket.send('2') elif isinstance(message, str): if message == 'headers': - return websocket.send(json.dumps(dict(websocket.request.headers))) + return websocket.send(json.dumps(dict(websocket.request.headers.raw_items()))) elif message == 'path': return websocket.send(websocket.request.path) elif message == 'source_address': @@ -266,18 +266,18 @@ class TestWebsSocketRequestHandlerConformance: with handler(cookiejar=cookiejar) as rh: ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send('headers') - assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' + assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp' ws.close() with handler() as rh: ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send('headers') - assert 'cookie' not in json.loads(ws.recv()) + assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv())) ws.close() ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) ws.send('headers') - assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' + assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp' ws.close() @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets') @@ -287,7 +287,7 @@ class TestWebsSocketRequestHandlerConformance: ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie', extensions={'cookiejar': YoutubeDLCookieJar()})) ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': YoutubeDLCookieJar()})) ws.send('headers') - assert 'cookie' not in json.loads(ws.recv()) + assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv())) ws.close() @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets') @@ -298,12 +298,12 @@ class TestWebsSocketRequestHandlerConformance: ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie')) ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send('headers') - assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' + assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp' ws.close() cookiejar.clear_session_cookies() ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send('headers') - assert 'cookie' not in json.loads(ws.recv()) + assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv())) ws.close() def test_source_address(self, handler): @@ -341,6 +341,14 @@ class TestWebsSocketRequestHandlerConformance: assert headers['test3'] == 'test3' ws.close() + def test_keep_header_casing(self, handler): + with handler(headers=HTTPHeaderDict({'x-TeSt1': 'test'})) as rh: + ws = ws_validate_and_send(rh, Request(self.ws_base_url, headers={'x-TeSt2': 'test'}, extensions={'keep_header_casing': True})) + ws.send('headers') + headers = json.loads(ws.recv()) + assert 'x-TeSt1' in headers + assert 'x-TeSt2' in headers + @pytest.mark.parametrize('client_cert', ( {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}, { diff --git a/yt_dlp/networking/_requests.py b/yt_dlp/networking/_requests.py index 7de95ab3b..23775845d 100644 --- a/yt_dlp/networking/_requests.py +++ b/yt_dlp/networking/_requests.py @@ -296,6 +296,7 @@ class RequestsRH(RequestHandler, InstanceStoreMixin): extensions.pop('cookiejar', None) extensions.pop('timeout', None) extensions.pop('legacy_ssl', None) + extensions.pop('keep_header_casing', None) def _create_instance(self, cookiejar, legacy_ssl_support=None): session = RequestsSession() @@ -312,11 +313,12 @@ class RequestsRH(RequestHandler, InstanceStoreMixin): session.trust_env = False # no need, we already load proxies from env return session - def _send(self, request): - - headers = self._merge_headers(request.headers) + def _prepare_headers(self, _, headers): add_accept_encoding_header(headers, SUPPORTED_ENCODINGS) + def _send(self, request): + + headers = self._get_headers(request) max_redirects_exceeded = False session = self._get_instance( diff --git a/yt_dlp/networking/_urllib.py b/yt_dlp/networking/_urllib.py index 510bb2a69..a188b35f5 100644 --- a/yt_dlp/networking/_urllib.py +++ b/yt_dlp/networking/_urllib.py @@ -379,13 +379,15 @@ class UrllibRH(RequestHandler, InstanceStoreMixin): opener.addheaders = [] return opener - def _send(self, request): - headers = self._merge_headers(request.headers) + def _prepare_headers(self, _, headers): add_accept_encoding_header(headers, SUPPORTED_ENCODINGS) + + def _send(self, request): + headers = self._get_headers(request) urllib_req = urllib.request.Request( url=request.url, data=request.data, - headers=dict(headers), + headers=headers, method=request.method, ) diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py index ec55567da..7e5ab4600 100644 --- a/yt_dlp/networking/_websockets.py +++ b/yt_dlp/networking/_websockets.py @@ -116,6 +116,7 @@ class WebsocketsRH(WebSocketRequestHandler): extensions.pop('timeout', None) extensions.pop('cookiejar', None) extensions.pop('legacy_ssl', None) + extensions.pop('keep_header_casing', None) def close(self): # Remove the logging handler that contains a reference to our logger @@ -123,15 +124,16 @@ class WebsocketsRH(WebSocketRequestHandler): for name, handler in self.__logging_handlers.items(): logging.getLogger(name).removeHandler(handler) - def _send(self, request): - timeout = self._calculate_timeout(request) - headers = self._merge_headers(request.headers) + def _prepare_headers(self, request, headers): if 'cookie' not in headers: cookiejar = self._get_cookiejar(request) cookie_header = cookiejar.get_cookie_header(request.url) if cookie_header: headers['cookie'] = cookie_header + def _send(self, request): + timeout = self._calculate_timeout(request) + headers = self._get_headers(request) wsuri = parse_uri(request.url) create_conn_kwargs = { 'source_address': (self.source_address, 0) if self.source_address else None, diff --git a/yt_dlp/networking/common.py b/yt_dlp/networking/common.py index e8951c7e7..ddceaa9a9 100644 --- a/yt_dlp/networking/common.py +++ b/yt_dlp/networking/common.py @@ -206,6 +206,7 @@ class RequestHandler(abc.ABC): - `cookiejar`: Cookiejar to use for this request. - `timeout`: socket timeout to use for this request. - `legacy_ssl`: Enable legacy SSL options for this request. See legacy_ssl_support. + - `keep_header_casing`: Keep the casing of headers when sending the request. To enable these, add extensions.pop('', None) to _check_extensions Apart from the url protocol, proxies dict may contain the following keys: @@ -259,6 +260,23 @@ class RequestHandler(abc.ABC): def _merge_headers(self, request_headers): return HTTPHeaderDict(self.headers, request_headers) + def _prepare_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027 + """Additional operations to prepare headers before building. To be extended by subclasses. + @param request: Request object + @param headers: Merged headers to prepare + """ + + def _get_headers(self, request: Request) -> dict[str, str]: + """ + Get headers for external use. + Subclasses may define a _prepare_headers method to modify headers after merge but before building. + """ + headers = self._merge_headers(request.headers) + self._prepare_headers(request, headers) + if request.extensions.get('keep_header_casing'): + return headers.sensitive() + return dict(headers) + def _calculate_timeout(self, request): return float(request.extensions.get('timeout') or self.timeout) @@ -317,6 +335,7 @@ class RequestHandler(abc.ABC): assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType)) assert isinstance(extensions.get('timeout'), (float, int, NoneType)) assert isinstance(extensions.get('legacy_ssl'), (bool, NoneType)) + assert isinstance(extensions.get('keep_header_casing'), (bool, NoneType)) def _validate(self, request): self._check_url_scheme(request) diff --git a/yt_dlp/networking/impersonate.py b/yt_dlp/networking/impersonate.py index 0626b3b49..b90d10b76 100644 --- a/yt_dlp/networking/impersonate.py +++ b/yt_dlp/networking/impersonate.py @@ -5,11 +5,11 @@ from abc import ABC from dataclasses import dataclass from typing import Any -from .common import RequestHandler, register_preference +from .common import RequestHandler, register_preference, Request from .exceptions import UnsupportedRequest from ..compat.types import NoneType from ..utils import classproperty, join_nonempty -from ..utils.networking import std_headers +from ..utils.networking import std_headers, HTTPHeaderDict @dataclass(order=True, frozen=True) @@ -123,7 +123,17 @@ class ImpersonateRequestHandler(RequestHandler, ABC): """Get the requested target for the request""" return self._resolve_target(request.extensions.get('impersonate') or self.impersonate) - def _get_impersonate_headers(self, request): + def _prepare_impersonate_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027 + """Additional operations to prepare headers before building. To be extended by subclasses. + @param request: Request object + @param headers: Merged headers to prepare + """ + + def _get_impersonate_headers(self, request: Request) -> dict[str, str]: + """ + Get headers for external impersonation use. + Subclasses may define a _prepare_impersonate_headers method to modify headers after merge but before building. + """ headers = self._merge_headers(request.headers) if self._get_request_target(request) is not None: # remove all headers present in std_headers @@ -131,7 +141,11 @@ class ImpersonateRequestHandler(RequestHandler, ABC): for k, v in std_headers.items(): if headers.get(k) == v: headers.pop(k) - return headers + + self._prepare_impersonate_headers(request, headers) + if request.extensions.get('keep_header_casing'): + return headers.sensitive() + return dict(headers) @register_preference(ImpersonateRequestHandler) diff --git a/yt_dlp/utils/networking.py b/yt_dlp/utils/networking.py index 933b164be..542abace8 100644 --- a/yt_dlp/utils/networking.py +++ b/yt_dlp/utils/networking.py @@ -1,9 +1,16 @@ +from __future__ import annotations + import collections +import collections.abc import random +import typing import urllib.parse import urllib.request -from ._utils import remove_start +if typing.TYPE_CHECKING: + T = typing.TypeVar('T') + +from ._utils import NO_DEFAULT, remove_start def random_user_agent(): @@ -51,32 +58,141 @@ def random_user_agent(): return _USER_AGENT_TPL % random.choice(_CHROME_VERSIONS) -class HTTPHeaderDict(collections.UserDict, dict): +class HTTPHeaderDict(dict): """ Store and access keys case-insensitively. The constructor can take multiple dicts, in which keys in the latter are prioritised. + + Retains a case sensitive mapping of the headers, which can be accessed via `.sensitive()`. """ + def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> typing.Self: + obj = dict.__new__(cls, *args, **kwargs) + obj.__sensitive_map = {} + return obj - def __init__(self, *args, **kwargs): + def __init__(self, /, *args, **kwargs): super().__init__() - for dct in args: - if dct is not None: - self.update(dct) - self.update(kwargs) + self.__sensitive_map = {} + + for dct in filter(None, args): + self.update(dct) + if kwargs: + self.update(kwargs) + + def sensitive(self, /) -> dict[str, str]: + return { + self.__sensitive_map[key]: value + for key, value in self.items() + } + + def __contains__(self, key: str, /) -> bool: + return super().__contains__(key.title() if isinstance(key, str) else key) + + def __delitem__(self, key: str, /) -> None: + key = key.title() + del self.__sensitive_map[key] + super().__delitem__(key) - def __setitem__(self, key, value): + def __getitem__(self, key, /) -> str: + return super().__getitem__(key.title()) + + def __ior__(self, other, /): + if isinstance(other, type(self)): + other = other.sensitive() + if isinstance(other, dict): + self.update(other) + return + return NotImplemented + + def __or__(self, other, /) -> typing.Self: + if isinstance(other, type(self)): + other = other.sensitive() + if isinstance(other, dict): + return type(self)(self.sensitive(), other) + return NotImplemented + + def __ror__(self, other, /) -> typing.Self: + if isinstance(other, type(self)): + other = other.sensitive() + if isinstance(other, dict): + return type(self)(other, self.sensitive()) + return NotImplemented + + def __setitem__(self, key: str, value, /) -> None: if isinstance(value, bytes): value = value.decode('latin-1') - super().__setitem__(key.title(), str(value).strip()) + key_title = key.title() + self.__sensitive_map[key_title] = key + super().__setitem__(key_title, str(value).strip()) - def __getitem__(self, key): - return super().__getitem__(key.title()) + def clear(self, /) -> None: + self.__sensitive_map.clear() + super().clear() - def __delitem__(self, key): - super().__delitem__(key.title()) + def copy(self, /) -> typing.Self: + return type(self)(self.sensitive()) - def __contains__(self, key): - return super().__contains__(key.title() if isinstance(key, str) else key) + @typing.overload + def get(self, key: str, /) -> str | None: ... + + @typing.overload + def get(self, key: str, /, default: T) -> str | T: ... + + def get(self, key, /, default=NO_DEFAULT): + key = key.title() + if default is NO_DEFAULT: + return super().get(key) + return super().get(key, default) + + @typing.overload + def pop(self, key: str, /) -> str: ... + + @typing.overload + def pop(self, key: str, /, default: T) -> str | T: ... + + def pop(self, key, /, default=NO_DEFAULT): + key = key.title() + if default is NO_DEFAULT: + self.__sensitive_map.pop(key) + return super().pop(key) + self.__sensitive_map.pop(key, default) + return super().pop(key, default) + + def popitem(self) -> tuple[str, str]: + self.__sensitive_map.popitem() + return super().popitem() + + @typing.overload + def setdefault(self, key: str, /) -> str: ... + + @typing.overload + def setdefault(self, key: str, /, default) -> str: ... + + def setdefault(self, key, /, default=None) -> str: + key = key.title() + if key in self.__sensitive_map: + return super().__getitem__(key) + + self[key] = default or '' + return self[key] + + def update(self, other, /, **kwargs) -> None: + if isinstance(other, type(self)): + other = other.sensitive() + if isinstance(other, collections.abc.Mapping): + for key, value in other.items(): + self[key] = value + + elif hasattr(other, 'keys'): + for key in other.keys(): # noqa: SIM118 + self[key] = other[key] + + else: + for key, value in other: + self[key] = value + + for key, value in kwargs.items(): + self[key] = value std_headers = HTTPHeaderDict({