diff --git a/test/test_sabr/conftest.py b/test/test_sabr/conftest.py index f693c576e..5876450d5 100644 --- a/test/test_sabr/conftest.py +++ b/test/test_sabr/conftest.py @@ -1,3 +1,20 @@ +from unittest.mock import MagicMock import pytest pytest.importorskip('protobug', reason='protobug is not installed') + +from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, ClientName +from yt_dlp.extractor.youtube._streaming.sabr.models import SabrLogger + + +@pytest.fixture +def logger(): + mock_logger = MagicMock() + mock_logger.LogLevel = SabrLogger.LogLevel + mock_logger.log_level = SabrLogger.LogLevel.TRACE + return mock_logger + + +@pytest.fixture +def client_info(): + return ClientInfo(client_name=ClientName.WEB) diff --git a/test/test_sabr/test_processor.py b/test/test_sabr/test_processor.py index 5849b07ab..d0faa1b7e 100644 --- a/test/test_sabr/test_processor.py +++ b/test/test_sabr/test_processor.py @@ -2,7 +2,6 @@ import dataclasses import io import pytest -from unittest.mock import MagicMock from yt_dlp.extractor.youtube._streaming.sabr.exceptions import SabrStreamError, MediaSegmentMismatchError from yt_dlp.extractor.youtube._streaming.sabr.part import ( @@ -43,17 +42,7 @@ from yt_dlp.extractor.youtube._proto.videostreaming import ( MediaHeader, TimeRange, ) -from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy, CompressionAlgorithm - - -@pytest.fixture -def logger(): - return MagicMock() - - -@pytest.fixture -def client_info(): - return ClientInfo() +from yt_dlp.extractor.youtube._proto.innertube import NextRequestPolicy, CompressionAlgorithm @pytest.fixture diff --git a/test/test_sabr/test_stream.py b/test/test_sabr/test_stream.py new file mode 100644 index 000000000..2e492f09e --- /dev/null +++ b/test/test_sabr/test_stream.py @@ -0,0 +1,306 @@ +from __future__ import annotations +import base64 +import dataclasses +import io +import protobug +from yt_dlp.extractor.youtube._streaming.sabr.models import AudioSelector, VideoSelector +from yt_dlp.extractor.youtube._streaming.sabr.stream import SabrStream +from yt_dlp.networking import Request, Response +from yt_dlp.extractor.youtube._proto.videostreaming import ( + VideoPlaybackAbrRequest, + SabrError, + FormatId, + FormatInitializationMetadata, + MediaHeader, +) +from yt_dlp.extractor.youtube._streaming.ump import UMPEncoder, UMPPart, UMPPartId, write_varint + +VIDEO_PLAYBACK_USTREAMER_CONFIG = base64.urlsafe_b64encode(b'test-config').decode('utf-8') +VIDEO_ID = 'test_video_id' + +DEFAULT_NUM_AUDIO_SEGMENTS = 5 +DEFAULT_NUM_VIDEO_SEGMENTS = 10 +DEFAULT_MEDIA_SEGMENT_DATA = b'example-media-segment' +DEFAULT_DURATION_MS = 10000 +DEFAULT_INIT_SEGMENT_DATA = b'example-init-segment' + + +@dataclasses.dataclass +class SabrRequestDetails: + request: Request + parts: list = dataclasses.field(default_factory=list) + response: Response | None = None + vpabr: VideoPlaybackAbrRequest | None = None + error: Exception | None = None + + +class SabrRequestHandler: + def __init__(self, sabr_response_processor: SabrResponseProcessor): + self.sabr_response_processor = sabr_response_processor + self.request_history = [] + + def send(self, request: Request) -> Response: + try: + vpabr, parts, response_code = self.sabr_response_processor.process_request(request.data, request.url) + except Exception as e: + self.request_history.append( + SabrRequestDetails(request=request, error=e)) + raise e + + fp = io.BytesIO() + with UMPEncoder(fp) as encoder: + for part in parts: + encoder.write_part(part) + + response = Response( + url=request.url, + status=response_code, + headers={ + 'Content-Type': 'application/vnd.yt-ump', + 'Content-Length': str(fp.tell()), + }, + fp=fp, + ) + fp.seek(0) + + self.request_history.append(SabrRequestDetails( + request=request, + response=response, + parts=parts, + vpabr=vpabr, + )) + + return response + + +class SabrResponseProcessor: + def process_request(self, data: bytes, url: str) -> tuple[VideoPlaybackAbrRequest | None, list[UMPPart], int]: + try: + vpabr = protobug.loads(data, VideoPlaybackAbrRequest) + except Exception: + error_part = protobug.dumps(SabrError(type='sabr.malformed_request')) + # TODO: confirm GVS behaviour when VideoPlaybackAbrRequest is malformed + return None, [UMPPart(data=io.BytesIO(error_part), part_id=UMPPartId.SABR_ERROR, size=len(error_part))], 200 + + return vpabr, self.get_parts(vpabr, url), 200 + + def get_parts(self, vpabr: VideoPlaybackAbrRequest, url: str) -> list[UMPPart]: + raise NotImplementedError + + def determine_formats(self, vpabr: VideoPlaybackAbrRequest) -> tuple[FormatId, FormatId]: + # Check selected_audio_format_ids and selected_video_format_ids + # TODO: caption format ids, consider initialized_format_ids, enabled_track_types_bitfield + audio_format_ids = vpabr.selected_audio_format_ids + video_format_ids = vpabr.selected_video_format_ids + + audio_format_id = audio_format_ids[0] if audio_format_ids else FormatId(itag=140, lmt=123) + video_format_id = video_format_ids[0] if video_format_ids else FormatId(itag=248, lmt=456) + return audio_format_id, video_format_id + + def get_format_initialization_metadata_parts(self, + vpabr: VideoPlaybackAbrRequest, + audio_format_id: FormatId | None = None, + video_format_id: FormatId | None = None, + total_audio_segments: int = DEFAULT_NUM_AUDIO_SEGMENTS, + total_video_segments: int = DEFAULT_NUM_VIDEO_SEGMENTS, + audio_end_time_ms: int = DEFAULT_DURATION_MS, + video_end_time_ms: int = DEFAULT_DURATION_MS, + audio_duration_ms: int = DEFAULT_DURATION_MS, + video_duration_ms: int = DEFAULT_DURATION_MS, + ) -> list[UMPPart]: + parts = [] + + audio_buffered_segments = self.buffered_segments(vpabr, total_audio_segments, audio_format_id) + video_buffered_segments = self.buffered_segments(vpabr, total_video_segments, video_format_id) + + if audio_format_id and not audio_buffered_segments: + fim = protobug.dumps(FormatInitializationMetadata( + video_id=VIDEO_ID, + format_id=audio_format_id, + mime_type='audio/mp4', + total_segments=total_audio_segments, + end_time_ms=audio_end_time_ms, + duration_ticks=audio_duration_ms, + duration_timescale=1000, + )) + parts.append(UMPPart( + part_id=UMPPartId.FORMAT_INITIALIZATION_METADATA, + size=len(fim), + data=io.BytesIO(fim), + )) + + if video_format_id and not video_buffered_segments: + fim = protobug.dumps(FormatInitializationMetadata( + video_id=VIDEO_ID, + format_id=video_format_id, + mime_type='video/mp4', + total_segments=total_video_segments, + end_time_ms=video_end_time_ms, + duration_ticks=video_duration_ms, + duration_timescale=1000, + )) + parts.append(UMPPart( + part_id=UMPPartId.FORMAT_INITIALIZATION_METADATA, + size=len(fim), + data=io.BytesIO(fim), + )) + + return parts + + def buffered_segments(self, vpabr: VideoPlaybackAbrRequest, total_segments: int, format_id: FormatId): + return { + segment_index + for buffered_range in vpabr.buffered_ranges + if buffered_range.format_id == format_id + for segment_index in range(buffered_range.start_segment_index, min(total_segments + 1, buffered_range.end_segment_index + 1)) + } + + def get_init_segment_parts(self, header_id: int, format_id: FormatId): + media_header = protobug.dumps(MediaHeader( + header_id=header_id, + format_id=format_id, + is_init_segment=True, + video_id=VIDEO_ID, + content_length=len(DEFAULT_INIT_SEGMENT_DATA), + )) + + varint_fp = io.BytesIO() + write_varint(varint_fp, header_id) + header_id_varint = varint_fp.getvalue() + + return [ + UMPPart( + part_id=UMPPartId.MEDIA_HEADER, + size=len(media_header), + data=io.BytesIO(media_header), + ), + UMPPart( + part_id=UMPPartId.MEDIA, + size=len(DEFAULT_INIT_SEGMENT_DATA) + len(header_id_varint), + data=io.BytesIO(header_id_varint + DEFAULT_INIT_SEGMENT_DATA), + ), + UMPPart( + part_id=UMPPartId.MEDIA_END, + size=len(header_id_varint), + data=io.BytesIO(header_id_varint), + ), + ] + + def get_media_segments( + self, + buffered_segments: set[int], + total_segments: int, + max_segments: int, + player_time_ms: int, + start_header_id: int, + format_id: FormatId, + ) -> tuple[list[UMPPart], int]: + + segment_parts = [] + + if not buffered_segments: + segment_parts.append(self.get_init_segment_parts(header_id=start_header_id, format_id=format_id)) + + segment_duration = (DEFAULT_DURATION_MS // total_segments) + + for sequence_number in range(1, total_segments + 1): + if sequence_number in buffered_segments: + continue + if len(segment_parts) >= max_segments: + break + start_ms = (sequence_number - 1) * segment_duration + if start_ms: + start_ms += 1 # should be + 1 from previous segment end time + + # Basic server-side buffering logic to determine if the segment should be included + if ( + (player_time_ms >= start_ms + segment_duration) + or (player_time_ms < (start_ms - segment_duration * 2)) # allow to buffer 2 segments ahead + ): + continue + + header_id = len(segment_parts) + start_header_id + media_header = protobug.dumps(MediaHeader( + header_id=header_id, + format_id=format_id, + video_id=VIDEO_ID, + content_length=len(DEFAULT_MEDIA_SEGMENT_DATA), + sequence_number=sequence_number, + duration_ms=segment_duration, + start_ms=start_ms, + )) + + varint_fp = io.BytesIO() + write_varint(varint_fp, header_id) + header_id_varint = varint_fp.getvalue() + + segment_parts.append([ + UMPPart( + part_id=UMPPartId.MEDIA_HEADER, + size=len(media_header), + data=io.BytesIO(media_header), + ), + UMPPart( + part_id=UMPPartId.MEDIA, + size=len(DEFAULT_MEDIA_SEGMENT_DATA) + len(header_id_varint), + data=io.BytesIO(header_id_varint + DEFAULT_MEDIA_SEGMENT_DATA), + ), + UMPPart( + part_id=UMPPartId.MEDIA_END, + size=len(header_id_varint), + data=io.BytesIO(header_id_varint), + ), + ]) + return [item for sublist in segment_parts for item in sublist], len(segment_parts) + start_header_id + + +class BasicAudioVideoProfile(SabrResponseProcessor): + def get_parts(self, vpabr: VideoPlaybackAbrRequest, url: str) -> list[UMPPart]: + audio_format_id, video_format_id = self.determine_formats(vpabr) + fim_parts = self.get_format_initialization_metadata_parts( + audio_format_id=audio_format_id, + video_format_id=video_format_id, + vpabr=vpabr, + ) + + audio_segment_parts, next_header_id = self.get_media_segments( + buffered_segments=self.buffered_segments(vpabr, DEFAULT_NUM_AUDIO_SEGMENTS, audio_format_id), + total_segments=DEFAULT_NUM_AUDIO_SEGMENTS, + max_segments=2, + player_time_ms=vpabr.client_abr_state.player_time_ms, + start_header_id=0, + format_id=audio_format_id, + ) + video_segment_parts, next_header_id = self.get_media_segments( + buffered_segments=self.buffered_segments(vpabr, DEFAULT_NUM_VIDEO_SEGMENTS, video_format_id), + total_segments=DEFAULT_NUM_VIDEO_SEGMENTS, + max_segments=2, + player_time_ms=vpabr.client_abr_state.player_time_ms, + start_header_id=next_header_id, + format_id=video_format_id, + ) + return [ + *fim_parts, + *audio_segment_parts, + *video_segment_parts, + ] + + +class TestStream: + def test_sabr_request_handler(self, logger, client_info): + + rh = SabrRequestHandler(sabr_response_processor=BasicAudioVideoProfile()) + + sabr_stream = SabrStream( + urlopen=rh.send, + server_abr_streaming_url='https://example.com/sabr', + logger=logger, + video_playback_ustreamer_config=VIDEO_PLAYBACK_USTREAMER_CONFIG, + client_info=client_info, + audio_selection=AudioSelector(display_name='audio'), + video_selection=VideoSelector(display_name='video'), + ) + + for part in sabr_stream.iter_parts(): + print(part) + print(logger.mock_calls) diff --git a/yt_dlp/extractor/youtube/_streaming/ump.py b/yt_dlp/extractor/youtube/_streaming/ump.py index 0b3cb9c8b..b6ad31a86 100644 --- a/yt_dlp/extractor/youtube/_streaming/ump.py +++ b/yt_dlp/extractor/youtube/_streaming/ump.py @@ -88,13 +88,13 @@ class UMPEncoder: self.fp = fp def write_part(self, part: UMPPart) -> None: - if not isinstance(part.part_id, UMPPartId): - raise ValueError('part_id must be an instance of UMPPartId') - write_varint(self.fp, part.part_id.value) write_varint(self.fp, part.size) self.fp.write(part.data.read()) + __enter__ = lambda self: self + __exit__ = lambda self, exc_type, exc_value, traceback: None + def read_varint(fp: io.BufferedIOBase) -> int: # https://web.archive.org/web/20250430054327/https://github.com/gsuberland/UMP_Format/blob/main/UMP_Format.md