diff --git a/test/test_socks.py b/test/test_socks.py index 95ffce275..211ee814d 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -281,17 +281,13 @@ class TestSocks4Proxy: rh, proxies={'all': f'socks4://user:@{server_address}'}) assert response['version'] == 4 - @pytest.mark.parametrize('handler,ctx', [ - pytest.param('Urllib', 'http', marks=pytest.mark.xfail( - reason='socks4a implementation currently broken when destination is not a domain name')) - ], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) def test_socks4a_ipv4_target(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler) as server_address: with handler(proxies={'all': f'socks4a://{server_address}'}) as rh: response = ctx.socks_info_request(rh, target_domain='127.0.0.1') assert response['version'] == 4 - assert response['ipv4_address'] == '127.0.0.1' - assert response['domain_address'] is None + assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1') @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) def test_socks4a_domain_target(self, handler, ctx): @@ -302,10 +298,7 @@ class TestSocks4Proxy: assert response['ipv4_address'] is None assert response['domain_address'] == 'localhost' - @pytest.mark.parametrize('handler,ctx', [ - pytest.param('Urllib', 'http', marks=pytest.mark.xfail( - reason='source_address is not yet supported for socks4 proxies')) - ], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) def test_ipv4_client_source_address(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler) as server_address: source_address = f'127.0.0.{random.randint(5, 255)}' @@ -327,10 +320,7 @@ class TestSocks4Proxy: with pytest.raises(ProxyError): ctx.socks_info_request(rh) - @pytest.mark.parametrize('handler,ctx', [ - pytest.param('Urllib', 'http', marks=pytest.mark.xfail( - reason='IPv6 socks4 proxies are not yet supported')) - ], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) def test_ipv6_socks4_proxy(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address: with handler(proxies={'all': f'socks4://{server_address}'}) as rh: @@ -342,7 +332,7 @@ class TestSocks4Proxy: @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) def test_timeout(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address: - with handler(proxies={'all': f'socks4://{server_address}'}, timeout=1) as rh: + with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh: with pytest.raises(TransportError): ctx.socks_info_request(rh) @@ -383,7 +373,7 @@ class TestSocks5Proxy: with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: response = ctx.socks_info_request(rh, target_domain='localhost') - assert response['ipv4_address'] == '127.0.0.1' + assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1') assert response['version'] == 5 @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) @@ -404,22 +394,15 @@ class TestSocks5Proxy: assert response['domain_address'] is None assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [ - pytest.param('Urllib', 'http', marks=pytest.mark.xfail( - reason='IPv6 destination addresses are not yet supported')) - ], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) def test_socks5_ipv6_destination(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: response = ctx.socks_info_request(rh, target_domain='[::1]') assert response['ipv6_address'] == '::1' - assert response['port'] == 80 assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [ - pytest.param('Urllib', 'http', marks=pytest.mark.xfail( - reason='IPv6 socks5 proxies are not yet supported')) - ], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) def test_ipv6_socks5_proxy(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: @@ -430,10 +413,7 @@ class TestSocks5Proxy: # XXX: is there any feasible way of testing IPv6 source addresses? # Same would go for non-proxy source_address test... - @pytest.mark.parametrize('handler,ctx', [ - pytest.param('Urllib', 'http', marks=pytest.mark.xfail( - reason='source_address is not yet supported for socks5 proxies')) - ], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) def test_ipv4_client_source_address(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: source_address = f'127.0.0.{random.randint(5, 255)}' diff --git a/yt_dlp/networking/_helper.py b/yt_dlp/networking/_helper.py index a43c57bb4..4c9dbf25d 100644 --- a/yt_dlp/networking/_helper.py +++ b/yt_dlp/networking/_helper.py @@ -2,6 +2,7 @@ from __future__ import annotations import contextlib import functools +import socket import ssl import sys import typing @@ -206,3 +207,59 @@ def wrap_request_errors(func): e.handler = self raise return wrapper + + +def _socket_connect(ip_addr, timeout, source_address): + af, socktype, proto, canonname, sa = ip_addr + sock = socket.socket(af, socktype, proto) + try: + if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + sock.settimeout(timeout) + if source_address: + sock.bind(source_address) + sock.connect(sa) + return sock + except socket.error: + sock.close() + raise + + +def create_connection( + address, + timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + source_address=None, + *, + _create_socket_func=_socket_connect +): + # Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6. + # This filters the addresses based on the given source_address. + # Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810 + host, port = address + ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM) + if not ip_addrs: + raise socket.error('getaddrinfo returns an empty list') + if source_address is not None: + af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6 + ip_addrs = [addr for addr in ip_addrs if addr[0] == af] + if not ip_addrs: + raise OSError( + f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. ' + f'Can\'t use "{source_address[0]}" as source address') + + err = None + for ip_addr in ip_addrs: + try: + sock = _create_socket_func(ip_addr, timeout, source_address) + # Explicitly break __traceback__ reference cycle + # https://bugs.python.org/issue36820 + err = None + return sock + except socket.error as e: + err = e + + try: + raise err + finally: + # Explicitly break __traceback__ reference cycle + # https://bugs.python.org/issue36820 + err = None diff --git a/yt_dlp/networking/_urllib.py b/yt_dlp/networking/_urllib.py index 3c0647ecf..c327f7744 100644 --- a/yt_dlp/networking/_urllib.py +++ b/yt_dlp/networking/_urllib.py @@ -23,6 +23,7 @@ from urllib.request import ( from ._helper import ( InstanceStoreMixin, add_accept_encoding_header, + create_connection, get_redirect_method, make_socks_proxy_opts, select_proxy, @@ -54,44 +55,10 @@ if brotli: def _create_http_connection(http_class, source_address, *args, **kwargs): hc = http_class(*args, **kwargs) + if hasattr(hc, '_create_connection'): + hc._create_connection = create_connection + if source_address is not None: - # This is to workaround _create_connection() from socket where it will try all - # address data from getaddrinfo() including IPv6. This filters the result from - # getaddrinfo() based on the source_address value. - # This is based on the cpython socket.create_connection() function. - # https://github.com/python/cpython/blob/master/Lib/socket.py#L691 - def _create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None): - host, port = address - err = None - addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM) - af = socket.AF_INET if '.' in source_address[0] else socket.AF_INET6 - ip_addrs = [addr for addr in addrs if addr[0] == af] - if addrs and not ip_addrs: - ip_version = 'v4' if af == socket.AF_INET else 'v6' - raise OSError( - "No remote IP%s addresses available for connect, can't use '%s' as source address" - % (ip_version, source_address[0])) - for res in ip_addrs: - af, socktype, proto, canonname, sa = res - sock = None - try: - sock = socket.socket(af, socktype, proto) - if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: - sock.settimeout(timeout) - sock.bind(source_address) - sock.connect(sa) - err = None # Explicitly break reference cycle - return sock - except OSError as _: - err = _ - if sock is not None: - sock.close() - if err is not None: - raise err - else: - raise OSError('getaddrinfo returns an empty list') - if hasattr(hc, '_create_connection'): - hc._create_connection = _create_connection hc.source_address = (source_address, 0) return hc @@ -220,13 +187,28 @@ def make_socks_conn_class(base_class, socks_proxy): proxy_args = make_socks_proxy_opts(socks_proxy) class SocksConnection(base_class): - def connect(self): - self.sock = sockssocket() - self.sock.setproxy(**proxy_args) - if type(self.timeout) in (int, float): # noqa: E721 - self.sock.settimeout(self.timeout) - self.sock.connect((self.host, self.port)) + _create_connection = create_connection + def connect(self): + def sock_socket_connect(ip_addr, timeout, source_address): + af, socktype, proto, canonname, sa = ip_addr + sock = sockssocket(af, socktype, proto) + try: + connect_proxy_args = proxy_args.copy() + connect_proxy_args.update({'addr': sa[0], 'port': sa[1]}) + sock.setproxy(**connect_proxy_args) + if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: # noqa: E721 + sock.settimeout(timeout) + if source_address: + sock.bind(source_address) + sock.connect((self.host, self.port)) + return sock + except socket.error: + sock.close() + raise + self.sock = create_connection( + (proxy_args['addr'], proxy_args['port']), timeout=self.timeout, + source_address=self.source_address, _create_socket_func=sock_socket_connect) if isinstance(self, http.client.HTTPSConnection): self.sock = self._context.wrap_socket(self.sock, server_hostname=self.host) diff --git a/yt_dlp/socks.py b/yt_dlp/socks.py index f93328f63..e7f41d7e2 100644 --- a/yt_dlp/socks.py +++ b/yt_dlp/socks.py @@ -134,26 +134,31 @@ class sockssocket(socket.socket): self.close() raise InvalidVersionError(expected_version, got_version) - def _resolve_address(self, destaddr, default, use_remote_dns): - try: - return socket.inet_aton(destaddr) - except OSError: - if use_remote_dns and self._proxy.remote_dns: - return default - else: - return socket.inet_aton(socket.gethostbyname(destaddr)) + def _resolve_address(self, destaddr, default, use_remote_dns, family=None): + for f in (family,) if family else (socket.AF_INET, socket.AF_INET6): + try: + return f, socket.inet_pton(f, destaddr) + except OSError: + continue + + if use_remote_dns and self._proxy.remote_dns: + return 0, default + else: + res = socket.getaddrinfo(destaddr, None, family=family or 0) + f, _, _, _, ipaddr = res[0] + return f, socket.inet_pton(f, ipaddr[0]) def _setup_socks4(self, address, is_4a=False): destaddr, port = address - ipaddr = self._resolve_address(destaddr, SOCKS4_DEFAULT_DSTIP, use_remote_dns=is_4a) + _, ipaddr = self._resolve_address(destaddr, SOCKS4_DEFAULT_DSTIP, use_remote_dns=is_4a, family=socket.AF_INET) packet = struct.pack('!BBH', SOCKS4_VERSION, Socks4Command.CMD_CONNECT, port) + ipaddr username = (self._proxy.username or '').encode() packet += username + b'\x00' - if is_4a and self._proxy.remote_dns: + if is_4a and self._proxy.remote_dns and ipaddr == SOCKS4_DEFAULT_DSTIP: packet += destaddr.encode() + b'\x00' self.sendall(packet) @@ -210,7 +215,7 @@ class sockssocket(socket.socket): def _setup_socks5(self, address): destaddr, port = address - ipaddr = self._resolve_address(destaddr, None, use_remote_dns=True) + family, ipaddr = self._resolve_address(destaddr, None, use_remote_dns=True) self._socks5_auth() @@ -220,8 +225,10 @@ class sockssocket(socket.socket): destaddr = destaddr.encode() packet += struct.pack('!B', Socks5AddressType.ATYP_DOMAINNAME) packet += self._len_and_data(destaddr) - else: + elif family == socket.AF_INET: packet += struct.pack('!B', Socks5AddressType.ATYP_IPV4) + ipaddr + elif family == socket.AF_INET6: + packet += struct.pack('!B', Socks5AddressType.ATYP_IPV6) + ipaddr packet += struct.pack('!H', port) self.sendall(packet)