mirror of https://github.com/yt-dlp/yt-dlp
Merge ba869a0901
into 404bd889d0
commit
280d2fbb8d
@ -0,0 +1,20 @@
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
pytest.importorskip('protobug', reason='protobug is not installed')
|
||||
|
||||
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, ClientName
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.models import SabrLogger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logger():
|
||||
mock_logger = MagicMock()
|
||||
mock_logger.LogLevel = SabrLogger.LogLevel
|
||||
mock_logger.log_level = SabrLogger.LogLevel.TRACE
|
||||
return mock_logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_info():
|
||||
return ClientInfo(client_name=ClientName.WEB)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
import base64
|
||||
import dataclasses
|
||||
import io
|
||||
import protobug
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.models import AudioSelector, VideoSelector
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.stream import SabrStream
|
||||
from yt_dlp.networking import Request, Response
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import (
|
||||
VideoPlaybackAbrRequest,
|
||||
SabrError,
|
||||
FormatId,
|
||||
FormatInitializationMetadata,
|
||||
MediaHeader,
|
||||
)
|
||||
from yt_dlp.extractor.youtube._streaming.ump import UMPEncoder, UMPPart, UMPPartId, write_varint
|
||||
|
||||
VIDEO_PLAYBACK_USTREAMER_CONFIG = base64.urlsafe_b64encode(b'test-config').decode('utf-8')
|
||||
VIDEO_ID = 'test_video_id'
|
||||
|
||||
DEFAULT_NUM_AUDIO_SEGMENTS = 5
|
||||
DEFAULT_NUM_VIDEO_SEGMENTS = 10
|
||||
DEFAULT_MEDIA_SEGMENT_DATA = b'example-media-segment'
|
||||
DEFAULT_DURATION_MS = 10000
|
||||
DEFAULT_INIT_SEGMENT_DATA = b'example-init-segment'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SabrRequestDetails:
|
||||
request: Request
|
||||
parts: list = dataclasses.field(default_factory=list)
|
||||
response: Response | None = None
|
||||
vpabr: VideoPlaybackAbrRequest | None = None
|
||||
error: Exception | None = None
|
||||
|
||||
|
||||
class SabrRequestHandler:
|
||||
def __init__(self, sabr_response_processor: SabrResponseProcessor):
|
||||
self.sabr_response_processor = sabr_response_processor
|
||||
self.request_history = []
|
||||
|
||||
def send(self, request: Request) -> Response:
|
||||
try:
|
||||
vpabr, parts, response_code = self.sabr_response_processor.process_request(request.data, request.url)
|
||||
except Exception as e:
|
||||
self.request_history.append(
|
||||
SabrRequestDetails(request=request, error=e))
|
||||
raise e
|
||||
|
||||
fp = io.BytesIO()
|
||||
with UMPEncoder(fp) as encoder:
|
||||
for part in parts:
|
||||
encoder.write_part(part)
|
||||
|
||||
response = Response(
|
||||
url=request.url,
|
||||
status=response_code,
|
||||
headers={
|
||||
'Content-Type': 'application/vnd.yt-ump',
|
||||
'Content-Length': str(fp.tell()),
|
||||
},
|
||||
fp=fp,
|
||||
)
|
||||
fp.seek(0)
|
||||
|
||||
self.request_history.append(SabrRequestDetails(
|
||||
request=request,
|
||||
response=response,
|
||||
parts=parts,
|
||||
vpabr=vpabr,
|
||||
))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class SabrResponseProcessor:
|
||||
def process_request(self, data: bytes, url: str) -> tuple[VideoPlaybackAbrRequest | None, list[UMPPart], int]:
|
||||
try:
|
||||
vpabr = protobug.loads(data, VideoPlaybackAbrRequest)
|
||||
except Exception:
|
||||
error_part = protobug.dumps(SabrError(type='sabr.malformed_request'))
|
||||
# TODO: confirm GVS behaviour when VideoPlaybackAbrRequest is malformed
|
||||
return None, [UMPPart(data=io.BytesIO(error_part), part_id=UMPPartId.SABR_ERROR, size=len(error_part))], 200
|
||||
|
||||
return vpabr, self.get_parts(vpabr, url), 200
|
||||
|
||||
def get_parts(self, vpabr: VideoPlaybackAbrRequest, url: str) -> list[UMPPart]:
|
||||
raise NotImplementedError
|
||||
|
||||
def determine_formats(self, vpabr: VideoPlaybackAbrRequest) -> tuple[FormatId, FormatId]:
|
||||
# Check selected_audio_format_ids and selected_video_format_ids
|
||||
# TODO: caption format ids, consider initialized_format_ids, enabled_track_types_bitfield
|
||||
audio_format_ids = vpabr.selected_audio_format_ids
|
||||
video_format_ids = vpabr.selected_video_format_ids
|
||||
|
||||
audio_format_id = audio_format_ids[0] if audio_format_ids else FormatId(itag=140, lmt=123)
|
||||
video_format_id = video_format_ids[0] if video_format_ids else FormatId(itag=248, lmt=456)
|
||||
return audio_format_id, video_format_id
|
||||
|
||||
def get_format_initialization_metadata_parts(self,
|
||||
vpabr: VideoPlaybackAbrRequest,
|
||||
audio_format_id: FormatId | None = None,
|
||||
video_format_id: FormatId | None = None,
|
||||
total_audio_segments: int = DEFAULT_NUM_AUDIO_SEGMENTS,
|
||||
total_video_segments: int = DEFAULT_NUM_VIDEO_SEGMENTS,
|
||||
audio_end_time_ms: int = DEFAULT_DURATION_MS,
|
||||
video_end_time_ms: int = DEFAULT_DURATION_MS,
|
||||
audio_duration_ms: int = DEFAULT_DURATION_MS,
|
||||
video_duration_ms: int = DEFAULT_DURATION_MS,
|
||||
) -> list[UMPPart]:
|
||||
parts = []
|
||||
|
||||
audio_buffered_segments = self.buffered_segments(vpabr, total_audio_segments, audio_format_id)
|
||||
video_buffered_segments = self.buffered_segments(vpabr, total_video_segments, video_format_id)
|
||||
|
||||
if audio_format_id and not audio_buffered_segments:
|
||||
fim = protobug.dumps(FormatInitializationMetadata(
|
||||
video_id=VIDEO_ID,
|
||||
format_id=audio_format_id,
|
||||
mime_type='audio/mp4',
|
||||
total_segments=total_audio_segments,
|
||||
end_time_ms=audio_end_time_ms,
|
||||
duration_ticks=audio_duration_ms,
|
||||
duration_timescale=1000,
|
||||
))
|
||||
parts.append(UMPPart(
|
||||
part_id=UMPPartId.FORMAT_INITIALIZATION_METADATA,
|
||||
size=len(fim),
|
||||
data=io.BytesIO(fim),
|
||||
))
|
||||
|
||||
if video_format_id and not video_buffered_segments:
|
||||
fim = protobug.dumps(FormatInitializationMetadata(
|
||||
video_id=VIDEO_ID,
|
||||
format_id=video_format_id,
|
||||
mime_type='video/mp4',
|
||||
total_segments=total_video_segments,
|
||||
end_time_ms=video_end_time_ms,
|
||||
duration_ticks=video_duration_ms,
|
||||
duration_timescale=1000,
|
||||
))
|
||||
parts.append(UMPPart(
|
||||
part_id=UMPPartId.FORMAT_INITIALIZATION_METADATA,
|
||||
size=len(fim),
|
||||
data=io.BytesIO(fim),
|
||||
))
|
||||
|
||||
return parts
|
||||
|
||||
def buffered_segments(self, vpabr: VideoPlaybackAbrRequest, total_segments: int, format_id: FormatId):
|
||||
return {
|
||||
segment_index
|
||||
for buffered_range in vpabr.buffered_ranges
|
||||
if buffered_range.format_id == format_id
|
||||
for segment_index in range(buffered_range.start_segment_index, min(total_segments + 1, buffered_range.end_segment_index + 1))
|
||||
}
|
||||
|
||||
def get_init_segment_parts(self, header_id: int, format_id: FormatId):
|
||||
media_header = protobug.dumps(MediaHeader(
|
||||
header_id=header_id,
|
||||
format_id=format_id,
|
||||
is_init_segment=True,
|
||||
video_id=VIDEO_ID,
|
||||
content_length=len(DEFAULT_INIT_SEGMENT_DATA),
|
||||
))
|
||||
|
||||
varint_fp = io.BytesIO()
|
||||
write_varint(varint_fp, header_id)
|
||||
header_id_varint = varint_fp.getvalue()
|
||||
|
||||
return [
|
||||
UMPPart(
|
||||
part_id=UMPPartId.MEDIA_HEADER,
|
||||
size=len(media_header),
|
||||
data=io.BytesIO(media_header),
|
||||
),
|
||||
UMPPart(
|
||||
part_id=UMPPartId.MEDIA,
|
||||
size=len(DEFAULT_INIT_SEGMENT_DATA) + len(header_id_varint),
|
||||
data=io.BytesIO(header_id_varint + DEFAULT_INIT_SEGMENT_DATA),
|
||||
),
|
||||
UMPPart(
|
||||
part_id=UMPPartId.MEDIA_END,
|
||||
size=len(header_id_varint),
|
||||
data=io.BytesIO(header_id_varint),
|
||||
),
|
||||
]
|
||||
|
||||
def get_media_segments(
|
||||
self,
|
||||
buffered_segments: set[int],
|
||||
total_segments: int,
|
||||
max_segments: int,
|
||||
player_time_ms: int,
|
||||
start_header_id: int,
|
||||
format_id: FormatId,
|
||||
) -> tuple[list[UMPPart], int]:
|
||||
|
||||
segment_parts = []
|
||||
|
||||
if not buffered_segments:
|
||||
segment_parts.append(self.get_init_segment_parts(header_id=start_header_id, format_id=format_id))
|
||||
|
||||
segment_duration = (DEFAULT_DURATION_MS // total_segments)
|
||||
|
||||
for sequence_number in range(1, total_segments + 1):
|
||||
if sequence_number in buffered_segments:
|
||||
continue
|
||||
if len(segment_parts) >= max_segments:
|
||||
break
|
||||
start_ms = (sequence_number - 1) * segment_duration
|
||||
if start_ms:
|
||||
start_ms += 1 # should be + 1 from previous segment end time
|
||||
|
||||
# Basic server-side buffering logic to determine if the segment should be included
|
||||
if (
|
||||
(player_time_ms >= start_ms + segment_duration)
|
||||
or (player_time_ms < (start_ms - segment_duration * 2)) # allow to buffer 2 segments ahead
|
||||
):
|
||||
continue
|
||||
|
||||
header_id = len(segment_parts) + start_header_id
|
||||
media_header = protobug.dumps(MediaHeader(
|
||||
header_id=header_id,
|
||||
format_id=format_id,
|
||||
video_id=VIDEO_ID,
|
||||
content_length=len(DEFAULT_MEDIA_SEGMENT_DATA),
|
||||
sequence_number=sequence_number,
|
||||
duration_ms=segment_duration,
|
||||
start_ms=start_ms,
|
||||
))
|
||||
|
||||
varint_fp = io.BytesIO()
|
||||
write_varint(varint_fp, header_id)
|
||||
header_id_varint = varint_fp.getvalue()
|
||||
|
||||
segment_parts.append([
|
||||
UMPPart(
|
||||
part_id=UMPPartId.MEDIA_HEADER,
|
||||
size=len(media_header),
|
||||
data=io.BytesIO(media_header),
|
||||
),
|
||||
UMPPart(
|
||||
part_id=UMPPartId.MEDIA,
|
||||
size=len(DEFAULT_MEDIA_SEGMENT_DATA) + len(header_id_varint),
|
||||
data=io.BytesIO(header_id_varint + DEFAULT_MEDIA_SEGMENT_DATA),
|
||||
),
|
||||
UMPPart(
|
||||
part_id=UMPPartId.MEDIA_END,
|
||||
size=len(header_id_varint),
|
||||
data=io.BytesIO(header_id_varint),
|
||||
),
|
||||
])
|
||||
return [item for sublist in segment_parts for item in sublist], len(segment_parts) + start_header_id
|
||||
|
||||
|
||||
class BasicAudioVideoProfile(SabrResponseProcessor):
|
||||
def get_parts(self, vpabr: VideoPlaybackAbrRequest, url: str) -> list[UMPPart]:
|
||||
audio_format_id, video_format_id = self.determine_formats(vpabr)
|
||||
fim_parts = self.get_format_initialization_metadata_parts(
|
||||
audio_format_id=audio_format_id,
|
||||
video_format_id=video_format_id,
|
||||
vpabr=vpabr,
|
||||
)
|
||||
|
||||
audio_segment_parts, next_header_id = self.get_media_segments(
|
||||
buffered_segments=self.buffered_segments(vpabr, DEFAULT_NUM_AUDIO_SEGMENTS, audio_format_id),
|
||||
total_segments=DEFAULT_NUM_AUDIO_SEGMENTS,
|
||||
max_segments=2,
|
||||
player_time_ms=vpabr.client_abr_state.player_time_ms,
|
||||
start_header_id=0,
|
||||
format_id=audio_format_id,
|
||||
)
|
||||
video_segment_parts, next_header_id = self.get_media_segments(
|
||||
buffered_segments=self.buffered_segments(vpabr, DEFAULT_NUM_VIDEO_SEGMENTS, video_format_id),
|
||||
total_segments=DEFAULT_NUM_VIDEO_SEGMENTS,
|
||||
max_segments=2,
|
||||
player_time_ms=vpabr.client_abr_state.player_time_ms,
|
||||
start_header_id=next_header_id,
|
||||
format_id=video_format_id,
|
||||
)
|
||||
return [
|
||||
*fim_parts,
|
||||
*audio_segment_parts,
|
||||
*video_segment_parts,
|
||||
]
|
||||
|
||||
|
||||
class TestStream:
|
||||
def test_sabr_request_handler(self, logger, client_info):
|
||||
|
||||
rh = SabrRequestHandler(sabr_response_processor=BasicAudioVideoProfile())
|
||||
|
||||
sabr_stream = SabrStream(
|
||||
urlopen=rh.send,
|
||||
server_abr_streaming_url='https://example.com/sabr',
|
||||
logger=logger,
|
||||
video_playback_ustreamer_config=VIDEO_PLAYBACK_USTREAMER_CONFIG,
|
||||
client_info=client_info,
|
||||
audio_selection=AudioSelector(display_name='audio'),
|
||||
video_selection=VideoSelector(display_name='video'),
|
||||
)
|
||||
|
||||
for part in sabr_stream.iter_parts():
|
||||
print(part)
|
||||
print(logger.mock_calls)
|
@ -0,0 +1,132 @@
|
||||
import io
|
||||
import pytest
|
||||
|
||||
from yt_dlp.extractor.youtube._streaming.ump import (
|
||||
varint_size,
|
||||
read_varint,
|
||||
UMPDecoder,
|
||||
UMPPartId,
|
||||
write_varint,
|
||||
UMPEncoder,
|
||||
UMPPart,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('data, expected', [
|
||||
(0x01, 1),
|
||||
(0x4F, 1),
|
||||
(0x80, 2),
|
||||
(0xBF, 2),
|
||||
(0xC0, 3),
|
||||
(0xDF, 3),
|
||||
(0xE0, 4),
|
||||
(0xEF, 4),
|
||||
(0xF0, 5),
|
||||
(0xFF, 5),
|
||||
])
|
||||
def test_varint_size(data, expected):
|
||||
assert varint_size(data) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize('data, expected', [
|
||||
(b'\x01', 1),
|
||||
(b'\xad\x05', 365),
|
||||
(b'\xd5\x22\x05', 42069),
|
||||
(b'\xe0\x68\x89\x09', 10000000),
|
||||
(b'\xf0\xff\xc9\x9a\x3b', 999999999),
|
||||
(b'\xf0\xff\xff\xff\xff', 4294967295),
|
||||
],
|
||||
)
|
||||
def test_readvarint(data, expected):
|
||||
assert read_varint(io.BytesIO(data)) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize('value, expected_bytes', [
|
||||
(1, b'\x01'),
|
||||
(365, b'\xad\x05'),
|
||||
(42069, b'\xd5\x22\x05'),
|
||||
(10000000, b'\xe0\x68\x89\x09'),
|
||||
(999999999, b'\xf0\xff\xc9\x9a\x3b'),
|
||||
(4294967295, b'\xf0\xff\xff\xff\xff'),
|
||||
])
|
||||
def test_writevarint(value, expected_bytes):
|
||||
fp = io.BytesIO()
|
||||
write_varint(fp, value)
|
||||
assert fp.getvalue() == expected_bytes
|
||||
|
||||
|
||||
class TestUMPDecoder:
|
||||
EXAMPLE_PART_DATA = [
|
||||
{
|
||||
# Part 1: Part type of 20, part size of 127
|
||||
'part_type_bytes': b'\x14',
|
||||
'part_size_bytes': b'\x7F',
|
||||
'part_data_bytes': b'\x01' * 127,
|
||||
'part_id': UMPPartId.MEDIA_HEADER,
|
||||
'part_size': 127,
|
||||
},
|
||||
# Part 2, Part type of 4294967295, part size of 0
|
||||
{
|
||||
'part_type_bytes': b'\xFF\xFF\xFF\xFF\xFF',
|
||||
'part_size_bytes': b'\x00',
|
||||
'part_data_bytes': b'',
|
||||
'part_id': UMPPartId.UNKNOWN,
|
||||
'part_size': 0,
|
||||
},
|
||||
# Part 3: Part type of 21, part size of 1574912
|
||||
{
|
||||
'part_type_bytes': b'\x15',
|
||||
'part_size_bytes': b'\xE0\x80\x80\x01',
|
||||
'part_data_bytes': b'\x01' * 1574912,
|
||||
'part_id': UMPPartId.MEDIA,
|
||||
'part_size': 1574912,
|
||||
},
|
||||
]
|
||||
|
||||
COMBINED_PART_DATA = b''.join(part['part_type_bytes'] + part['part_size_bytes'] + part['part_data_bytes'] for part in EXAMPLE_PART_DATA)
|
||||
|
||||
def test_iter_parts(self):
|
||||
# Create a mock file-like object
|
||||
mock_file = io.BytesIO(self.COMBINED_PART_DATA)
|
||||
|
||||
# Create an instance of UMPDecoder with the mock file
|
||||
decoder = UMPDecoder(mock_file)
|
||||
|
||||
# Iterate over the parts and check the values
|
||||
for idx, part in enumerate(decoder.iter_parts()):
|
||||
assert part.part_id == self.EXAMPLE_PART_DATA[idx]['part_id']
|
||||
assert part.size == self.EXAMPLE_PART_DATA[idx]['part_size']
|
||||
assert part.data.read() == self.EXAMPLE_PART_DATA[idx]['part_data_bytes']
|
||||
|
||||
assert mock_file.closed
|
||||
|
||||
def test_unexpected_eof(self):
|
||||
# Unexpected bytes at the end of the file
|
||||
mock_file = io.BytesIO(self.COMBINED_PART_DATA + b'\x00')
|
||||
decoder = UMPDecoder(mock_file)
|
||||
|
||||
# Iterate over the parts and check the values
|
||||
with pytest.raises(EOFError, match='Unexpected EOF while reading part size'):
|
||||
for idx, part in enumerate(decoder.iter_parts()):
|
||||
assert part.part_id == self.EXAMPLE_PART_DATA[idx]['part_id']
|
||||
part.data.read()
|
||||
|
||||
assert mock_file.closed
|
||||
|
||||
|
||||
class TestUMPEncoder:
|
||||
def test_write_part(self):
|
||||
fp = io.BytesIO()
|
||||
encoder = UMPEncoder(fp)
|
||||
part = UMPPart(
|
||||
part_id=UMPPartId.MEDIA_HEADER,
|
||||
size=127,
|
||||
data=io.BytesIO(b'\x01' * 127),
|
||||
)
|
||||
|
||||
encoder.write_part(part)
|
||||
|
||||
part_type = b'\x14' # MEDIA_HEADER part type
|
||||
part_size = b'\x7F' # Part size of 127
|
||||
expected_data = part_type + part_size + b'\x01' * 127
|
||||
assert fp.getvalue() == expected_data
|
@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.utils import ticks_to_ms, broadcast_id_from_url
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'ticks, timescale, expected_ms',
|
||||
[
|
||||
(1000, 1000, 1000),
|
||||
(5000, 10000, 500),
|
||||
(234234, 44100, 5312),
|
||||
(1, 1, 1000),
|
||||
(None, 1000, None),
|
||||
(1000, None, None),
|
||||
(None, None, None),
|
||||
],
|
||||
)
|
||||
def test_ticks_to_ms(ticks, timescale, expected_ms):
|
||||
assert ticks_to_ms(ticks, timescale) == expected_ms
|
||||
|
||||
|
||||
def test_broadcast_id_from_url():
|
||||
assert broadcast_id_from_url('https://example.com/path?other=param&id=example.1~243&other2=param2') == 'example.1~243'
|
||||
assert broadcast_id_from_url('https://example.com/path?other=param&other2=param2') is None
|
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
from yt_dlp.utils import DownloadError
|
||||
from yt_dlp.downloader import FileDownloader
|
||||
|
||||
if not protobug:
|
||||
class SabrFD(FileDownloader):
|
||||
|
||||
@classmethod
|
||||
def can_download(cls, info_dict):
|
||||
is_sabr = (
|
||||
info_dict.get('requested_formats')
|
||||
and all(
|
||||
format_info.get('protocol') == 'sabr'
|
||||
for format_info in info_dict['requested_formats']))
|
||||
|
||||
if is_sabr:
|
||||
raise DownloadError('SABRFD requires protobug to be installed')
|
||||
|
||||
return is_sabr
|
||||
|
||||
else:
|
||||
from ._fd import SabrFD # noqa: F401
|
@ -0,0 +1,335 @@
|
||||
from __future__ import annotations
|
||||
import collections
|
||||
import itertools
|
||||
|
||||
from yt_dlp.networking.exceptions import TransportError, HTTPError
|
||||
from yt_dlp.utils import traverse_obj, int_or_none, DownloadError, join_nonempty
|
||||
from yt_dlp.downloader import FileDownloader
|
||||
|
||||
from ._writer import SabrFDFormatWriter
|
||||
from ._logger import create_sabrfd_logger
|
||||
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.part import (
|
||||
MediaSegmentEndSabrPart,
|
||||
MediaSegmentDataSabrPart,
|
||||
MediaSegmentInitSabrPart,
|
||||
PoTokenStatusSabrPart,
|
||||
RefreshPlayerResponseSabrPart,
|
||||
MediaSeekSabrPart,
|
||||
FormatInitializedSabrPart,
|
||||
)
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.stream import SabrStream
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.models import ConsumedRange, AudioSelector, VideoSelector, CaptionSelector
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.exceptions import SabrStreamError
|
||||
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, ClientName
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
|
||||
|
||||
|
||||
class SabrFD(FileDownloader):
|
||||
|
||||
@classmethod
|
||||
def can_download(cls, info_dict):
|
||||
return (
|
||||
info_dict.get('requested_formats')
|
||||
and all(
|
||||
format_info.get('protocol') == 'sabr'
|
||||
for format_info in info_dict['requested_formats']))
|
||||
|
||||
def _group_formats_by_client(self, filename, info_dict):
|
||||
format_groups = collections.defaultdict(dict, {})
|
||||
requested_formats = info_dict.get('requested_formats') or [info_dict]
|
||||
|
||||
for _idx, f in enumerate(requested_formats):
|
||||
sabr_config = f.get('_sabr_config')
|
||||
client_name = sabr_config.get('client_name')
|
||||
client_info = sabr_config.get('client_info')
|
||||
server_abr_streaming_url = f.get('url')
|
||||
video_playback_ustreamer_config = sabr_config.get('video_playback_ustreamer_config')
|
||||
|
||||
if not video_playback_ustreamer_config:
|
||||
raise DownloadError('Video playback ustreamer config not found')
|
||||
|
||||
sabr_format_group_config = format_groups.get(client_name)
|
||||
|
||||
if not sabr_format_group_config:
|
||||
sabr_format_group_config = format_groups[client_name] = {
|
||||
'server_abr_streaming_url': server_abr_streaming_url,
|
||||
'video_playback_ustreamer_config': video_playback_ustreamer_config,
|
||||
'formats': [],
|
||||
'initial_po_token': sabr_config.get('po_token'),
|
||||
'fetch_po_token_fn': fn if callable(fn := sabr_config.get('fetch_po_token_fn')) else None,
|
||||
'reload_config_fn': fn if callable(fn := sabr_config.get('reload_config_fn')) else None,
|
||||
'live_status': sabr_config.get('live_status'),
|
||||
'video_id': sabr_config.get('video_id'),
|
||||
'client_info': ClientInfo(
|
||||
client_name=traverse_obj(client_info, ('clientName', {lambda x: ClientName[x]})),
|
||||
client_version=traverse_obj(client_info, 'clientVersion'),
|
||||
os_version=traverse_obj(client_info, 'osVersion'),
|
||||
os_name=traverse_obj(client_info, 'osName'),
|
||||
device_model=traverse_obj(client_info, 'deviceModel'),
|
||||
device_make=traverse_obj(client_info, 'deviceMake'),
|
||||
),
|
||||
'target_duration_sec': sabr_config.get('target_duration_sec'),
|
||||
# Number.MAX_SAFE_INTEGER
|
||||
'start_time_ms': ((2**53) - 1) if info_dict.get('live_status') == 'is_live' and not f.get('is_from_start') else 0,
|
||||
}
|
||||
|
||||
else:
|
||||
if sabr_format_group_config['server_abr_streaming_url'] != server_abr_streaming_url:
|
||||
raise DownloadError('Server ABR streaming URL mismatch')
|
||||
|
||||
if sabr_format_group_config['video_playback_ustreamer_config'] != video_playback_ustreamer_config:
|
||||
raise DownloadError('Video playback ustreamer config mismatch')
|
||||
|
||||
itag = int_or_none(sabr_config.get('itag'))
|
||||
sabr_format_group_config['formats'].append({
|
||||
'display_name': f.get('format_id'),
|
||||
'format_id': itag and FormatId(
|
||||
itag=itag, lmt=int_or_none(sabr_config.get('last_modified')), xtags=sabr_config.get('xtags')),
|
||||
'format_type': format_type(f),
|
||||
'quality': sabr_config.get('quality'),
|
||||
'height': sabr_config.get('height'),
|
||||
'filename': f.get('filepath', filename),
|
||||
'info_dict': f,
|
||||
})
|
||||
|
||||
return format_groups
|
||||
|
||||
def real_download(self, filename, info_dict):
|
||||
format_groups = self._group_formats_by_client(filename, info_dict)
|
||||
|
||||
is_test = self.params.get('test', False)
|
||||
resume = self.params.get('continuedl', True)
|
||||
|
||||
for client_name, format_group in format_groups.items():
|
||||
formats = format_group['formats']
|
||||
audio_formats = (f for f in formats if f['format_type'] == 'audio')
|
||||
video_formats = (f for f in formats if f['format_type'] == 'video')
|
||||
caption_formats = (f for f in formats if f['format_type'] == 'caption')
|
||||
for audio_format, video_format, caption_format in itertools.zip_longest(audio_formats, video_formats, caption_formats):
|
||||
format_str = join_nonempty(*[
|
||||
traverse_obj(audio_format, 'display_name'),
|
||||
traverse_obj(video_format, 'display_name'),
|
||||
traverse_obj(caption_format, 'display_name')], delim='+')
|
||||
self.write_debug(f'Downloading formats: {format_str} ({client_name} client)')
|
||||
self._download_sabr_stream(
|
||||
info_dict=info_dict,
|
||||
video_format=video_format,
|
||||
audio_format=audio_format,
|
||||
caption_format=caption_format,
|
||||
resume=resume,
|
||||
is_test=is_test,
|
||||
server_abr_streaming_url=format_group['server_abr_streaming_url'],
|
||||
video_playback_ustreamer_config=format_group['video_playback_ustreamer_config'],
|
||||
initial_po_token=format_group['initial_po_token'],
|
||||
fetch_po_token_fn=format_group['fetch_po_token_fn'],
|
||||
reload_config_fn=format_group['reload_config_fn'],
|
||||
client_info=format_group['client_info'],
|
||||
start_time_ms=format_group['start_time_ms'],
|
||||
target_duration_sec=format_group.get('target_duration_sec', None),
|
||||
live_status=format_group.get('live_status'),
|
||||
video_id=format_group.get('video_id'),
|
||||
)
|
||||
return True
|
||||
|
||||
def _download_sabr_stream(
|
||||
self,
|
||||
video_id: str,
|
||||
info_dict: dict,
|
||||
video_format: dict,
|
||||
audio_format: dict,
|
||||
caption_format: dict,
|
||||
resume: bool,
|
||||
is_test: bool,
|
||||
server_abr_streaming_url: str,
|
||||
video_playback_ustreamer_config: str,
|
||||
initial_po_token: str,
|
||||
fetch_po_token_fn: callable | None = None,
|
||||
reload_config_fn: callable | None = None,
|
||||
client_info: ClientInfo | None = None,
|
||||
start_time_ms: int = 0,
|
||||
target_duration_sec: int | None = None,
|
||||
live_status: str | None = None,
|
||||
):
|
||||
|
||||
writers = {}
|
||||
audio_selector = None
|
||||
video_selector = None
|
||||
caption_selector = None
|
||||
|
||||
if audio_format:
|
||||
audio_selector = AudioSelector(
|
||||
display_name=audio_format['display_name'], format_ids=[audio_format['format_id']])
|
||||
writers[audio_selector.display_name] = SabrFDFormatWriter(
|
||||
self, audio_format.get('filename'),
|
||||
audio_format['info_dict'], len(writers), resume=resume)
|
||||
|
||||
if video_format:
|
||||
video_selector = VideoSelector(
|
||||
display_name=video_format['display_name'], format_ids=[video_format['format_id']])
|
||||
writers[video_selector.display_name] = SabrFDFormatWriter(
|
||||
self, video_format.get('filename'),
|
||||
video_format['info_dict'], len(writers), resume=resume)
|
||||
|
||||
if caption_format:
|
||||
caption_selector = CaptionSelector(
|
||||
display_name=caption_format['display_name'], format_ids=[caption_format['format_id']])
|
||||
writers[caption_selector.display_name] = SabrFDFormatWriter(
|
||||
self, caption_format.get('filename'),
|
||||
caption_format['info_dict'], len(writers), resume=resume)
|
||||
|
||||
# Report the destination files before we start downloading instead of when we initialize the writers,
|
||||
# as the formats may not all start at the same time (leading to messy output)
|
||||
for writer in writers.values():
|
||||
self.report_destination(writer.filename)
|
||||
|
||||
stream = SabrStream(
|
||||
urlopen=self.ydl.urlopen,
|
||||
logger=create_sabrfd_logger(self.ydl, prefix='sabr:stream'),
|
||||
server_abr_streaming_url=server_abr_streaming_url,
|
||||
video_playback_ustreamer_config=video_playback_ustreamer_config,
|
||||
po_token=initial_po_token,
|
||||
video_selection=video_selector,
|
||||
audio_selection=audio_selector,
|
||||
caption_selection=caption_selector,
|
||||
start_time_ms=start_time_ms,
|
||||
client_info=client_info,
|
||||
live_segment_target_duration_sec=target_duration_sec,
|
||||
post_live=live_status == 'post_live',
|
||||
video_id=video_id,
|
||||
retry_sleep_func=self.params.get('retry_sleep_functions', {}).get('http'),
|
||||
)
|
||||
|
||||
self._prepare_multiline_status(len(writers) + 1)
|
||||
|
||||
try:
|
||||
total_bytes = 0
|
||||
for part in stream:
|
||||
if is_test and total_bytes >= self._TEST_FILE_SIZE:
|
||||
stream.close()
|
||||
break
|
||||
if isinstance(part, PoTokenStatusSabrPart):
|
||||
if not fetch_po_token_fn:
|
||||
self.report_warning(
|
||||
'No fetch PO token function found - this can happen if you use --load-info-json.'
|
||||
' The download will fail if a valid PO token is required.', only_once=True)
|
||||
if part.status in (
|
||||
part.PoTokenStatus.INVALID,
|
||||
part.PoTokenStatus.PENDING,
|
||||
):
|
||||
# Fetch a PO token with bypass_cache=True
|
||||
# (ensure we create a new one)
|
||||
po_token = fetch_po_token_fn(bypass_cache=True)
|
||||
if po_token:
|
||||
stream.processor.po_token = po_token
|
||||
elif part.status in (
|
||||
part.PoTokenStatus.MISSING,
|
||||
part.PoTokenStatus.PENDING_MISSING,
|
||||
):
|
||||
# Fetch a PO Token, bypass_cache=False
|
||||
po_token = fetch_po_token_fn()
|
||||
if po_token:
|
||||
stream.processor.po_token = po_token
|
||||
|
||||
elif isinstance(part, FormatInitializedSabrPart):
|
||||
writer = writers.get(part.format_selector.display_name)
|
||||
if not writer:
|
||||
self.report_warning(f'Unknown format selector: {part.format_selector}')
|
||||
continue
|
||||
|
||||
writer.initialize_format(part.format_id)
|
||||
initialized_format = stream.processor.initialized_formats[str(part.format_id)]
|
||||
if writer.state.init_sequence:
|
||||
initialized_format.init_segment = True
|
||||
initialized_format.current_segment = None # allow a seek
|
||||
|
||||
# Build consumed ranges from the sequences
|
||||
consumed_ranges = []
|
||||
for sequence in writer.state.sequences:
|
||||
consumed_ranges.append(ConsumedRange(
|
||||
start_time_ms=sequence.first_segment.start_time_ms,
|
||||
duration_ms=(sequence.last_segment.start_time_ms + sequence.last_segment.duration_ms) - sequence.first_segment.start_time_ms,
|
||||
start_sequence_number=sequence.first_segment.sequence_number,
|
||||
end_sequence_number=sequence.last_segment.sequence_number,
|
||||
))
|
||||
if consumed_ranges:
|
||||
initialized_format.consumed_ranges = consumed_ranges
|
||||
initialized_format.current_segment = None # allow a seek
|
||||
self.to_screen(f'[download] Resuming download for format {part.format_selector.display_name}')
|
||||
|
||||
elif isinstance(part, MediaSegmentInitSabrPart):
|
||||
writer = writers.get(part.format_selector.display_name)
|
||||
if not writer:
|
||||
self.report_warning(f'Unknown init format selector: {part.format_selector}')
|
||||
continue
|
||||
writer.initialize_segment(part)
|
||||
|
||||
elif isinstance(part, MediaSegmentDataSabrPart):
|
||||
total_bytes += len(part.data) # TODO: not reliable
|
||||
writer = writers.get(part.format_selector.display_name)
|
||||
if not writer:
|
||||
self.report_warning(f'Unknown data format selector: {part.format_selector}')
|
||||
continue
|
||||
writer.write_segment_data(part)
|
||||
|
||||
elif isinstance(part, MediaSegmentEndSabrPart):
|
||||
writer = writers.get(part.format_selector.display_name)
|
||||
if not writer:
|
||||
self.report_warning(f'Unknown end format selector: {part.format_selector}')
|
||||
continue
|
||||
writer.end_segment(part)
|
||||
|
||||
elif isinstance(part, RefreshPlayerResponseSabrPart):
|
||||
self.to_screen(f'Refreshing player response; Reason: {part.reason}')
|
||||
# In-place refresh - not ideal but should work in most cases
|
||||
# TODO: handle case where live stream changes to non-livestream on refresh?
|
||||
# TODO: if live, allow a seek as for non-DVR streams the reload may be longer than the buffer duration
|
||||
# TODO: handle po token function change
|
||||
if not reload_config_fn:
|
||||
raise self.report_warning(
|
||||
'No reload config function found - cannot refresh SABR streaming URL.'
|
||||
' The url will expire soon and the download will fail.')
|
||||
try:
|
||||
stream.url, stream.processor.video_playback_ustreamer_config = reload_config_fn(part.reload_playback_token)
|
||||
except (TransportError, HTTPError) as e:
|
||||
self.report_warning(f'Failed to refresh SABR streaming URL: {e}')
|
||||
|
||||
elif isinstance(part, MediaSeekSabrPart):
|
||||
if (
|
||||
not info_dict.get('is_live')
|
||||
and live_status not in ('post_live', 'is_live')
|
||||
and not stream.processor.is_live
|
||||
and part.reason == MediaSeekSabrPart.Reason.SERVER_SEEK
|
||||
):
|
||||
raise DownloadError('Server tried to seek a video')
|
||||
else:
|
||||
self.to_screen(f'Unhandled part type: {part.__class__.__name__}')
|
||||
|
||||
for writer in writers.values():
|
||||
writer.finish()
|
||||
except SabrStreamError as e:
|
||||
raise DownloadError(str(e)) from e
|
||||
except KeyboardInterrupt:
|
||||
if (
|
||||
not info_dict.get('is_live')
|
||||
and not live_status == 'is_live'
|
||||
and not stream.processor.is_live
|
||||
):
|
||||
raise
|
||||
self.to_screen('Interrupted by user')
|
||||
for writer in writers.values():
|
||||
writer.finish()
|
||||
finally:
|
||||
# TODO: for livestreams, since we cannot resume them, should we finish the writers?
|
||||
for writer in writers.values():
|
||||
writer.close()
|
||||
|
||||
|
||||
def format_type(f):
|
||||
if f.get('acodec') == 'none':
|
||||
return 'video'
|
||||
elif f.get('vcodec') == 'none':
|
||||
return 'audio'
|
||||
elif f.get('vcodec') is None and f.get('acodec') is None:
|
||||
return 'caption'
|
||||
return None
|
@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from yt_dlp.utils import DownloadError
|
||||
from ._io import DiskFormatIOBackend, MemoryFormatIOBackend
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Segment:
|
||||
segment_id: str
|
||||
content_length: int | None = None
|
||||
content_length_estimated: bool = False
|
||||
sequence_number: int | None = None
|
||||
start_time_ms: int | None = None
|
||||
duration_ms: int | None = None
|
||||
duration_estimated: bool = False
|
||||
is_init_segment: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Sequence:
|
||||
sequence_id: str
|
||||
# The segments may not have a start byte range, so to keep it simple we will track
|
||||
# length of the sequence. We can infer from this and the segment's content_length where they should end and begin.
|
||||
sequence_content_length: int = 0
|
||||
first_segment: Segment | None = None
|
||||
last_segment: Segment | None = None
|
||||
|
||||
|
||||
class SequenceFile:
|
||||
|
||||
def __init__(self, fd, format_filename, sequence: Sequence, resume=False):
|
||||
self.fd = fd
|
||||
self.format_filename = format_filename
|
||||
self.sequence = sequence
|
||||
self.file = DiskFormatIOBackend(
|
||||
fd=self.fd,
|
||||
filename=self.format_filename + f'.sq{self.sequence_id}.sabr.part',
|
||||
)
|
||||
self.current_segment: SegmentFile | None = None
|
||||
self.resume = resume
|
||||
|
||||
sequence_file_exists = self.file.exists()
|
||||
|
||||
if not resume and sequence_file_exists:
|
||||
self.file.remove()
|
||||
|
||||
elif not self.sequence.last_segment and sequence_file_exists:
|
||||
self.file.remove()
|
||||
|
||||
if self.sequence.last_segment and not sequence_file_exists:
|
||||
raise DownloadError(f'Cannot find existing sequence {self.sequence_id} file')
|
||||
|
||||
if self.sequence.last_segment and not self.file.validate_length(self.sequence.sequence_content_length):
|
||||
self.file.remove()
|
||||
raise DownloadError(f'Existing sequence {self.sequence_id} file is not valid; removing')
|
||||
|
||||
@property
|
||||
def sequence_id(self):
|
||||
return self.sequence.sequence_id
|
||||
|
||||
@property
|
||||
def current_length(self):
|
||||
total = self.sequence.sequence_content_length
|
||||
if self.current_segment:
|
||||
total += self.current_segment.current_length
|
||||
return total
|
||||
|
||||
def is_next_segment(self, segment: Segment):
|
||||
if self.current_segment:
|
||||
return False
|
||||
latest_segment = self.sequence.last_segment or self.sequence.first_segment
|
||||
if not latest_segment:
|
||||
return True
|
||||
if segment.is_init_segment and latest_segment.is_init_segment:
|
||||
# Only one segment allowed for init segments
|
||||
return False
|
||||
return segment.sequence_number == latest_segment.sequence_number + 1
|
||||
|
||||
def is_current_segment(self, segment_id: str):
|
||||
if not self.current_segment:
|
||||
return False
|
||||
return self.current_segment.segment_id == segment_id
|
||||
|
||||
def initialize_segment(self, segment: Segment):
|
||||
if self.current_segment and not self.is_current_segment(segment.segment_id):
|
||||
raise ValueError('Cannot reinitialize a segment that does not match the current segment')
|
||||
|
||||
if not self.current_segment and not self.is_next_segment(segment):
|
||||
raise ValueError('Cannot initialize a segment that does not match the next segment')
|
||||
|
||||
self.current_segment = SegmentFile(
|
||||
fd=self.fd,
|
||||
format_filename=self.format_filename,
|
||||
segment=segment,
|
||||
)
|
||||
|
||||
def write_segment_data(self, data, segment_id: str):
|
||||
if not self.is_current_segment(segment_id):
|
||||
raise ValueError('Cannot write to a segment that does not match the current segment')
|
||||
|
||||
self.current_segment.write(data)
|
||||
|
||||
def end_segment(self, segment_id):
|
||||
if not self.is_current_segment(segment_id):
|
||||
raise ValueError('Cannot end a segment that does not exist')
|
||||
|
||||
self.current_segment.finish_write()
|
||||
|
||||
if (
|
||||
self.current_segment.segment.content_length
|
||||
and not self.current_segment.segment.content_length_estimated
|
||||
and self.current_segment.current_length != self.current_segment.segment.content_length
|
||||
):
|
||||
raise DownloadError(
|
||||
f'Filesize mismatch for segment {self.current_segment.segment_id}: '
|
||||
f'Expected {self.current_segment.segment.content_length} bytes, got {self.current_segment.current_length} bytes')
|
||||
|
||||
self.current_segment.segment.content_length = self.current_segment.current_length
|
||||
self.current_segment.segment.content_length_estimated = False
|
||||
|
||||
if not self.sequence.first_segment:
|
||||
self.sequence.first_segment = self.current_segment.segment
|
||||
|
||||
self.sequence.last_segment = self.current_segment.segment
|
||||
self.sequence.sequence_content_length += self.current_segment.current_length
|
||||
|
||||
if not self.file.mode:
|
||||
self.file.initialize_writer(self.resume)
|
||||
|
||||
self.current_segment.read_into(self.file)
|
||||
self.current_segment.remove()
|
||||
self.current_segment = None
|
||||
|
||||
def read_into(self, backend):
|
||||
self.file.initialize_reader()
|
||||
self.file.read_into(backend)
|
||||
self.file.close()
|
||||
|
||||
def remove(self):
|
||||
self.close()
|
||||
self.file.remove()
|
||||
|
||||
def close(self):
|
||||
self.file.close()
|
||||
|
||||
|
||||
class SegmentFile:
|
||||
|
||||
def __init__(self, fd, format_filename, segment: Segment, memory_file_limit=2 * 1024 * 1024):
|
||||
self.fd = fd
|
||||
self.format_filename = format_filename
|
||||
self.segment: Segment = segment
|
||||
self.current_length = 0
|
||||
|
||||
filename = format_filename + f'.sg{segment.sequence_number}.sabr.part'
|
||||
# Store the segment in memory if it is small enough
|
||||
if segment.content_length and segment.content_length <= memory_file_limit:
|
||||
self.file = MemoryFormatIOBackend(
|
||||
fd=self.fd,
|
||||
filename=filename,
|
||||
)
|
||||
else:
|
||||
self.file = DiskFormatIOBackend(
|
||||
fd=self.fd,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
# Never resume a segment
|
||||
exists = self.file.exists()
|
||||
if exists:
|
||||
self.file.remove()
|
||||
|
||||
@property
|
||||
def segment_id(self):
|
||||
return self.segment.segment_id
|
||||
|
||||
def write(self, data):
|
||||
if not self.file.mode:
|
||||
self.file.initialize_writer(resume=False)
|
||||
self.current_length += self.file.write(data)
|
||||
|
||||
def read_into(self, file):
|
||||
self.file.initialize_reader()
|
||||
self.file.read_into(file)
|
||||
self.file.close()
|
||||
|
||||
def remove(self):
|
||||
self.close()
|
||||
self.file.remove()
|
||||
|
||||
def finish_write(self):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
self.file.close()
|
@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
import typing
|
||||
|
||||
|
||||
class FormatIOBackend(abc.ABC):
|
||||
def __init__(self, fd, filename, buffer=1024 * 1024):
|
||||
self.fd = fd
|
||||
self.filename = filename
|
||||
self.write_buffer = buffer
|
||||
self._fp = None
|
||||
self._fp_mode = None
|
||||
|
||||
@property
|
||||
def writer(self):
|
||||
if self._fp is None or self._fp_mode != 'write':
|
||||
return None
|
||||
return self._fp
|
||||
|
||||
@property
|
||||
def reader(self):
|
||||
if self._fp is None or self._fp_mode != 'read':
|
||||
return None
|
||||
return self._fp
|
||||
|
||||
def initialize_writer(self, resume=False):
|
||||
if self._fp is not None:
|
||||
raise ValueError('Backend already initialized')
|
||||
|
||||
self._fp = self._create_writer(resume)
|
||||
self._fp_mode = 'write'
|
||||
|
||||
@abc.abstractmethod
|
||||
def _create_writer(self, resume=False) -> typing.IO:
|
||||
pass
|
||||
|
||||
def initialize_reader(self):
|
||||
if self._fp is not None:
|
||||
raise ValueError('Backend already initialized')
|
||||
self._fp = self._create_reader()
|
||||
self._fp_mode = 'read'
|
||||
|
||||
@abc.abstractmethod
|
||||
def _create_reader(self) -> typing.IO:
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
if self._fp and not self._fp.closed:
|
||||
self._fp.flush()
|
||||
self._fp.close()
|
||||
self._fp = None
|
||||
self._fp_mode = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def validate_length(self, expected_length):
|
||||
pass
|
||||
|
||||
def remove(self):
|
||||
self.close()
|
||||
self._remove()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _remove(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def exists(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
if self._fp is None:
|
||||
return None
|
||||
return self._fp_mode
|
||||
|
||||
def write(self, data: io.BufferedIOBase | bytes):
|
||||
if not self.writer:
|
||||
raise ValueError('Backend writer not initialized')
|
||||
|
||||
if isinstance(data, bytes):
|
||||
bytes_written = self.writer.write(data)
|
||||
elif isinstance(data, io.BufferedIOBase):
|
||||
bytes_written = self.writer.tell()
|
||||
shutil.copyfileobj(data, self.writer, length=self.write_buffer)
|
||||
bytes_written = self.writer.tell() - bytes_written
|
||||
else:
|
||||
raise TypeError('Data must be bytes or a BufferedIOBase object')
|
||||
|
||||
self.writer.flush()
|
||||
|
||||
return bytes_written
|
||||
|
||||
def read_into(self, backend):
|
||||
if not backend.writer:
|
||||
raise ValueError('Backend writer not initialized')
|
||||
if not self.reader:
|
||||
raise ValueError('Backend reader not initialized')
|
||||
shutil.copyfileobj(self.reader, backend.writer, length=self.write_buffer)
|
||||
backend.writer.flush()
|
||||
|
||||
|
||||
class DiskFormatIOBackend(FormatIOBackend):
|
||||
def _create_writer(self, resume=False) -> typing.IO:
|
||||
if resume and self.exists():
|
||||
write_fp, self.filename = self.fd.sanitize_open(self.filename, 'ab')
|
||||
else:
|
||||
write_fp, self.filename = self.fd.sanitize_open(self.filename, 'wb')
|
||||
return write_fp
|
||||
|
||||
def _create_reader(self) -> typing.IO:
|
||||
read_fp, self.filename = self.fd.sanitize_open(self.filename, 'rb')
|
||||
return read_fp
|
||||
|
||||
def validate_length(self, expected_length):
|
||||
return os.path.getsize(self.filename) == expected_length
|
||||
|
||||
def _remove(self):
|
||||
self.fd.try_remove(self.filename)
|
||||
|
||||
def exists(self):
|
||||
return os.path.isfile(self.filename)
|
||||
|
||||
|
||||
class MemoryFormatIOBackend(FormatIOBackend):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._memory_store = io.BytesIO()
|
||||
|
||||
def _create_writer(self, resume=False) -> typing.IO:
|
||||
class NonClosingBufferedWriter(io.BufferedWriter):
|
||||
def close(self):
|
||||
self.flush()
|
||||
# Do not close the underlying buffer
|
||||
|
||||
if resume and self.exists():
|
||||
self._memory_store.seek(0, io.SEEK_END)
|
||||
else:
|
||||
self._memory_store.seek(0)
|
||||
self._memory_store.truncate(0)
|
||||
|
||||
return NonClosingBufferedWriter(self._memory_store)
|
||||
|
||||
def _create_reader(self) -> typing.IO:
|
||||
class NonClosingBufferedReader(io.BufferedReader):
|
||||
def close(self):
|
||||
self.flush()
|
||||
|
||||
# Seek to the beginning of the buffer
|
||||
self._memory_store.seek(0)
|
||||
return NonClosingBufferedReader(self._memory_store)
|
||||
|
||||
def validate_length(self, expected_length):
|
||||
return self._memory_store.getbuffer().nbytes != expected_length
|
||||
|
||||
def _remove(self):
|
||||
self._memory_store = io.BytesIO()
|
||||
|
||||
def exists(self):
|
||||
return self._memory_store.getbuffer().nbytes > 0
|
@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.utils import format_field, traverse_obj
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.models import SabrLogger
|
||||
from yt_dlp.utils._utils import _YDLLogger
|
||||
|
||||
# TODO: create a logger that logs to a file rather than the console.
|
||||
# Might be useful for debugging SABR issues from users.
|
||||
|
||||
|
||||
class SabrFDLogger(SabrLogger):
|
||||
def __init__(self, ydl, prefix, log_level: SabrLogger.LogLevel | None = None):
|
||||
self._ydl_logger = _YDLLogger(ydl)
|
||||
self.prefix = prefix
|
||||
self.log_level = log_level if log_level is not None else self.LogLevel.INFO
|
||||
|
||||
def _format_msg(self, message: str):
|
||||
prefixstr = format_field(self.prefix, None, '[%s] ')
|
||||
return f'{prefixstr}{message}'
|
||||
|
||||
def trace(self, message: str):
|
||||
if self.log_level <= self.LogLevel.TRACE:
|
||||
self._ydl_logger.debug(self._format_msg('TRACE: ' + message))
|
||||
|
||||
def debug(self, message: str):
|
||||
if self.log_level <= self.LogLevel.DEBUG:
|
||||
self._ydl_logger.debug(self._format_msg(message))
|
||||
|
||||
def info(self, message: str):
|
||||
if self.log_level <= self.LogLevel.INFO:
|
||||
self._ydl_logger.info(self._format_msg(message))
|
||||
|
||||
def warning(self, message: str, *, once=False):
|
||||
if self.log_level <= self.LogLevel.WARNING:
|
||||
self._ydl_logger.warning(self._format_msg(message), once=once)
|
||||
|
||||
def error(self, message: str):
|
||||
if self.log_level <= self.LogLevel.ERROR:
|
||||
self._ydl_logger.error(self._format_msg(message), is_error=False)
|
||||
|
||||
|
||||
def create_sabrfd_logger(ydl, prefix):
|
||||
return SabrFDLogger(
|
||||
ydl, prefix=prefix,
|
||||
log_level=SabrFDLogger.LogLevel(traverse_obj(
|
||||
ydl.params, ('extractor_args', 'youtube', 'sabr_log_level', 0, {str}), get_all=False)))
|
@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
|
||||
|
||||
|
||||
@protobug.message
|
||||
class SabrStateSegment:
|
||||
sequence_number: protobug.Int32 = protobug.field(1)
|
||||
start_time_ms: protobug.Int64 = protobug.field(2)
|
||||
duration_ms: protobug.Int64 = protobug.field(3)
|
||||
duration_estimated: protobug.Bool = protobug.field(4)
|
||||
content_length: protobug.Int64 = protobug.field(5)
|
||||
|
||||
|
||||
@protobug.message
|
||||
class SabrStateSequence:
|
||||
sequence_start_number: protobug.Int32 = protobug.field(1)
|
||||
sequence_content_length: protobug.Int64 = protobug.field(2)
|
||||
first_segment: SabrStateSegment = protobug.field(3)
|
||||
last_segment: SabrStateSegment = protobug.field(4)
|
||||
|
||||
|
||||
@protobug.message
|
||||
class SabrStateInitSegment:
|
||||
content_length: protobug.Int64 = protobug.field(2)
|
||||
|
||||
|
||||
@protobug.message
|
||||
class SabrState:
|
||||
format_id: FormatId = protobug.field(1)
|
||||
init_segment: SabrStateInitSegment | None = protobug.field(2, default=None)
|
||||
sequences: list[SabrStateSequence] = protobug.field(3, default_factory=list)
|
||||
|
||||
|
||||
class SabrStateFile:
|
||||
|
||||
def __init__(self, format_filename, fd):
|
||||
self.filename = format_filename + '.sabr.state'
|
||||
self.fd = fd
|
||||
|
||||
@property
|
||||
def exists(self):
|
||||
return os.path.isfile(self.filename)
|
||||
|
||||
def retrieve(self):
|
||||
stream, self.filename = self.fd.sanitize_open(self.filename, 'rb')
|
||||
try:
|
||||
return self.deserialize(stream.read())
|
||||
finally:
|
||||
stream.close()
|
||||
|
||||
def update(self, sabr_document):
|
||||
# Attempt to write progress document somewhat atomically to avoid corruption
|
||||
with tempfile.NamedTemporaryFile('wb', delete=False, dir=os.path.dirname(self.filename)) as tf:
|
||||
tf.write(self.serialize(sabr_document))
|
||||
tf.flush()
|
||||
os.fsync(tf.fileno())
|
||||
|
||||
try:
|
||||
os.replace(tf.name, self.filename)
|
||||
finally:
|
||||
if os.path.exists(tf.name):
|
||||
with contextlib.suppress(FileNotFoundError, OSError):
|
||||
os.unlink(tf.name)
|
||||
|
||||
def serialize(self, sabr_document):
|
||||
return protobug.dumps(sabr_document)
|
||||
|
||||
def deserialize(self, data):
|
||||
return protobug.loads(data, SabrState)
|
||||
|
||||
def remove(self):
|
||||
self.fd.try_remove(self.filename)
|
@ -0,0 +1,355 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
from ._io import DiskFormatIOBackend
|
||||
from ._file import SequenceFile, Sequence, Segment
|
||||
from ._state import (
|
||||
SabrStateSegment,
|
||||
SabrStateSequence,
|
||||
SabrStateInitSegment,
|
||||
SabrState,
|
||||
SabrStateFile,
|
||||
)
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.part import (
|
||||
MediaSegmentInitSabrPart,
|
||||
MediaSegmentDataSabrPart,
|
||||
MediaSegmentEndSabrPart,
|
||||
)
|
||||
|
||||
from yt_dlp.utils import DownloadError
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
|
||||
from yt_dlp.utils.progress import ProgressCalculator
|
||||
|
||||
INIT_SEGMENT_ID = 'i'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SabrFormatState:
|
||||
format_id: FormatId
|
||||
init_sequence: Sequence | None = None
|
||||
sequences: list[Sequence] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
class SabrFDFormatWriter:
|
||||
def __init__(self, fd, filename, infodict, progress_idx=0, resume=False):
|
||||
self.fd = fd
|
||||
self.info_dict = infodict
|
||||
self.filename = filename
|
||||
self.progress_idx = progress_idx
|
||||
self.resume = resume
|
||||
|
||||
self._progress = None
|
||||
self._downloaded_bytes = 0
|
||||
self._state = {}
|
||||
self._format_id = None
|
||||
|
||||
self.file = DiskFormatIOBackend(
|
||||
fd=self.fd,
|
||||
filename=self.fd.temp_name(filename),
|
||||
)
|
||||
self._sabr_state_file = SabrStateFile(format_filename=self.filename, fd=fd)
|
||||
self._sequence_files: list[SequenceFile] = []
|
||||
self._init_sequence: SequenceFile | None = None
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return SabrFormatState(
|
||||
format_id=self._format_id,
|
||||
init_sequence=self._init_sequence.sequence if self._init_sequence else None,
|
||||
sequences=[sf.sequence for sf in self._sequence_files],
|
||||
)
|
||||
|
||||
@property
|
||||
def downloaded_bytes(self):
|
||||
return (sum(
|
||||
sequence.current_length for sequence in self._sequence_files)
|
||||
+ (self._init_sequence.current_length if self._init_sequence else 0))
|
||||
|
||||
def initialize_format(self, format_id):
|
||||
if self._format_id:
|
||||
raise ValueError('Already initialized')
|
||||
self._format_id = format_id
|
||||
|
||||
if not self.resume:
|
||||
if self._sabr_state_file.exists:
|
||||
self._sabr_state_file.remove()
|
||||
return
|
||||
|
||||
document = self._load_sabr_state()
|
||||
|
||||
if document.init_segment:
|
||||
init_segment = Segment(
|
||||
segment_id=INIT_SEGMENT_ID,
|
||||
content_length=document.init_segment.content_length,
|
||||
is_init_segment=True,
|
||||
)
|
||||
|
||||
try:
|
||||
self._init_sequence = SequenceFile(
|
||||
fd=self.fd,
|
||||
format_filename=self.filename,
|
||||
resume=True,
|
||||
sequence=Sequence(
|
||||
sequence_id=INIT_SEGMENT_ID,
|
||||
sequence_content_length=init_segment.content_length,
|
||||
first_segment=init_segment,
|
||||
last_segment=init_segment,
|
||||
))
|
||||
except DownloadError as e:
|
||||
self.fd.report_warning(f'Failed to resume init segment for format {self.info_dict.get("format_id")}: {e}')
|
||||
|
||||
for sabr_sequence in list(document.sequences):
|
||||
try:
|
||||
self._sequence_files.append(SequenceFile(
|
||||
fd=self.fd,
|
||||
format_filename=self.filename,
|
||||
resume=True,
|
||||
sequence=Sequence(
|
||||
sequence_id=str(sabr_sequence.sequence_start_number),
|
||||
sequence_content_length=sabr_sequence.sequence_content_length,
|
||||
first_segment=Segment(
|
||||
segment_id=str(sabr_sequence.first_segment.sequence_number),
|
||||
sequence_number=sabr_sequence.first_segment.sequence_number,
|
||||
content_length=sabr_sequence.first_segment.content_length,
|
||||
start_time_ms=sabr_sequence.first_segment.start_time_ms,
|
||||
duration_ms=sabr_sequence.first_segment.duration_ms,
|
||||
is_init_segment=False,
|
||||
),
|
||||
last_segment=Segment(
|
||||
segment_id=str(sabr_sequence.last_segment.sequence_number),
|
||||
sequence_number=sabr_sequence.last_segment.sequence_number,
|
||||
content_length=sabr_sequence.last_segment.content_length,
|
||||
start_time_ms=sabr_sequence.last_segment.start_time_ms,
|
||||
duration_ms=sabr_sequence.last_segment.duration_ms,
|
||||
is_init_segment=False,
|
||||
),
|
||||
),
|
||||
))
|
||||
except DownloadError as e:
|
||||
self.fd.report_warning(
|
||||
f'Failed to resume sequence {sabr_sequence.sequence_start_number} '
|
||||
f'for format {self.info_dict.get("format_id")}: {e}')
|
||||
|
||||
@property
|
||||
def initialized(self):
|
||||
return self._format_id is not None
|
||||
|
||||
def close(self):
|
||||
if not self.file:
|
||||
raise ValueError('Already closed')
|
||||
for sequence in self._sequence_files:
|
||||
sequence.close()
|
||||
self._sequence_files.clear()
|
||||
if self._init_sequence:
|
||||
self._init_sequence.close()
|
||||
self._init_sequence = None
|
||||
self.file.close()
|
||||
|
||||
def _find_sequence_file(self, predicate):
|
||||
match = None
|
||||
for sequence in self._sequence_files:
|
||||
if predicate(sequence):
|
||||
if match is not None:
|
||||
raise DownloadError('Multiple sequence files found for segment')
|
||||
match = sequence
|
||||
return match
|
||||
|
||||
def find_next_sequence_file(self, next_segment: Segment):
|
||||
return self._find_sequence_file(lambda sequence: sequence.is_next_segment(next_segment))
|
||||
|
||||
def find_current_sequence_file(self, segment_id: str):
|
||||
return self._find_sequence_file(lambda sequence: sequence.is_current_segment(segment_id))
|
||||
|
||||
def initialize_segment(self, part: MediaSegmentInitSabrPart):
|
||||
if not self._progress:
|
||||
self._progress = ProgressCalculator(part.start_bytes)
|
||||
|
||||
if not self._format_id:
|
||||
raise ValueError('not initialized')
|
||||
|
||||
if part.is_init_segment:
|
||||
if not self._init_sequence:
|
||||
self._init_sequence = SequenceFile(
|
||||
fd=self.fd,
|
||||
format_filename=self.filename,
|
||||
resume=False,
|
||||
sequence=Sequence(
|
||||
sequence_id=INIT_SEGMENT_ID,
|
||||
))
|
||||
|
||||
self._init_sequence.initialize_segment(Segment(
|
||||
segment_id=INIT_SEGMENT_ID,
|
||||
content_length=part.content_length,
|
||||
content_length_estimated=part.content_length_estimated,
|
||||
is_init_segment=True,
|
||||
))
|
||||
return True
|
||||
|
||||
segment = Segment(
|
||||
segment_id=str(part.sequence_number),
|
||||
sequence_number=part.sequence_number,
|
||||
start_time_ms=part.start_time_ms,
|
||||
duration_ms=part.duration_ms,
|
||||
duration_estimated=part.duration_estimated,
|
||||
content_length=part.content_length,
|
||||
content_length_estimated=part.content_length_estimated,
|
||||
)
|
||||
|
||||
sequence_file = self.find_current_sequence_file(segment.segment_id) or self.find_next_sequence_file(segment)
|
||||
|
||||
if not sequence_file:
|
||||
sequence_file = SequenceFile(
|
||||
fd=self.fd,
|
||||
format_filename=self.filename,
|
||||
resume=False,
|
||||
sequence=Sequence(sequence_id=str(part.sequence_number)),
|
||||
)
|
||||
self._sequence_files.append(sequence_file)
|
||||
|
||||
sequence_file.initialize_segment(segment)
|
||||
return True
|
||||
|
||||
def write_segment_data(self, part: MediaSegmentDataSabrPart):
|
||||
if part.is_init_segment:
|
||||
sequence_file, segment_id = self._init_sequence, INIT_SEGMENT_ID
|
||||
else:
|
||||
segment_id = str(part.sequence_number)
|
||||
sequence_file = self.find_current_sequence_file(segment_id)
|
||||
|
||||
if not sequence_file:
|
||||
raise DownloadError('Unable to find sequence file for segment. Was the segment initialized?')
|
||||
|
||||
sequence_file.write_segment_data(part.data, segment_id)
|
||||
|
||||
# TODO: Handling of disjointed segments (e.g. when downloading segments out of order / concurrently)
|
||||
self._progress.total = self.info_dict.get('filesize')
|
||||
self._state = {
|
||||
'status': 'downloading',
|
||||
'downloaded_bytes': self.downloaded_bytes,
|
||||
'total_bytes': self.info_dict.get('filesize'),
|
||||
'filename': self.filename,
|
||||
'eta': self._progress.eta.smooth,
|
||||
'speed': self._progress.speed.smooth,
|
||||
'elapsed': self._progress.elapsed,
|
||||
'progress_idx': self.progress_idx,
|
||||
'fragment_count': part.total_segments,
|
||||
'fragment_index': part.sequence_number,
|
||||
}
|
||||
|
||||
self._progress.update(self._state['downloaded_bytes'])
|
||||
self.fd._hook_progress(self._state, self.info_dict)
|
||||
|
||||
def end_segment(self, part: MediaSegmentEndSabrPart):
|
||||
if part.is_init_segment:
|
||||
sequence_file, segment_id = self._init_sequence, INIT_SEGMENT_ID
|
||||
else:
|
||||
segment_id = str(part.sequence_number)
|
||||
sequence_file = self.find_current_sequence_file(segment_id)
|
||||
|
||||
if not sequence_file:
|
||||
raise DownloadError('Unable to find sequence file for segment. Was the segment initialized?')
|
||||
|
||||
sequence_file.end_segment(segment_id)
|
||||
self._write_sabr_state()
|
||||
|
||||
def _load_sabr_state(self):
|
||||
sabr_state = None
|
||||
if self._sabr_state_file.exists:
|
||||
try:
|
||||
sabr_state = self._sabr_state_file.retrieve()
|
||||
except Exception:
|
||||
self.fd.report_warning(
|
||||
f'Corrupted state file for format {self.info_dict.get("format_id")}, restarting download')
|
||||
|
||||
if sabr_state and sabr_state.format_id != self._format_id:
|
||||
self.fd.report_warning(
|
||||
f'Format ID mismatch in state file for {self.info_dict.get("format_id")}, restarting download')
|
||||
sabr_state = None
|
||||
|
||||
if not sabr_state:
|
||||
sabr_state = SabrState(format_id=self._format_id)
|
||||
|
||||
return sabr_state
|
||||
|
||||
def _write_sabr_state(self):
|
||||
sabr_state = SabrState(format_id=self._format_id)
|
||||
|
||||
if not self._init_sequence:
|
||||
sabr_state.init_segment = None
|
||||
else:
|
||||
sabr_state.init_segment = SabrStateInitSegment(
|
||||
content_length=self._init_sequence.sequence.sequence_content_length,
|
||||
)
|
||||
|
||||
sabr_state.sequences = []
|
||||
for sequence_file in self._sequence_files:
|
||||
# Ignore partial sequences
|
||||
if not sequence_file.sequence.first_segment or not sequence_file.sequence.last_segment:
|
||||
continue
|
||||
sabr_state.sequences.append(SabrStateSequence(
|
||||
sequence_start_number=sequence_file.sequence.first_segment.sequence_number,
|
||||
sequence_content_length=sequence_file.sequence.sequence_content_length,
|
||||
first_segment=SabrStateSegment(
|
||||
sequence_number=sequence_file.sequence.first_segment.sequence_number,
|
||||
start_time_ms=sequence_file.sequence.first_segment.start_time_ms,
|
||||
duration_ms=sequence_file.sequence.first_segment.duration_ms,
|
||||
duration_estimated=sequence_file.sequence.first_segment.duration_estimated,
|
||||
content_length=sequence_file.sequence.first_segment.content_length,
|
||||
),
|
||||
last_segment=SabrStateSegment(
|
||||
sequence_number=sequence_file.sequence.last_segment.sequence_number,
|
||||
start_time_ms=sequence_file.sequence.last_segment.start_time_ms,
|
||||
duration_ms=sequence_file.sequence.last_segment.duration_ms,
|
||||
duration_estimated=sequence_file.sequence.last_segment.duration_estimated,
|
||||
content_length=sequence_file.sequence.last_segment.content_length,
|
||||
),
|
||||
))
|
||||
|
||||
self._sabr_state_file.update(sabr_state)
|
||||
|
||||
def finish(self):
|
||||
self._state['status'] = 'finished'
|
||||
self.fd._hook_progress(self._state, self.info_dict)
|
||||
|
||||
for sequence_file in self._sequence_files:
|
||||
sequence_file.close()
|
||||
|
||||
if self._init_sequence:
|
||||
self._init_sequence.close()
|
||||
|
||||
# Now merge all the sequences together
|
||||
self.file.initialize_writer(resume=False)
|
||||
|
||||
# Note: May not always be an init segment, e.g for live streams
|
||||
if self._init_sequence:
|
||||
self._init_sequence.read_into(self.file)
|
||||
self._init_sequence.close()
|
||||
|
||||
# TODO: handling of disjointed segments
|
||||
previous_seq_number = None
|
||||
for sequence_file in sorted(
|
||||
(sf for sf in self._sequence_files if sf.sequence.first_segment),
|
||||
key=lambda s: s.sequence.first_segment.sequence_number):
|
||||
if previous_seq_number and previous_seq_number + 1 != sequence_file.sequence.first_segment.sequence_number:
|
||||
self.fd.report_warning(f'Disjointed sequences found in SABR format {self.info_dict.get("format_id")}')
|
||||
previous_seq_number = sequence_file.sequence.last_segment.sequence_number
|
||||
sequence_file.read_into(self.file)
|
||||
sequence_file.close()
|
||||
|
||||
# Format temp file should have all the segments, rename it to the final name
|
||||
self.file.close()
|
||||
self.fd.try_rename(self.file.filename, self.fd.undo_temp_name(self.file.filename))
|
||||
|
||||
# Remove the state file
|
||||
self._sabr_state_file.remove()
|
||||
|
||||
# Remove sequence files
|
||||
for sf in self._sequence_files:
|
||||
sf.close()
|
||||
sf.remove()
|
||||
|
||||
if self._init_sequence:
|
||||
self._init_sequence.close()
|
||||
self._init_sequence.remove()
|
||||
self.close()
|
@ -0,0 +1,14 @@
|
||||
import dataclasses
|
||||
import typing
|
||||
|
||||
|
||||
def unknown_fields(obj: typing.Any, path=()) -> typing.Iterable[tuple[tuple[str, ...], dict[int, list]]]:
|
||||
if not dataclasses.is_dataclass(obj):
|
||||
return
|
||||
|
||||
if unknown := getattr(obj, '_unknown', None):
|
||||
yield path, unknown
|
||||
|
||||
for field in dataclasses.fields(obj):
|
||||
value = getattr(obj, field.name)
|
||||
yield from unknown_fields(value, (*path, field.name))
|
@ -0,0 +1,5 @@
|
||||
from .client_info import ClientInfo, ClientName # noqa: F401
|
||||
from .compression_algorithm import CompressionAlgorithm # noqa: F401
|
||||
from .next_request_policy import NextRequestPolicy # noqa: F401
|
||||
from .range import Range # noqa: F401
|
||||
from .seek_source import SeekSource # noqa: F401
|
@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
class ClientName(protobug.Enum, strict=False):
|
||||
UNKNOWN_INTERFACE = 0
|
||||
WEB = 1
|
||||
MWEB = 2
|
||||
ANDROID = 3
|
||||
IOS = 5
|
||||
TVHTML5 = 7
|
||||
TVLITE = 8
|
||||
TVANDROID = 10
|
||||
XBOX = 11
|
||||
CLIENTX = 12
|
||||
XBOXONEGUIDE = 13
|
||||
ANDROID_CREATOR = 14
|
||||
IOS_CREATOR = 15
|
||||
TVAPPLE = 16
|
||||
IOS_INSTANT = 17
|
||||
ANDROID_KIDS = 18
|
||||
IOS_KIDS = 19
|
||||
ANDROID_INSTANT = 20
|
||||
ANDROID_MUSIC = 21
|
||||
IOS_TABLOID = 22
|
||||
ANDROID_TV = 23
|
||||
ANDROID_GAMING = 24
|
||||
IOS_GAMING = 25
|
||||
IOS_MUSIC = 26
|
||||
MWEB_TIER_2 = 27
|
||||
ANDROID_VR = 28
|
||||
ANDROID_UNPLUGGED = 29
|
||||
ANDROID_TESTSUITE = 30
|
||||
WEB_MUSIC_ANALYTICS = 31
|
||||
WEB_GAMING = 32
|
||||
IOS_UNPLUGGED = 33
|
||||
ANDROID_WITNESS = 34
|
||||
IOS_WITNESS = 35
|
||||
ANDROID_SPORTS = 36
|
||||
IOS_SPORTS = 37
|
||||
ANDROID_LITE = 38
|
||||
IOS_EMBEDDED_PLAYER = 39
|
||||
IOS_DIRECTOR = 40
|
||||
WEB_UNPLUGGED = 41
|
||||
WEB_EXPERIMENTS = 42
|
||||
TVHTML5_CAST = 43
|
||||
WEB_EMBEDDED_PLAYER = 56
|
||||
TVHTML5_AUDIO = 57
|
||||
TV_UNPLUGGED_CAST = 58
|
||||
TVHTML5_KIDS = 59
|
||||
WEB_HEROES = 60
|
||||
WEB_MUSIC = 61
|
||||
WEB_CREATOR = 62
|
||||
TV_UNPLUGGED_ANDROID = 63
|
||||
IOS_LIVE_CREATION_EXTENSION = 64
|
||||
TVHTML5_UNPLUGGED = 65
|
||||
IOS_MESSAGES_EXTENSION = 66
|
||||
WEB_REMIX = 67
|
||||
IOS_UPTIME = 68
|
||||
WEB_UNPLUGGED_ONBOARDING = 69
|
||||
WEB_UNPLUGGED_OPS = 70
|
||||
WEB_UNPLUGGED_PUBLIC = 71
|
||||
TVHTML5_VR = 72
|
||||
WEB_LIVE_STREAMING = 73
|
||||
ANDROID_TV_KIDS = 74
|
||||
TVHTML5_SIMPLY = 75
|
||||
WEB_KIDS = 76
|
||||
MUSIC_INTEGRATIONS = 77
|
||||
TVHTML5_YONGLE = 80
|
||||
GOOGLE_ASSISTANT = 84
|
||||
TVHTML5_SIMPLY_EMBEDDED_PLAYER = 85
|
||||
WEB_MUSIC_EMBEDDED_PLAYER = 86
|
||||
WEB_INTERNAL_ANALYTICS = 87
|
||||
WEB_PARENT_TOOLS = 88
|
||||
GOOGLE_MEDIA_ACTIONS = 89
|
||||
WEB_PHONE_VERIFICATION = 90
|
||||
ANDROID_PRODUCER = 91
|
||||
IOS_PRODUCER = 92
|
||||
TVHTML5_FOR_KIDS = 93
|
||||
GOOGLE_LIST_RECS = 94
|
||||
MEDIA_CONNECT_FRONTEND = 95
|
||||
WEB_EFFECT_MAKER = 98
|
||||
WEB_SHOPPING_EXTENSION = 99
|
||||
WEB_PLAYABLES_PORTAL = 100
|
||||
VISIONOS = 101
|
||||
WEB_LIVE_APPS = 102
|
||||
WEB_MUSIC_INTEGRATIONS = 103
|
||||
ANDROID_MUSIC_AOSP = 104
|
||||
|
||||
|
||||
@protobug.message
|
||||
class ClientInfo:
|
||||
hl: protobug.String | None = protobug.field(1, default=None)
|
||||
gl: protobug.String | None = protobug.field(2, default=None)
|
||||
remote_host: protobug.String | None = protobug.field(4, default=None)
|
||||
|
||||
device_make: protobug.String | None = protobug.field(12, default=None)
|
||||
device_model: protobug.String | None = protobug.field(13, default=None)
|
||||
visitor_data: protobug.String | None = protobug.field(14, default=None)
|
||||
user_agent: protobug.String | None = protobug.field(15, default=None)
|
||||
client_name: ClientName | None = protobug.field(16, default=None)
|
||||
client_version: protobug.String | None = protobug.field(17, default=None)
|
||||
os_name: protobug.String | None = protobug.field(18, default=None)
|
||||
os_version: protobug.String | None = protobug.field(19, default=None)
|
@ -0,0 +1,7 @@
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
class CompressionAlgorithm(protobug.Enum, strict=False):
|
||||
COMPRESSION_ALGORITHM_UNKNOWN = 0
|
||||
COMPRESSION_ALGORITHM_NONE = 1
|
||||
COMPRESSION_ALGORITHM_GZIP = 2
|
@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
@protobug.message
|
||||
class NextRequestPolicy:
|
||||
target_audio_readahead_ms: protobug.Int32 | None = protobug.field(1, default=None)
|
||||
target_video_readahead_ms: protobug.Int32 | None = protobug.field(2, default=None)
|
||||
max_time_since_last_request_ms: protobug.Int32 | None = protobug.field(3, default=None)
|
||||
backoff_time_ms: protobug.Int32 | None = protobug.field(4, default=None)
|
||||
min_audio_readahead_ms: protobug.Int32 | None = protobug.field(5, default=None)
|
||||
min_video_readahead_ms: protobug.Int32 | None = protobug.field(6, default=None)
|
||||
playback_cookie: protobug.Bytes | None = protobug.field(7, default=None)
|
||||
video_id: protobug.String | None = protobug.field(8, default=None)
|
@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
@protobug.message
|
||||
class Range:
|
||||
start: protobug.Int64 | None = protobug.field(1, default=None)
|
||||
end: protobug.Int64 | None = protobug.field(2, default=None)
|
@ -0,0 +1,14 @@
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
class SeekSource(protobug.Enum, strict=False):
|
||||
SEEK_SOURCE_UNKNOWN = 0
|
||||
SEEK_SOURCE_SABR_PARTIAL_CHUNK = 9
|
||||
SEEK_SOURCE_SABR_SEEK_TO_HEAD = 10
|
||||
SEEK_SOURCE_SABR_LIVE_DVR_USER_SEEK = 11
|
||||
SEEK_SOURCE_SABR_SEEK_TO_DVR_LOWER_BOUND = 12
|
||||
SEEK_SOURCE_SABR_SEEK_TO_DVR_UPPER_BOUND = 13
|
||||
SEEK_SOURCE_SABR_ACCURATE_SEEK = 17
|
||||
SEEK_SOURCE_SABR_INGESTION_WALL_TIME_SEEK = 29
|
||||
SEEK_SOURCE_SABR_SEEK_TO_CLOSEST_KEYFRAME = 59
|
||||
SEEK_SOURCE_SABR_RELOAD_PLAYER_RESPONSE_TOKEN_SEEK = 106
|
@ -0,0 +1,16 @@
|
||||
from .buffered_range import BufferedRange # noqa: F401
|
||||
from .client_abr_state import ClientAbrState # noqa: F401
|
||||
from .format_id import FormatId # noqa: F401
|
||||
from .format_initialization_metadata import FormatInitializationMetadata # noqa: F401
|
||||
from .live_metadata import LiveMetadata # noqa: F401
|
||||
from .media_header import MediaHeader # noqa: F401
|
||||
from .reload_player_response import ReloadPlayerResponse # noqa: F401
|
||||
from .sabr_context_sending_policy import SabrContextSendingPolicy # noqa: F401
|
||||
from .sabr_context_update import SabrContextUpdate # noqa: F401
|
||||
from .sabr_error import SabrError # noqa: F401
|
||||
from .sabr_redirect import SabrRedirect # noqa: F401
|
||||
from .sabr_seek import SabrSeek # noqa: F401
|
||||
from .stream_protection_status import StreamProtectionStatus # noqa: F401
|
||||
from .streamer_context import SabrContext, StreamerContext # noqa: F401
|
||||
from .time_range import TimeRange # noqa: F401
|
||||
from .video_playback_abr_request import VideoPlaybackAbrRequest # noqa: F401
|
@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
from .format_id import FormatId
|
||||
from .time_range import TimeRange
|
||||
|
||||
|
||||
@protobug.message
|
||||
class BufferedRange:
|
||||
format_id: FormatId | None = protobug.field(1, default=None)
|
||||
start_time_ms: protobug.Int64 | None = protobug.field(2, default=None)
|
||||
duration_ms: protobug.Int64 | None = protobug.field(3, default=None)
|
||||
start_segment_index: protobug.Int32 | None = protobug.field(4, default=None)
|
||||
end_segment_index: protobug.Int32 | None = protobug.field(5, default=None)
|
||||
time_range: TimeRange | None = protobug.field(6, default=None)
|
@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
@protobug.message
|
||||
class ClientAbrState:
|
||||
player_time_ms: protobug.Int64 | None = protobug.field(28, default=None)
|
||||
enabled_track_types_bitfield: protobug.Int32 | None = protobug.field(40, default=None)
|
||||
drc_enabled: protobug.Bool = protobug.field(46, default=False)
|
@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
@protobug.message
|
||||
class FormatId:
|
||||
itag: protobug.Int32 | None = protobug.field(1)
|
||||
lmt: protobug.UInt64 | None = protobug.field(2, default=None)
|
||||
xtags: protobug.String | None = protobug.field(3, default=None)
|
@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
from .format_id import FormatId
|
||||
from ..innertube import Range
|
||||
|
||||
|
||||
@protobug.message
|
||||
class FormatInitializationMetadata:
|
||||
video_id: protobug.String = protobug.field(1, default=None)
|
||||
format_id: FormatId = protobug.field(2, default=None)
|
||||
end_time_ms: protobug.Int32 | None = protobug.field(3, default=None)
|
||||
total_segments: protobug.Int32 | None = protobug.field(4, default=None)
|
||||
mime_type: protobug.String | None = protobug.field(5, default=None)
|
||||
init_range: Range | None = protobug.field(6, default=None)
|
||||
index_range: Range | None = protobug.field(7, default=None)
|
||||
duration_ticks: protobug.Int32 | None = protobug.field(9, default=None)
|
||||
duration_timescale: protobug.Int32 | None = protobug.field(10, default=None)
|
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
@protobug.message
|
||||
class LiveMetadata:
|
||||
head_sequence_number: protobug.Int32 | None = protobug.field(3, default=None)
|
||||
head_sequence_time_ms: protobug.Int64 | None = protobug.field(4, default=None)
|
||||
wall_time_ms: protobug.Int64 | None = protobug.field(5, default=None)
|
||||
video_id: protobug.String | None = protobug.field(6, default=None)
|
||||
source: protobug.String | None = protobug.field(7, default=None)
|
||||
|
||||
min_seekable_time_ticks: protobug.Int64 | None = protobug.field(12, default=None)
|
||||
min_seekable_timescale: protobug.Int32 | None = protobug.field(13, default=None)
|
||||
|
||||
max_seekable_time_ticks: protobug.Int64 | None = protobug.field(14, default=None)
|
||||
max_seekable_timescale: protobug.Int32 | None = protobug.field(15, default=None)
|
@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
from .format_id import FormatId
|
||||
from .time_range import TimeRange
|
||||
from ..innertube import CompressionAlgorithm
|
||||
|
||||
|
||||
@protobug.message
|
||||
class MediaHeader:
|
||||
header_id: protobug.UInt32 | None = protobug.field(1, default=None)
|
||||
video_id: protobug.String | None = protobug.field(2, default=None)
|
||||
itag: protobug.Int32 | None = protobug.field(3, default=None)
|
||||
last_modified: protobug.UInt64 | None = protobug.field(4, default=None)
|
||||
xtags: protobug.String | None = protobug.field(5, default=None)
|
||||
start_data_range: protobug.Int32 | None = protobug.field(6, default=None)
|
||||
compression: CompressionAlgorithm | None = protobug.field(7, default=None)
|
||||
is_init_segment: protobug.Bool | None = protobug.field(8, default=None)
|
||||
sequence_number: protobug.Int64 | None = protobug.field(9, default=None)
|
||||
bitrate_bps: protobug.Int64 | None = protobug.field(10, default=None)
|
||||
start_ms: protobug.Int32 | None = protobug.field(11, default=None)
|
||||
duration_ms: protobug.Int32 | None = protobug.field(12, default=None)
|
||||
format_id: FormatId | None = protobug.field(13, default=None)
|
||||
content_length: protobug.Int64 | None = protobug.field(14, default=None)
|
||||
time_range: TimeRange | None = protobug.field(15, default=None)
|
||||
sequence_lmt: protobug.Int32 | None = protobug.field(16, default=None)
|
@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
@protobug.message
|
||||
class ReloadPlaybackParams:
|
||||
token: protobug.String | None = protobug.field(1, default=None)
|
||||
|
||||
|
||||
@protobug.message
|
||||
class ReloadPlayerResponse:
|
||||
reload_playback_params: ReloadPlaybackParams | None = protobug.field(1, default=None)
|
@ -0,0 +1,13 @@
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
@protobug.message
|
||||
class SabrContextSendingPolicy:
|
||||
# Start sending the SabrContextUpdates of this type
|
||||
start_policy: list[protobug.Int32] = protobug.field(1, default_factory=list)
|
||||
|
||||
# Stop sending the SabrContextUpdates of this type
|
||||
stop_policy: list[protobug.Int32] = protobug.field(2, default_factory=list)
|
||||
|
||||
# Stop and discard the SabrContextUpdates of this type
|
||||
discard_policy: list[protobug.Int32] = protobug.field(3, default_factory=list)
|
@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
@protobug.message
|
||||
class SabrContextUpdate:
|
||||
|
||||
class SabrContextScope(protobug.Enum, strict=False):
|
||||
SABR_CONTEXT_SCOPE_UNKNOWN = 0
|
||||
SABR_CONTEXT_SCOPE_PLAYBACK = 1
|
||||
SABR_CONTEXT_SCOPE_REQUEST = 2
|
||||
SABR_CONTEXT_SCOPE_WATCH_ENDPOINT = 3
|
||||
SABR_CONTEXT_SCOPE_CONTENT_ADS = 4
|
||||
|
||||
class SabrContextWritePolicy(protobug.Enum, strict=False):
|
||||
SABR_CONTEXT_WRITE_POLICY_UNSPECIFIED = 0
|
||||
SABR_CONTEXT_WRITE_POLICY_OVERWRITE = 1
|
||||
SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING = 2
|
||||
|
||||
type: protobug.Int32 | None = protobug.field(1, default=None)
|
||||
scope: SabrContextScope | None = protobug.field(2, default=None)
|
||||
value: protobug.Bytes | None = protobug.field(3, default=None)
|
||||
send_by_default: protobug.Bool | None = protobug.field(4, default=None)
|
||||
write_policy: SabrContextWritePolicy | None = protobug.field(5, default=None)
|
@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
@protobug.message
|
||||
class Error:
|
||||
status_code: protobug.Int32 | None = protobug.field(1, default=None)
|
||||
type: protobug.Int32 | None = protobug.field(4, default=None)
|
||||
|
||||
|
||||
@protobug.message
|
||||
class SabrError:
|
||||
type: protobug.String | None = protobug.field(1, default=None)
|
||||
action: protobug.Int32 | None = protobug.field(2, default=None)
|
||||
error: Error | None = protobug.field(3, default=None)
|
@ -0,0 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
@protobug.message
|
||||
class SabrRedirect:
|
||||
redirect_url: protobug.String | None = protobug.field(1, default=None)
|
@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
from ..innertube import SeekSource
|
||||
|
||||
|
||||
@protobug.message
|
||||
class SabrSeek:
|
||||
seek_time_ticks: protobug.Int32 = protobug.field(1)
|
||||
timescale: protobug.Int32 = protobug.field(2)
|
||||
seek_source: SeekSource | None = protobug.field(3, default=None)
|
@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
@protobug.message
|
||||
class StreamProtectionStatus:
|
||||
|
||||
class Status(protobug.Enum, strict=False):
|
||||
OK = 1
|
||||
ATTESTATION_PENDING = 2
|
||||
ATTESTATION_REQUIRED = 3
|
||||
|
||||
status: Status | None = protobug.field(1, default=None)
|
||||
max_retries: protobug.Int32 | None = protobug.field(2, default=None)
|
@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
from ..innertube import ClientInfo
|
||||
|
||||
|
||||
@protobug.message
|
||||
class SabrContext:
|
||||
# Type and Value from a SabrContextUpdate
|
||||
type: protobug.Int32 | None = protobug.field(1, default=None)
|
||||
value: protobug.Bytes | None = protobug.field(2, default=None)
|
||||
|
||||
|
||||
@protobug.message
|
||||
class StreamerContext:
|
||||
client_info: ClientInfo | None = protobug.field(1, default=None)
|
||||
po_token: protobug.Bytes | None = protobug.field(2, default=None)
|
||||
playback_cookie: protobug.Bytes | None = protobug.field(3, default=None)
|
||||
sabr_contexts: list[SabrContext] = protobug.field(5, default_factory=list)
|
||||
unsent_sabr_contexts: list[protobug.Int32] = protobug.field(6, default_factory=list)
|
@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
|
||||
@protobug.message
|
||||
class TimeRange:
|
||||
start_ticks: protobug.Int64 | None = protobug.field(1, default=None)
|
||||
duration_ticks: protobug.Int64 | None = protobug.field(2, default=None)
|
||||
timescale: protobug.Int32 | None = protobug.field(3, default=None)
|
@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
|
||||
from .buffered_range import BufferedRange
|
||||
from .client_abr_state import ClientAbrState
|
||||
from .format_id import FormatId
|
||||
from .streamer_context import StreamerContext
|
||||
|
||||
|
||||
@protobug.message
|
||||
class VideoPlaybackAbrRequest:
|
||||
client_abr_state: ClientAbrState = protobug.field(1, default=None)
|
||||
initialized_format_ids: list[FormatId] = protobug.field(2, default_factory=list)
|
||||
buffered_ranges: list[BufferedRange] = protobug.field(3, default_factory=list)
|
||||
player_time_ms: protobug.Int64 | None = protobug.field(4, default=None)
|
||||
video_playback_ustreamer_config: protobug.Bytes | None = protobug.field(5, default=None)
|
||||
|
||||
selected_audio_format_ids: list[FormatId] = protobug.field(16, default_factory=list)
|
||||
selected_video_format_ids: list[FormatId] = protobug.field(17, default_factory=list)
|
||||
selected_caption_format_ids: list[FormatId] = protobug.field(18, default_factory=list)
|
||||
streamer_context: StreamerContext = protobug.field(19, default_factory=StreamerContext)
|
@ -0,0 +1,26 @@
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
|
||||
from yt_dlp.utils import YoutubeDLError
|
||||
|
||||
|
||||
class SabrStreamConsumedError(YoutubeDLError):
|
||||
pass
|
||||
|
||||
|
||||
class SabrStreamError(YoutubeDLError):
|
||||
pass
|
||||
|
||||
|
||||
class MediaSegmentMismatchError(SabrStreamError):
|
||||
def __init__(self, format_id: FormatId, expected_sequence_number: int, received_sequence_number: int):
|
||||
super().__init__(
|
||||
f'Segment sequence number mismatch for format {format_id}: '
|
||||
f'expected {expected_sequence_number}, received {received_sequence_number}')
|
||||
self.expected_sequence_number = expected_sequence_number
|
||||
self.received_sequence_number = received_sequence_number
|
||||
|
||||
|
||||
class PoTokenError(SabrStreamError):
|
||||
def __init__(self, missing=False):
|
||||
super().__init__(
|
||||
f'This stream requires a GVS PO Token to continue'
|
||||
f'{" and the one provided is invalid" if not missing else ""}')
|
@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
|
||||
from yt_dlp.extractor.youtube.pot._provider import IEContentProviderLogger
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Segment:
|
||||
format_id: FormatId
|
||||
is_init_segment: bool = False
|
||||
duration_ms: int = 0
|
||||
start_ms: int = 0
|
||||
start_data_range: int = 0
|
||||
sequence_number: int | None = 0
|
||||
content_length: int | None = None
|
||||
content_length_estimated: bool = False
|
||||
initialized_format: InitializedFormat = None
|
||||
# Whether duration_ms is an estimate
|
||||
duration_estimated: bool = False
|
||||
# Whether we should discard the segment data
|
||||
discard: bool = False
|
||||
# Whether the segment has already been consumed.
|
||||
# `discard` should be set to True if this is the case.
|
||||
consumed: bool = False
|
||||
received_data_length: int = 0
|
||||
sequence_lmt: int | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConsumedRange:
|
||||
start_sequence_number: int
|
||||
end_sequence_number: int
|
||||
start_time_ms: int
|
||||
duration_ms: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InitializedFormat:
|
||||
format_id: FormatId
|
||||
video_id: str
|
||||
format_selector: FormatSelector | None = None
|
||||
duration_ms: int = 0
|
||||
end_time_ms: int = 0
|
||||
mime_type: str = None
|
||||
# Current segment in the sequence. Set to None to break the sequence and allow a seek.
|
||||
current_segment: Segment | None = None
|
||||
init_segment: Segment | None | bool = None
|
||||
consumed_ranges: list[ConsumedRange] = dataclasses.field(default_factory=list)
|
||||
total_segments: int = None
|
||||
# Whether we should discard any data received for this format
|
||||
discard: bool = False
|
||||
sequence_lmt: int | None = None
|
||||
|
||||
|
||||
SabrLogger = IEContentProviderLogger
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FormatSelector:
|
||||
display_name: str
|
||||
format_ids: list[FormatId] = dataclasses.field(default_factory=list)
|
||||
discard_media: bool = False
|
||||
mime_prefix: str | None = None
|
||||
|
||||
def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool:
|
||||
return (
|
||||
format_id in self.format_ids
|
||||
or (
|
||||
not self.format_ids
|
||||
and self.mime_prefix
|
||||
and mime_type and mime_type.lower().startswith(self.mime_prefix)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AudioSelector(FormatSelector):
|
||||
mime_prefix: str = dataclasses.field(default='audio')
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VideoSelector(FormatSelector):
|
||||
mime_prefix: str = dataclasses.field(default='video')
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CaptionSelector(FormatSelector):
|
||||
mime_prefix: str = dataclasses.field(default='text')
|
@ -0,0 +1,94 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
|
||||
|
||||
from .models import FormatSelector
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SabrPart:
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MediaSegmentInitSabrPart(SabrPart):
|
||||
format_selector: FormatSelector
|
||||
format_id: FormatId
|
||||
sequence_number: int | None = None
|
||||
is_init_segment: bool = False
|
||||
total_segments: int = None
|
||||
start_time_ms: int = None
|
||||
player_time_ms: int = None
|
||||
duration_ms: int = None
|
||||
duration_estimated: bool = False
|
||||
start_bytes: int = None
|
||||
content_length: int = None
|
||||
content_length_estimated: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MediaSegmentDataSabrPart(SabrPart):
|
||||
format_selector: FormatSelector
|
||||
format_id: FormatId
|
||||
sequence_number: int | None = None
|
||||
is_init_segment: bool = False
|
||||
total_segments: int | None = None
|
||||
data: bytes = b''
|
||||
content_length: int | None = None
|
||||
segment_start_bytes: int | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MediaSegmentEndSabrPart(SabrPart):
|
||||
format_selector: FormatSelector
|
||||
format_id: FormatId
|
||||
sequence_number: int | None = None
|
||||
is_init_segment: bool = False
|
||||
total_segments: int = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FormatInitializedSabrPart(SabrPart):
|
||||
format_id: FormatId
|
||||
format_selector: FormatSelector
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PoTokenStatusSabrPart(SabrPart):
|
||||
class PoTokenStatus(enum.Enum):
|
||||
OK = enum.auto() # PO Token is provided and valid
|
||||
MISSING = enum.auto() # PO Token is not provided, and is required. A PO Token should be provided ASAP
|
||||
INVALID = enum.auto() # PO Token is provided, but is invalid. A new one should be generated ASAP
|
||||
PENDING = enum.auto() # PO Token is provided, but probably only a cold start token. A full PO Token should be provided ASAP
|
||||
NOT_REQUIRED = enum.auto() # PO Token is not provided, and is not required
|
||||
PENDING_MISSING = enum.auto() # PO Token is not provided, but is pending. A full PO Token should be (probably) provided ASAP
|
||||
|
||||
status: PoTokenStatus
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RefreshPlayerResponseSabrPart(SabrPart):
|
||||
|
||||
class Reason(enum.Enum):
|
||||
UNKNOWN = enum.auto()
|
||||
SABR_URL_EXPIRY = enum.auto()
|
||||
SABR_RELOAD_PLAYER_RESPONSE = enum.auto()
|
||||
|
||||
reason: Reason
|
||||
reload_playback_token: str = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MediaSeekSabrPart(SabrPart):
|
||||
# Lets the consumer know the media sequence for a format may change
|
||||
class Reason(enum.Enum):
|
||||
UNKNOWN = enum.auto()
|
||||
SERVER_SEEK = enum.auto() # SABR_SEEK from server
|
||||
CONSUMED_SEEK = enum.auto() # Seeking as next fragment is already buffered
|
||||
|
||||
reason: Reason
|
||||
format_id: FormatId
|
||||
format_selector: FormatSelector
|
@ -0,0 +1,668 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import math
|
||||
|
||||
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import (
|
||||
BufferedRange,
|
||||
ClientAbrState,
|
||||
FormatInitializationMetadata,
|
||||
LiveMetadata,
|
||||
MediaHeader,
|
||||
SabrContext,
|
||||
SabrContextSendingPolicy,
|
||||
SabrContextUpdate,
|
||||
SabrSeek,
|
||||
StreamerContext,
|
||||
StreamProtectionStatus,
|
||||
TimeRange,
|
||||
VideoPlaybackAbrRequest,
|
||||
)
|
||||
|
||||
from .exceptions import MediaSegmentMismatchError, SabrStreamError
|
||||
from .models import (
|
||||
AudioSelector,
|
||||
CaptionSelector,
|
||||
ConsumedRange,
|
||||
InitializedFormat,
|
||||
SabrLogger,
|
||||
Segment,
|
||||
VideoSelector,
|
||||
)
|
||||
from .part import (
|
||||
FormatInitializedSabrPart,
|
||||
MediaSeekSabrPart,
|
||||
MediaSegmentDataSabrPart,
|
||||
MediaSegmentEndSabrPart,
|
||||
MediaSegmentInitSabrPart,
|
||||
PoTokenStatusSabrPart,
|
||||
)
|
||||
from .utils import ticks_to_ms
|
||||
|
||||
|
||||
class ProcessMediaEndResult:
|
||||
def __init__(self, sabr_part: MediaSegmentEndSabrPart = None, is_new_segment: bool = False):
|
||||
self.is_new_segment = is_new_segment # TODO: better name
|
||||
self.sabr_part = sabr_part
|
||||
|
||||
|
||||
class ProcessMediaResult:
|
||||
def __init__(self, sabr_part: MediaSegmentDataSabrPart = None):
|
||||
self.sabr_part = sabr_part
|
||||
|
||||
|
||||
class ProcessMediaHeaderResult:
|
||||
def __init__(self, sabr_part: MediaSegmentInitSabrPart | None = None):
|
||||
self.sabr_part = sabr_part
|
||||
|
||||
|
||||
class ProcessLiveMetadataResult:
|
||||
def __init__(self, seek_sabr_parts: list[MediaSeekSabrPart] | None = None):
|
||||
self.seek_sabr_parts = seek_sabr_parts or []
|
||||
|
||||
|
||||
class ProcessStreamProtectionStatusResult:
|
||||
def __init__(self, sabr_part: PoTokenStatusSabrPart | None = None):
|
||||
self.sabr_part = sabr_part
|
||||
|
||||
|
||||
class ProcessFormatInitializationMetadataResult:
|
||||
def __init__(self, sabr_part: FormatInitializedSabrPart | None = None):
|
||||
self.sabr_part = sabr_part
|
||||
|
||||
|
||||
class ProcessSabrSeekResult:
|
||||
def __init__(self, seek_sabr_parts: list[MediaSeekSabrPart] | None = None):
|
||||
self.seek_sabr_parts = seek_sabr_parts or []
|
||||
|
||||
|
||||
class SabrProcessor:
|
||||
"""
|
||||
SABR Processor
|
||||
|
||||
This handles core SABR protocol logic, independent of requests.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: SabrLogger,
|
||||
video_playback_ustreamer_config: str,
|
||||
client_info: ClientInfo,
|
||||
audio_selection: AudioSelector | None = None,
|
||||
video_selection: VideoSelector | None = None,
|
||||
caption_selection: CaptionSelector | None = None,
|
||||
live_segment_target_duration_sec: int | None = None,
|
||||
live_segment_target_duration_tolerance_ms: int | None = None,
|
||||
start_time_ms: int | None = None,
|
||||
po_token: str | None = None,
|
||||
post_live: bool = False,
|
||||
video_id: str | None = None,
|
||||
):
|
||||
|
||||
self.logger = logger
|
||||
|
||||
self.video_playback_ustreamer_config = video_playback_ustreamer_config
|
||||
self.po_token = po_token
|
||||
self.client_info = client_info
|
||||
self.live_segment_target_duration_sec = live_segment_target_duration_sec or 5
|
||||
self.live_segment_target_duration_tolerance_ms = live_segment_target_duration_tolerance_ms or 100
|
||||
if self.live_segment_target_duration_tolerance_ms >= (self.live_segment_target_duration_sec * 1000) / 2:
|
||||
raise ValueError(
|
||||
'live_segment_target_duration_tolerance_ms must be less than '
|
||||
'half of live_segment_target_duration_sec in milliseconds',
|
||||
)
|
||||
self.start_time_ms = start_time_ms or 0
|
||||
if self.start_time_ms < 0:
|
||||
raise ValueError('start_time_ms must be greater than or equal to 0')
|
||||
|
||||
self.post_live = post_live
|
||||
self._is_live = False
|
||||
self.video_id = video_id
|
||||
|
||||
self._audio_format_selector = audio_selection
|
||||
self._video_format_selector = video_selection
|
||||
self._caption_format_selector = caption_selection
|
||||
|
||||
# IMPORTANT: initialized formats is assumed to contain only ACTIVE formats
|
||||
self.initialized_formats: dict[str, InitializedFormat] = {}
|
||||
self.stream_protection_status: StreamProtectionStatus.Status | None = None
|
||||
|
||||
self.partial_segments: dict[int, Segment] = {}
|
||||
self.total_duration_ms = None
|
||||
self.selected_audio_format_ids = []
|
||||
self.selected_video_format_ids = []
|
||||
self.selected_caption_format_ids = []
|
||||
self.next_request_policy: NextRequestPolicy | None = None
|
||||
self.live_metadata: LiveMetadata | None = None
|
||||
self.client_abr_state: ClientAbrState
|
||||
self.sabr_contexts_to_send: set[int] = set()
|
||||
self.sabr_context_updates: dict[int, SabrContextUpdate] = {}
|
||||
self._initialize_cabr_state()
|
||||
|
||||
@property
|
||||
def is_live(self):
|
||||
return bool(
|
||||
self.live_metadata
|
||||
or self._is_live,
|
||||
)
|
||||
|
||||
@is_live.setter
|
||||
def is_live(self, value: bool):
|
||||
self._is_live = value
|
||||
|
||||
def _initialize_cabr_state(self):
|
||||
# SABR supports: audio+video, audio+video+captions or audio-only.
|
||||
# For the other cases, we'll mark the tracks to be discarded (and fully buffered on initialization)
|
||||
|
||||
if not self._video_format_selector:
|
||||
self._video_format_selector = VideoSelector(display_name='video_ignore', discard_media=True)
|
||||
|
||||
if not self._audio_format_selector:
|
||||
self._audio_format_selector = AudioSelector(display_name='audio_ignore', discard_media=True)
|
||||
|
||||
if not self._caption_format_selector:
|
||||
self._caption_format_selector = CaptionSelector(display_name='caption_ignore', discard_media=True)
|
||||
|
||||
enabled_track_types_bitfield = 0 # Audio+Video
|
||||
|
||||
if self._video_format_selector.discard_media:
|
||||
enabled_track_types_bitfield = 1 # Audio only
|
||||
|
||||
if not self._caption_format_selector.discard_media:
|
||||
# SABR does not support caption-only or audio+captions only - can only get audio+video with captions
|
||||
# If audio or video is not selected, the tracks will be initialized but marked as buffered.
|
||||
enabled_track_types_bitfield = 7
|
||||
|
||||
self.selected_audio_format_ids = self._audio_format_selector.format_ids
|
||||
self.selected_video_format_ids = self._video_format_selector.format_ids
|
||||
self.selected_caption_format_ids = self._caption_format_selector.format_ids
|
||||
|
||||
self.logger.debug(f'Starting playback at: {self.start_time_ms}ms')
|
||||
self.client_abr_state = ClientAbrState(
|
||||
player_time_ms=self.start_time_ms,
|
||||
enabled_track_types_bitfield=enabled_track_types_bitfield,
|
||||
drc_enabled=True, # Required to stream DRC formats
|
||||
)
|
||||
|
||||
def match_format_selector(self, format_init_metadata):
|
||||
for format_selector in (self._video_format_selector, self._audio_format_selector, self._caption_format_selector):
|
||||
if not format_selector:
|
||||
continue
|
||||
if format_selector.match(format_id=format_init_metadata.format_id, mime_type=format_init_metadata.mime_type):
|
||||
return format_selector
|
||||
return None
|
||||
|
||||
def process_media_header(self, media_header: MediaHeader) -> ProcessMediaHeaderResult:
|
||||
if media_header.video_id and self.video_id and media_header.video_id != self.video_id:
|
||||
raise SabrStreamError(
|
||||
f'Received unexpected MediaHeader for video'
|
||||
f' {media_header.video_id} (expecting {self.video_id})')
|
||||
|
||||
if not media_header.format_id:
|
||||
raise SabrStreamError(f'FormatId not found in MediaHeader (media_header={media_header})')
|
||||
|
||||
# Guard. This should not happen, except if we don't clear partial segments
|
||||
if media_header.header_id in self.partial_segments:
|
||||
raise SabrStreamError(f'Header ID {media_header.header_id} already exists')
|
||||
|
||||
result = ProcessMediaHeaderResult()
|
||||
|
||||
initialized_format = self.initialized_formats.get(str(media_header.format_id))
|
||||
if not initialized_format:
|
||||
self.logger.debug(f'Initialized format not found for {media_header.format_id}')
|
||||
raise SabrStreamError(f'Initialized format not found for {media_header.format_id}')
|
||||
|
||||
if media_header.compression:
|
||||
# Unknown when this is used, but it is not supported currently
|
||||
raise SabrStreamError(f'Compression not supported in MediaHeader (media_header={media_header})')
|
||||
|
||||
sequence_number, is_init_segment = media_header.sequence_number, media_header.is_init_segment
|
||||
if sequence_number is None and not media_header.is_init_segment:
|
||||
raise SabrStreamError(f'Sequence number not found in MediaHeader (media_header={media_header})')
|
||||
|
||||
initialized_format.sequence_lmt = media_header.sequence_lmt
|
||||
|
||||
# Need to keep track of if we discard due to be consumed or not
|
||||
# for processing down the line (MediaEnd)
|
||||
consumed = False
|
||||
discard = initialized_format.discard
|
||||
|
||||
# Guard: Check if sequence number is within any existing consumed range
|
||||
# The server should not send us any segments that are already consumed
|
||||
# However, if retrying a request, we may get the same segment again
|
||||
if not is_init_segment and any(
|
||||
cr.start_sequence_number <= sequence_number <= cr.end_sequence_number
|
||||
for cr in initialized_format.consumed_ranges
|
||||
):
|
||||
self.logger.debug(f'{initialized_format.format_id} segment {sequence_number} already consumed, marking segment as consumed')
|
||||
consumed = True
|
||||
|
||||
# Validate that the segment is in order.
|
||||
# Note: If the format is to be discarded, we do not care about the order
|
||||
# and can expect uncommanded seeks as the consumer does not know about it.
|
||||
# Note: previous segment should never be an init segment.
|
||||
previous_segment = initialized_format.current_segment
|
||||
if (
|
||||
previous_segment and not is_init_segment
|
||||
and not previous_segment.discard and not discard and not consumed
|
||||
and sequence_number != previous_segment.sequence_number + 1
|
||||
):
|
||||
# Bail out as the segment is not in order when it is expected to be
|
||||
raise MediaSegmentMismatchError(
|
||||
expected_sequence_number=previous_segment.sequence_number + 1,
|
||||
received_sequence_number=sequence_number,
|
||||
format_id=media_header.format_id)
|
||||
|
||||
if initialized_format.init_segment and is_init_segment:
|
||||
self.logger.debug(
|
||||
f'Init segment {sequence_number} already seen for format {initialized_format.format_id}, marking segment as consumed')
|
||||
consumed = True
|
||||
|
||||
time_range = media_header.time_range
|
||||
start_ms = media_header.start_ms or (time_range and ticks_to_ms(time_range.start_ticks, time_range.timescale)) or 0
|
||||
|
||||
# Calculate duration of this segment
|
||||
# For videos, either duration_ms or time_range should be present
|
||||
# For live streams, calculate segment duration based on live metadata target segment duration
|
||||
actual_duration_ms = (
|
||||
media_header.duration_ms
|
||||
or (time_range and ticks_to_ms(time_range.duration_ticks, time_range.timescale)))
|
||||
|
||||
estimated_duration_ms = None
|
||||
if self.is_live:
|
||||
# Underestimate the duration of the segment slightly as
|
||||
# the real duration may be slightly shorter than the target duration.
|
||||
estimated_duration_ms = (self.live_segment_target_duration_sec * 1000) - self.live_segment_target_duration_tolerance_ms
|
||||
elif is_init_segment:
|
||||
estimated_duration_ms = 0
|
||||
|
||||
duration_ms = actual_duration_ms or estimated_duration_ms
|
||||
|
||||
# Guard: Bail out if we cannot determine the duration, which we need to progress.
|
||||
if duration_ms is None:
|
||||
raise SabrStreamError(f'Cannot determine duration of segment {sequence_number} (media_header={media_header})')
|
||||
|
||||
estimated_content_length = None
|
||||
if self.is_live and media_header.content_length is None and media_header.bitrate_bps is not None:
|
||||
estimated_content_length = math.ceil(media_header.bitrate_bps * (duration_ms / 1000))
|
||||
|
||||
segment = Segment(
|
||||
format_id=media_header.format_id,
|
||||
is_init_segment=is_init_segment,
|
||||
duration_ms=duration_ms,
|
||||
start_data_range=media_header.start_data_range,
|
||||
sequence_number=sequence_number,
|
||||
content_length=media_header.content_length or estimated_content_length,
|
||||
content_length_estimated=estimated_content_length is not None,
|
||||
start_ms=start_ms,
|
||||
initialized_format=initialized_format,
|
||||
duration_estimated=not actual_duration_ms,
|
||||
discard=discard or consumed,
|
||||
consumed=consumed,
|
||||
sequence_lmt=media_header.sequence_lmt,
|
||||
)
|
||||
|
||||
self.partial_segments[media_header.header_id] = segment
|
||||
|
||||
if not segment.discard:
|
||||
result.sabr_part = MediaSegmentInitSabrPart(
|
||||
format_selector=segment.initialized_format.format_selector,
|
||||
format_id=segment.format_id,
|
||||
player_time_ms=self.client_abr_state.player_time_ms,
|
||||
sequence_number=segment.sequence_number,
|
||||
total_segments=segment.initialized_format.total_segments,
|
||||
duration_ms=segment.duration_ms,
|
||||
duration_estimated=segment.duration_estimated,
|
||||
start_bytes=segment.start_data_range,
|
||||
start_time_ms=segment.start_ms,
|
||||
is_init_segment=segment.is_init_segment,
|
||||
content_length=segment.content_length,
|
||||
content_length_estimated=segment.content_length_estimated,
|
||||
)
|
||||
|
||||
self.logger.trace(
|
||||
f'Initialized Media Header {media_header.header_id} for sequence {sequence_number}. Segment: {segment}')
|
||||
|
||||
return result
|
||||
|
||||
def process_media(self, header_id: int, content_length: int, data: io.BufferedIOBase) -> ProcessMediaResult:
|
||||
result = ProcessMediaResult()
|
||||
segment = self.partial_segments.get(header_id)
|
||||
if not segment:
|
||||
self.logger.debug(f'Header ID {header_id} not found')
|
||||
raise SabrStreamError(f'Header ID {header_id} not found in partial segments')
|
||||
|
||||
segment_start_bytes = segment.received_data_length
|
||||
segment.received_data_length += content_length
|
||||
|
||||
if not segment.discard:
|
||||
result.sabr_part = MediaSegmentDataSabrPart(
|
||||
format_selector=segment.initialized_format.format_selector,
|
||||
format_id=segment.format_id,
|
||||
sequence_number=segment.sequence_number,
|
||||
is_init_segment=segment.is_init_segment,
|
||||
total_segments=segment.initialized_format.total_segments,
|
||||
data=data.read(),
|
||||
content_length=content_length,
|
||||
segment_start_bytes=segment_start_bytes,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def process_media_end(self, header_id: int) -> ProcessMediaEndResult:
|
||||
result = ProcessMediaEndResult()
|
||||
segment = self.partial_segments.pop(header_id, None)
|
||||
if not segment:
|
||||
self.logger.debug(f'Header ID {header_id} not found')
|
||||
raise SabrStreamError(f'Header ID {header_id} not found in partial segments')
|
||||
|
||||
self.logger.trace(
|
||||
f'MediaEnd for {segment.format_id} (sequence {segment.sequence_number}, data length = {segment.received_data_length})')
|
||||
|
||||
if segment.content_length is not None and segment.received_data_length != segment.content_length:
|
||||
if segment.content_length_estimated:
|
||||
self.logger.trace(
|
||||
f'Content length for {segment.format_id} (sequence {segment.sequence_number}) was estimated, '
|
||||
f'estimated {segment.content_length} bytes, got {segment.received_data_length} bytes')
|
||||
else:
|
||||
raise SabrStreamError(
|
||||
f'Content length mismatch for {segment.format_id} (sequence {segment.sequence_number}): '
|
||||
f'expected {segment.content_length} bytes, got {segment.received_data_length} bytes',
|
||||
)
|
||||
|
||||
# Only count received segments as new segments if they are not consumed.
|
||||
# Discarded segments that are not consumed are considered new segments.
|
||||
if not segment.consumed:
|
||||
result.is_new_segment = True
|
||||
|
||||
# Return the segment here instead of during MEDIA part(s) because:
|
||||
# 1. We can validate that we received the correct data length
|
||||
# 2. In the case of a retry during segment media, the partial data is not sent to the consumer
|
||||
if not segment.discard:
|
||||
# This needs to be yielded AFTER we have processed the segment
|
||||
# So the consumer can see the updated consumed ranges and use them for e.g. syncing between concurrent streams
|
||||
result.sabr_part = MediaSegmentEndSabrPart(
|
||||
format_selector=segment.initialized_format.format_selector,
|
||||
format_id=segment.format_id,
|
||||
sequence_number=segment.sequence_number,
|
||||
is_init_segment=segment.is_init_segment,
|
||||
total_segments=segment.initialized_format.total_segments,
|
||||
)
|
||||
else:
|
||||
self.logger.trace(f'Discarding media for {segment.initialized_format.format_id}')
|
||||
|
||||
if segment.is_init_segment:
|
||||
segment.initialized_format.init_segment = segment
|
||||
# Do not create a consumed range for init segments
|
||||
return result
|
||||
|
||||
if segment.initialized_format.current_segment and self.is_live:
|
||||
previous_segment = segment.initialized_format.current_segment
|
||||
self.logger.trace(
|
||||
f'Previous segment {previous_segment.sequence_number} for format {segment.format_id} '
|
||||
f'estimated duration difference from this segment ({segment.sequence_number}): {segment.start_ms - (previous_segment.start_ms + previous_segment.duration_ms)}ms')
|
||||
|
||||
segment.initialized_format.current_segment = segment
|
||||
|
||||
if segment.consumed:
|
||||
# Segment is already consumed, do not create a new consumed range. It was probably discarded.
|
||||
# This can be expected to happen in the case of video-only, where we discard the audio track (and mark it as entirely buffered)
|
||||
# We still want to create/update consumed range for discarded media IF it is not already consumed
|
||||
self.logger.debug(f'{segment.format_id} segment {segment.sequence_number} already consumed, not creating or updating consumed range (discard={segment.discard})')
|
||||
return result
|
||||
|
||||
# Try to find a consumed range for this segment in sequence
|
||||
consumed_range = next(
|
||||
(cr for cr in segment.initialized_format.consumed_ranges if cr.end_sequence_number == segment.sequence_number - 1),
|
||||
None,
|
||||
)
|
||||
|
||||
if not consumed_range:
|
||||
# Create a new consumed range starting from this segment
|
||||
segment.initialized_format.consumed_ranges.append(ConsumedRange(
|
||||
start_time_ms=segment.start_ms,
|
||||
duration_ms=segment.duration_ms,
|
||||
start_sequence_number=segment.sequence_number,
|
||||
end_sequence_number=segment.sequence_number,
|
||||
))
|
||||
self.logger.debug(f'Created new consumed range for {segment.initialized_format.format_id} {segment.initialized_format.consumed_ranges[-1]}')
|
||||
return result
|
||||
|
||||
# Update the existing consumed range to include this segment
|
||||
consumed_range.end_sequence_number = segment.sequence_number
|
||||
consumed_range.duration_ms = (segment.start_ms - consumed_range.start_time_ms) + segment.duration_ms
|
||||
|
||||
# TODO: Conduct a seek on consumed ranges
|
||||
|
||||
return result
|
||||
|
||||
def process_live_metadata(self, live_metadata: LiveMetadata) -> ProcessLiveMetadataResult:
|
||||
self.live_metadata = live_metadata
|
||||
if self.live_metadata.head_sequence_time_ms:
|
||||
self.total_duration_ms = self.live_metadata.head_sequence_time_ms
|
||||
|
||||
# If we have a head sequence number, we need to update the total sequences for each initialized format
|
||||
# For livestreams, it is not available in the format initialization metadata
|
||||
if self.live_metadata.head_sequence_number:
|
||||
for izf in self.initialized_formats.values():
|
||||
izf.total_segments = self.live_metadata.head_sequence_number
|
||||
|
||||
result = ProcessLiveMetadataResult()
|
||||
|
||||
# If the current player time is less than the min dvr time, simulate a server seek to the min dvr time.
|
||||
# The server SHOULD send us a SABR_SEEK part in this case, but it does not always happen (e.g. ANDROID_VR)
|
||||
# The server SHOULD NOT send us segments before the min dvr time, so we should assume that the player time is correct.
|
||||
min_seekable_time_ms = ticks_to_ms(self.live_metadata.min_seekable_time_ticks, self.live_metadata.min_seekable_timescale)
|
||||
if min_seekable_time_ms is not None and self.client_abr_state.player_time_ms < min_seekable_time_ms:
|
||||
self.logger.debug(f'Player time {self.client_abr_state.player_time_ms} is less than min seekable time {min_seekable_time_ms}, simulating server seek')
|
||||
self.client_abr_state.player_time_ms = min_seekable_time_ms
|
||||
|
||||
for izf in self.initialized_formats.values():
|
||||
izf.current_segment = None # Clear the current segment as we expect segments to no longer be in order.
|
||||
result.seek_sabr_parts.append(MediaSeekSabrPart(
|
||||
reason=MediaSeekSabrPart.Reason.SERVER_SEEK,
|
||||
format_id=izf.format_id,
|
||||
format_selector=izf.format_selector,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
def process_stream_protection_status(self, stream_protection_status: StreamProtectionStatus) -> ProcessStreamProtectionStatusResult:
|
||||
self.stream_protection_status = stream_protection_status.status
|
||||
status = stream_protection_status.status
|
||||
po_token = self.po_token
|
||||
|
||||
if status == StreamProtectionStatus.Status.OK:
|
||||
result_status = (
|
||||
PoTokenStatusSabrPart.PoTokenStatus.OK if po_token
|
||||
else PoTokenStatusSabrPart.PoTokenStatus.NOT_REQUIRED
|
||||
)
|
||||
elif status == StreamProtectionStatus.Status.ATTESTATION_PENDING:
|
||||
result_status = (
|
||||
PoTokenStatusSabrPart.PoTokenStatus.PENDING if po_token
|
||||
else PoTokenStatusSabrPart.PoTokenStatus.PENDING_MISSING
|
||||
)
|
||||
elif status == StreamProtectionStatus.Status.ATTESTATION_REQUIRED:
|
||||
result_status = (
|
||||
PoTokenStatusSabrPart.PoTokenStatus.INVALID if po_token
|
||||
else PoTokenStatusSabrPart.PoTokenStatus.MISSING
|
||||
)
|
||||
else:
|
||||
self.logger.warning(f'Received an unknown StreamProtectionStatus: {stream_protection_status}')
|
||||
result_status = None
|
||||
|
||||
sabr_part = PoTokenStatusSabrPart(status=result_status) if result_status is not None else None
|
||||
return ProcessStreamProtectionStatusResult(sabr_part)
|
||||
|
||||
def process_format_initialization_metadata(self, format_init_metadata: FormatInitializationMetadata) -> ProcessFormatInitializationMetadataResult:
|
||||
result = ProcessFormatInitializationMetadataResult()
|
||||
if str(format_init_metadata.format_id) in self.initialized_formats:
|
||||
self.logger.trace(f'Format {format_init_metadata.format_id} already initialized')
|
||||
return result
|
||||
|
||||
if format_init_metadata.video_id and self.video_id and format_init_metadata.video_id != self.video_id:
|
||||
raise SabrStreamError(
|
||||
f'Received unexpected Format Initialization Metadata for video'
|
||||
f' {format_init_metadata.video_id} (expecting {self.video_id})')
|
||||
|
||||
format_selector = self.match_format_selector(format_init_metadata)
|
||||
if not format_selector:
|
||||
# Should not happen. If we ignored the format the server may refuse to send us any more data
|
||||
raise SabrStreamError(f'Received format {format_init_metadata.format_id} but it does not match any format selector')
|
||||
|
||||
# Guard: Check if the format selector is already in use by another initialized format.
|
||||
# This can happen when the server changes the format to use (e.g. changing quality).
|
||||
#
|
||||
# Changing a format will require adding some logic to handle inactive formats.
|
||||
# Given we only provide one FormatId currently, and this should not occur in this case,
|
||||
# we will mark this as not currently supported and bail.
|
||||
for izf in self.initialized_formats.values():
|
||||
if izf.format_selector is format_selector:
|
||||
raise SabrStreamError('Server changed format. Changing formats is not currently supported')
|
||||
|
||||
duration_ms = ticks_to_ms(format_init_metadata.duration_ticks, format_init_metadata.duration_timescale)
|
||||
|
||||
total_segments = format_init_metadata.total_segments
|
||||
if not total_segments and self.live_metadata and self.live_metadata.head_sequence_number:
|
||||
total_segments = self.live_metadata.head_sequence_number
|
||||
|
||||
initialized_format = InitializedFormat(
|
||||
format_id=format_init_metadata.format_id,
|
||||
duration_ms=duration_ms,
|
||||
end_time_ms=format_init_metadata.end_time_ms,
|
||||
mime_type=format_init_metadata.mime_type,
|
||||
video_id=format_init_metadata.video_id,
|
||||
format_selector=format_selector,
|
||||
total_segments=total_segments,
|
||||
discard=format_selector.discard_media,
|
||||
)
|
||||
self.total_duration_ms = max(self.total_duration_ms or 0, format_init_metadata.end_time_ms or 0, duration_ms or 0)
|
||||
|
||||
if initialized_format.discard:
|
||||
# Mark the entire format as buffered into oblivion if we plan to discard all media.
|
||||
# This stops the server sending us any more data for this format.
|
||||
# Note: Using JS_MAX_SAFE_INTEGER but could use any maximum value as long as the server accepts it.
|
||||
initialized_format.consumed_ranges = [ConsumedRange(
|
||||
start_time_ms=0,
|
||||
duration_ms=(2**53) - 1,
|
||||
start_sequence_number=0,
|
||||
end_sequence_number=(2**53) - 1,
|
||||
)]
|
||||
|
||||
self.initialized_formats[str(format_init_metadata.format_id)] = initialized_format
|
||||
self.logger.debug(f'Initialized Format: {initialized_format}')
|
||||
|
||||
if not initialized_format.discard:
|
||||
result.sabr_part = FormatInitializedSabrPart(
|
||||
format_id=format_init_metadata.format_id,
|
||||
format_selector=format_selector,
|
||||
)
|
||||
|
||||
return ProcessFormatInitializationMetadataResult(sabr_part=result.sabr_part)
|
||||
|
||||
def process_next_request_policy(self, next_request_policy: NextRequestPolicy):
|
||||
self.next_request_policy = next_request_policy
|
||||
self.logger.trace(f'Registered new NextRequestPolicy: {self.next_request_policy}')
|
||||
|
||||
def process_sabr_seek(self, sabr_seek: SabrSeek) -> ProcessSabrSeekResult:
|
||||
seek_to = ticks_to_ms(sabr_seek.seek_time_ticks, sabr_seek.timescale)
|
||||
if seek_to is None:
|
||||
raise SabrStreamError(f'Server sent a SabrSeek part that is missing required seek data: {sabr_seek}')
|
||||
self.logger.debug(f'Seeking to {seek_to}ms')
|
||||
self.client_abr_state.player_time_ms = seek_to
|
||||
|
||||
result = ProcessSabrSeekResult()
|
||||
|
||||
# Clear latest segment of each initialized format
|
||||
# as we expect them to no longer be in order.
|
||||
for initialized_format in self.initialized_formats.values():
|
||||
initialized_format.current_segment = None
|
||||
result.seek_sabr_parts.append(MediaSeekSabrPart(
|
||||
reason=MediaSeekSabrPart.Reason.SERVER_SEEK,
|
||||
format_id=initialized_format.format_id,
|
||||
format_selector=initialized_format.format_selector,
|
||||
))
|
||||
return result
|
||||
|
||||
def process_sabr_context_update(self, sabr_ctx_update: SabrContextUpdate):
|
||||
if not (sabr_ctx_update.type and sabr_ctx_update.value and sabr_ctx_update.write_policy):
|
||||
self.logger.warning('Received an invalid SabrContextUpdate, ignoring')
|
||||
return
|
||||
|
||||
if (
|
||||
sabr_ctx_update.write_policy == SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING
|
||||
and sabr_ctx_update.type in self.sabr_context_updates
|
||||
):
|
||||
self.logger.debug(
|
||||
'Received a SABR Context Update with write_policy=KEEP_EXISTING'
|
||||
'matching an existing SABR Context Update. Ignoring update')
|
||||
return
|
||||
|
||||
self.logger.warning(
|
||||
'Received a SABR Context Update. YouTube is likely trying to force ads on the client. '
|
||||
'This may cause issues with playback.')
|
||||
|
||||
self.sabr_context_updates[sabr_ctx_update.type] = sabr_ctx_update
|
||||
if sabr_ctx_update.send_by_default:
|
||||
self.sabr_contexts_to_send.add(sabr_ctx_update.type)
|
||||
self.logger.debug(f'Registered SabrContextUpdate {sabr_ctx_update}')
|
||||
|
||||
def process_sabr_context_sending_policy(self, sabr_ctx_sending_policy: SabrContextSendingPolicy):
|
||||
for start_type in sabr_ctx_sending_policy.start_policy:
|
||||
if start_type not in self.sabr_contexts_to_send:
|
||||
self.logger.debug(f'Server requested to enable SABR Context Update for type {start_type}')
|
||||
self.sabr_contexts_to_send.add(start_type)
|
||||
|
||||
for stop_type in sabr_ctx_sending_policy.stop_policy:
|
||||
if stop_type in self.sabr_contexts_to_send:
|
||||
self.logger.debug(f'Server requested to disable SABR Context Update for type {stop_type}')
|
||||
self.sabr_contexts_to_send.remove(stop_type)
|
||||
|
||||
for discard_type in sabr_ctx_sending_policy.discard_policy:
|
||||
if discard_type in self.sabr_context_updates:
|
||||
self.logger.debug(f'Server requested to discard SABR Context Update for type {discard_type}')
|
||||
self.sabr_context_updates.pop(discard_type, None)
|
||||
|
||||
|
||||
def build_vpabr_request(processor: SabrProcessor):
|
||||
return VideoPlaybackAbrRequest(
|
||||
client_abr_state=processor.client_abr_state,
|
||||
selected_video_format_ids=processor.selected_video_format_ids,
|
||||
selected_audio_format_ids=processor.selected_audio_format_ids,
|
||||
selected_caption_format_ids=processor.selected_caption_format_ids,
|
||||
initialized_format_ids=[
|
||||
initialized_format.format_id for initialized_format in processor.initialized_formats.values()
|
||||
],
|
||||
video_playback_ustreamer_config=base64.urlsafe_b64decode(processor.video_playback_ustreamer_config),
|
||||
streamer_context=StreamerContext(
|
||||
po_token=base64.urlsafe_b64decode(processor.po_token) if processor.po_token is not None else None,
|
||||
playback_cookie=processor.next_request_policy.playback_cookie if processor.next_request_policy is not None else None,
|
||||
client_info=processor.client_info,
|
||||
sabr_contexts=[
|
||||
SabrContext(context.type, context.value)
|
||||
for context in processor.sabr_context_updates.values()
|
||||
if context.type in processor.sabr_contexts_to_send
|
||||
],
|
||||
unsent_sabr_contexts=[
|
||||
context_type for context_type in processor.sabr_contexts_to_send
|
||||
if context_type not in processor.sabr_context_updates
|
||||
],
|
||||
),
|
||||
buffered_ranges=[
|
||||
BufferedRange(
|
||||
format_id=initialized_format.format_id,
|
||||
start_segment_index=cr.start_sequence_number,
|
||||
end_segment_index=cr.end_sequence_number,
|
||||
start_time_ms=cr.start_time_ms,
|
||||
duration_ms=cr.duration_ms,
|
||||
time_range=TimeRange(
|
||||
start_ticks=cr.start_time_ms,
|
||||
duration_ticks=cr.duration_ms,
|
||||
timescale=1000,
|
||||
),
|
||||
) for initialized_format in processor.initialized_formats.values()
|
||||
for cr in initialized_format.consumed_ranges
|
||||
],
|
||||
)
|
@ -0,0 +1,813 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import datetime as dt
|
||||
import math
|
||||
import time
|
||||
import typing
|
||||
import urllib.parse
|
||||
|
||||
from yt_dlp.dependencies import protobug
|
||||
from yt_dlp.extractor.youtube._proto import unknown_fields
|
||||
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import (
|
||||
FormatInitializationMetadata,
|
||||
LiveMetadata,
|
||||
MediaHeader,
|
||||
ReloadPlayerResponse,
|
||||
SabrContextSendingPolicy,
|
||||
SabrContextUpdate,
|
||||
SabrError,
|
||||
SabrRedirect,
|
||||
SabrSeek,
|
||||
StreamProtectionStatus,
|
||||
)
|
||||
from yt_dlp.networking import Request, Response
|
||||
from yt_dlp.networking.exceptions import HTTPError, TransportError
|
||||
from yt_dlp.utils import RetryManager, int_or_none, parse_qs, str_or_none, traverse_obj
|
||||
|
||||
from .exceptions import MediaSegmentMismatchError, PoTokenError, SabrStreamConsumedError, SabrStreamError
|
||||
from .models import AudioSelector, CaptionSelector, SabrLogger, VideoSelector
|
||||
from .part import (
|
||||
MediaSeekSabrPart,
|
||||
RefreshPlayerResponseSabrPart,
|
||||
)
|
||||
from .processor import SabrProcessor, build_vpabr_request
|
||||
from .utils import broadcast_id_from_url, get_cr_chain, next_gvs_fallback_url
|
||||
from ..ump import UMPDecoder, UMPPart, UMPPartId, read_varint
|
||||
|
||||
|
||||
class SabrStream:
|
||||
|
||||
"""
|
||||
|
||||
A YouTube SABR (Server Adaptive Bit Rate) client implementation designed for downloading streams and videos.
|
||||
|
||||
It presents an iterator (iter_parts) that yields the next available segments and other metadata.
|
||||
|
||||
Parameters:
|
||||
@param urlopen: A callable that takes a Request and returns a Response. Raises TransportError or HTTPError on failure.
|
||||
@param logger: The logger.
|
||||
@param server_abr_streaming_url: SABR streaming URL.
|
||||
@param video_playback_ustreamer_config: The base64url encoded ustreamer config.
|
||||
@param client_info: The Innertube client info.
|
||||
@param audio_selection: The audio format selector to use for audio.
|
||||
@param video_selection: The video format selector to use for video.
|
||||
@param caption_selection: The caption format selector to use for captions.
|
||||
@param live_segment_target_duration_sec: The target duration of live segments in seconds.
|
||||
@param live_segment_target_duration_tolerance_ms: The tolerance to accept for estimated duration of live segment in milliseconds.
|
||||
@param start_time_ms: The time in milliseconds to start playback from.
|
||||
@param po_token: Initial GVS PO Token.
|
||||
@param http_retries: The maximum number of times to retry a request before failing.
|
||||
@param pot_retries: The maximum number of times to retry PO Token errors before failing.
|
||||
@param host_fallback_threshold: The number of consecutive retries before falling back to the next GVS server.
|
||||
@param max_empty_requests: The maximum number of consecutive requests with no new segments before giving up.
|
||||
@param live_end_wait_sec: The number of seconds to wait after the last received segment before considering the live stream ended.
|
||||
@param live_end_segment_tolerance: The number of segments before the live head segment at which the livestream is allowed to end. Defaults to 10.
|
||||
@param post_live: Whether the live stream is in post-live mode. Used to determine how to handle the end of the stream.
|
||||
@param video_id: The video ID of the YouTube video. Used for validating received data is for the correct video.
|
||||
@param retry_sleep_func: A function to sleep between retries. If None, will not sleep between retries.
|
||||
@param expiry_threshold_sec: The number of seconds before the GVS expiry to consider it expired. Defaults to 1 minute.
|
||||
"""
|
||||
|
||||
# Used for debugging
|
||||
_IGNORED_PARTS = (
|
||||
UMPPartId.REQUEST_IDENTIFIER,
|
||||
UMPPartId.REQUEST_CANCELLATION_POLICY,
|
||||
UMPPartId.PLAYBACK_START_POLICY,
|
||||
UMPPartId.ALLOWED_CACHED_FORMATS,
|
||||
UMPPartId.PAUSE_BW_SAMPLING_HINT,
|
||||
UMPPartId.START_BW_SAMPLING_HINT,
|
||||
UMPPartId.REQUEST_PIPELINING,
|
||||
UMPPartId.SELECTABLE_FORMATS,
|
||||
UMPPartId.PREWARM_CONNECTION,
|
||||
)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _NoSegmentsTracker:
|
||||
consecutive_requests: int = 0
|
||||
timestamp_started: float | None = None
|
||||
live_head_segment_started: int | None = None
|
||||
|
||||
def reset(self):
|
||||
self.consecutive_requests = 0
|
||||
self.timestamp_started = None
|
||||
self.live_head_segment_started = None
|
||||
|
||||
def increment(self, live_head_segment=None):
|
||||
if self.consecutive_requests == 0:
|
||||
self.timestamp_started = time.time()
|
||||
self.live_head_segment_started = live_head_segment
|
||||
self.consecutive_requests += 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urlopen: typing.Callable[[Request], Response],
|
||||
logger: SabrLogger,
|
||||
server_abr_streaming_url: str,
|
||||
video_playback_ustreamer_config: str,
|
||||
client_info: ClientInfo,
|
||||
audio_selection: AudioSelector | None = None,
|
||||
video_selection: VideoSelector | None = None,
|
||||
caption_selection: CaptionSelector | None = None,
|
||||
live_segment_target_duration_sec: int | None = None,
|
||||
live_segment_target_duration_tolerance_ms: int | None = None,
|
||||
start_time_ms: int | None = None,
|
||||
po_token: str | None = None,
|
||||
http_retries: int | None = None,
|
||||
pot_retries: int | None = None,
|
||||
host_fallback_threshold: int | None = None,
|
||||
max_empty_requests: int | None = None,
|
||||
live_end_wait_sec: int | None = None,
|
||||
live_end_segment_tolerance: int | None = None,
|
||||
post_live: bool = False,
|
||||
video_id: str | None = None,
|
||||
retry_sleep_func: int | None = None,
|
||||
expiry_threshold_sec: int | None = None,
|
||||
):
|
||||
|
||||
self.logger = logger
|
||||
self._urlopen = urlopen
|
||||
|
||||
self.processor = SabrProcessor(
|
||||
logger=logger,
|
||||
video_playback_ustreamer_config=video_playback_ustreamer_config,
|
||||
client_info=client_info,
|
||||
audio_selection=audio_selection,
|
||||
video_selection=video_selection,
|
||||
caption_selection=caption_selection,
|
||||
live_segment_target_duration_sec=live_segment_target_duration_sec,
|
||||
live_segment_target_duration_tolerance_ms=live_segment_target_duration_tolerance_ms,
|
||||
start_time_ms=start_time_ms,
|
||||
po_token=po_token,
|
||||
post_live=post_live,
|
||||
video_id=video_id,
|
||||
)
|
||||
self.url = server_abr_streaming_url
|
||||
self.http_retries = http_retries or 10
|
||||
self.pot_retries = pot_retries or 5
|
||||
self.host_fallback_threshold = host_fallback_threshold or 8
|
||||
self.max_empty_requests = max_empty_requests or 3
|
||||
self.live_end_wait_sec = live_end_wait_sec or max(10, self.max_empty_requests * self.processor.live_segment_target_duration_sec)
|
||||
self.live_end_segment_tolerance = live_end_segment_tolerance or 10
|
||||
self.expiry_threshold_sec = expiry_threshold_sec or 60 # 60 seconds
|
||||
if self.expiry_threshold_sec <= 0:
|
||||
raise ValueError('expiry_threshold_sec must be greater than 0')
|
||||
if self.max_empty_requests <= 0:
|
||||
raise ValueError('max_empty_requests must be greater than 0')
|
||||
self.retry_sleep_func = retry_sleep_func
|
||||
self._request_number = 0
|
||||
|
||||
# Whether we got any new (not consumed) segments in the request.
|
||||
self._received_new_segments = False
|
||||
self._no_new_segments_tracker = SabrStream._NoSegmentsTracker()
|
||||
self._sps_retry_manager: typing.Generator | None = None
|
||||
self._current_sps_retry = None
|
||||
self._http_retry_manager: typing.Generator | None = None
|
||||
self._current_http_retry = None
|
||||
self._unknown_part_types = set()
|
||||
|
||||
# Whether the current request is a result of a retry
|
||||
self._is_retry = False
|
||||
|
||||
self._consumed = False
|
||||
self._sq_mismatch_backtrack_count = 0
|
||||
self._sq_mismatch_forward_count = 0
|
||||
|
||||
def close(self):
|
||||
self._consumed = True
|
||||
|
||||
def __iter__(self):
|
||||
return self.iter_parts()
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return self._url
|
||||
|
||||
@url.setter
|
||||
def url(self, url):
|
||||
self.logger.debug(f'New URL: {url}')
|
||||
if self.processor.is_live and hasattr(self, '_url') and ((bn := broadcast_id_from_url(url)) != (bc := broadcast_id_from_url(self.url))):
|
||||
raise SabrStreamError(f'Broadcast ID changed from {bc} to {bn}. The download will need to be restarted.')
|
||||
self._url = url
|
||||
if str_or_none(parse_qs(url).get('source', [None])[0]) == 'yt_live_broadcast':
|
||||
self.processor.is_live = True
|
||||
|
||||
def iter_parts(self):
|
||||
if self._consumed:
|
||||
raise SabrStreamConsumedError('SABR stream has already been consumed')
|
||||
|
||||
self._http_retry_manager = None
|
||||
self._sps_retry_manager = None
|
||||
|
||||
def report_retry(err, count, retries, fatal=True):
|
||||
if count >= self.host_fallback_threshold:
|
||||
self._process_fallback_server()
|
||||
RetryManager.report_retry(
|
||||
err, count, retries, info=self.logger.info,
|
||||
warn=lambda msg: self.logger.warning(f'[sabr] Got error: {msg}'),
|
||||
error=None if fatal else lambda msg: self.logger.warning(f'[sabr] Got error: {msg}'),
|
||||
sleep_func=self.retry_sleep_func,
|
||||
)
|
||||
|
||||
def report_sps_retry(err, count, retries, fatal=True):
|
||||
RetryManager.report_retry(
|
||||
err, count, retries, info=self.logger.info,
|
||||
warn=lambda msg: self.logger.warning(f'[sabr] Got error: {msg}'),
|
||||
error=None if fatal else lambda msg: self.logger.warning(f'[sabr] Got error: {msg}'),
|
||||
sleep_func=self.retry_sleep_func,
|
||||
)
|
||||
|
||||
while not self._consumed:
|
||||
if self._http_retry_manager is None:
|
||||
self._http_retry_manager = iter(RetryManager(self.http_retries, report_retry))
|
||||
|
||||
if self._sps_retry_manager is None:
|
||||
self._sps_retry_manager = iter(RetryManager(self.pot_retries, report_sps_retry))
|
||||
|
||||
self._current_http_retry = next(self._http_retry_manager)
|
||||
self._current_sps_retry = next(self._sps_retry_manager)
|
||||
|
||||
self._log_state()
|
||||
|
||||
yield from self._process_expiry()
|
||||
vpabr = build_vpabr_request(self.processor)
|
||||
payload = protobug.dumps(vpabr)
|
||||
self.logger.trace(f'Ustreamer Config: {self.processor.video_playback_ustreamer_config}')
|
||||
self.logger.trace(f'Sending SABR request: {vpabr}')
|
||||
|
||||
response = None
|
||||
try:
|
||||
response = self._urlopen(
|
||||
Request(
|
||||
url=self.url,
|
||||
method='POST',
|
||||
data=payload,
|
||||
query={'rn': self._request_number},
|
||||
headers={
|
||||
'content-type': 'application/x-protobuf',
|
||||
'accept-encoding': 'identity',
|
||||
'accept': 'application/vnd.yt-ump',
|
||||
},
|
||||
),
|
||||
)
|
||||
self._request_number += 1
|
||||
except TransportError as e:
|
||||
self._current_http_retry.error = e
|
||||
except HTTPError as e:
|
||||
# retry on 5xx errors only
|
||||
if 500 <= e.status < 600:
|
||||
self._current_http_retry.error = e
|
||||
else:
|
||||
raise SabrStreamError(f'HTTP Error: {e.status} - {e.reason}')
|
||||
|
||||
if response:
|
||||
try:
|
||||
yield from self._parse_ump_response(response)
|
||||
except TransportError as e:
|
||||
self._current_http_retry.error = e
|
||||
|
||||
if not response.closed:
|
||||
response.close()
|
||||
|
||||
self._validate_response_integrity()
|
||||
self._process_sps_retry()
|
||||
|
||||
if not self._current_http_retry.error:
|
||||
self._http_retry_manager = None
|
||||
|
||||
if not self._current_sps_retry.error:
|
||||
self._sps_retry_manager = None
|
||||
|
||||
retry_next_request = bool(self._current_http_retry.error or self._current_sps_retry.error)
|
||||
|
||||
# We are expecting to stay in the same place for a retry
|
||||
if not retry_next_request:
|
||||
# Only increment request no segments number if we are not retrying
|
||||
self._process_request_had_segments()
|
||||
|
||||
# Calculate and apply the next playback time to skip to
|
||||
yield from self._prepare_next_playback_time()
|
||||
|
||||
# Request successfully processed, next request is not a retry
|
||||
self._is_retry = False
|
||||
else:
|
||||
self._is_retry = True
|
||||
|
||||
self._received_new_segments = False
|
||||
|
||||
self._consumed = True
|
||||
|
||||
def _process_sps_retry(self):
|
||||
error = PoTokenError(missing=not self.processor.po_token)
|
||||
|
||||
if self.processor.stream_protection_status == StreamProtectionStatus.Status.ATTESTATION_REQUIRED:
|
||||
# Always start retrying immediately on ATTESTATION_REQUIRED
|
||||
self._current_sps_retry.error = error
|
||||
return
|
||||
|
||||
elif (
|
||||
self.processor.stream_protection_status == StreamProtectionStatus.Status.ATTESTATION_PENDING
|
||||
and self._no_new_segments_tracker.consecutive_requests >= self.max_empty_requests
|
||||
and (not self.processor.is_live or self.processor.stream_protection_status or (
|
||||
self.processor.live_metadata is not None
|
||||
and self._no_new_segments_tracker.live_head_segment_started != self.processor.live_metadata.head_sequence_number)
|
||||
)
|
||||
):
|
||||
# Sometimes YouTube sends no data on ATTESTATION_PENDING, so in this case we need to count retries to fail on a PO Token error.
|
||||
# We only start counting retries after max_empty_requests in case of intermittent no data that we need to increase the player time on.
|
||||
# For livestreams when we receive no new segments, this could be due the stream ending or ATTESTATION_PENDING.
|
||||
# To differentiate the two, we check if the live head segment has changed during the time we start getting no new segments.
|
||||
# xxx: not perfect detection, sometimes get a new segment we can never fetch (partial)
|
||||
self._current_sps_retry.error = error
|
||||
return
|
||||
|
||||
def _process_request_had_segments(self):
|
||||
if not self._received_new_segments:
|
||||
self._no_new_segments_tracker.increment(
|
||||
live_head_segment=self.processor.live_metadata.head_sequence_number if self.processor.live_metadata else None)
|
||||
self.logger.trace(f'No new segments received in request {self._request_number}, count: {self._no_new_segments_tracker.consecutive_requests}')
|
||||
else:
|
||||
self._no_new_segments_tracker.reset()
|
||||
|
||||
def _validate_response_integrity(self):
|
||||
if not len(self.processor.partial_segments):
|
||||
return
|
||||
|
||||
msg = 'Received partial segments: ' + ', '.join(
|
||||
f'{seg.format_id}: {seg.sequence_number}'
|
||||
for seg in self.processor.partial_segments.values()
|
||||
)
|
||||
if self.processor.is_live:
|
||||
# In post live, sometimes we get a partial segment at the end of the stream that should be ignored.
|
||||
# If this occurs mid-stream, other guards should prevent corruption.
|
||||
if (
|
||||
self.processor.live_metadata
|
||||
# TODO: generalize
|
||||
and self.processor.client_abr_state.player_time_ms >= (
|
||||
self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.live_end_segment_tolerance))
|
||||
):
|
||||
# Only log a warning if we are not near the head of a stream
|
||||
self.logger.debug(msg)
|
||||
else:
|
||||
self.logger.warning(msg)
|
||||
else:
|
||||
# Should not happen for videos
|
||||
self._current_http_retry.error = SabrStreamError(msg)
|
||||
|
||||
self.processor.partial_segments.clear()
|
||||
|
||||
def _prepare_next_playback_time(self):
|
||||
# TODO: refactor and cleanup this massive function
|
||||
wait_seconds = 0
|
||||
for izf in self.processor.initialized_formats.values():
|
||||
if not izf.current_segment:
|
||||
continue
|
||||
|
||||
# Guard: Check that the segment is not in multiple consumed ranges
|
||||
# This should not happen, but if it does, we should bail
|
||||
count = sum(
|
||||
1 for cr in izf.consumed_ranges
|
||||
if cr.start_sequence_number <= izf.current_segment.sequence_number <= cr.end_sequence_number
|
||||
)
|
||||
|
||||
if count > 1:
|
||||
raise SabrStreamError(f'Segment {izf.current_segment.sequence_number} for format {izf.format_id} in {count} consumed ranges')
|
||||
|
||||
# Check if there is two or more consumed ranges where the end lines up with the start of the other.
|
||||
# This could happen in the case of concurrent playback.
|
||||
# In this case, we should consider a seek as acceptable to the end of the other.
|
||||
# Note: It is assumed a segment is only present in one consumed range - it should not be allowed in multiple (by process media header)
|
||||
prev_consumed_range = next(
|
||||
(cr for cr in izf.consumed_ranges if cr.end_sequence_number == izf.current_segment.sequence_number),
|
||||
None,
|
||||
)
|
||||
# TODO: move to processor MEDIA_END
|
||||
if prev_consumed_range and len(get_cr_chain(prev_consumed_range, izf.consumed_ranges)) >= 2:
|
||||
self.logger.debug(f'Found two or more consumed ranges that line up, allowing a seek for format {izf.format_id}')
|
||||
izf.current_segment = None
|
||||
yield MediaSeekSabrPart(
|
||||
reason=MediaSeekSabrPart.Reason.CONSUMED_SEEK,
|
||||
format_id=izf.format_id,
|
||||
format_selector=izf.format_selector)
|
||||
|
||||
enabled_initialized_formats = [izf for izf in self.processor.initialized_formats.values() if not izf.discard]
|
||||
|
||||
# For each initialized format:
|
||||
# 1. find the consumed format that matches player_time_ms.
|
||||
# 2. find the current consumed range in sequence (in case multiple are joined together)
|
||||
# For livestreams, we allow a tolerance for the segment duration as it is estimated. This tolerance should be less than the segment duration / 2.
|
||||
|
||||
cr_tolerance_ms = 0
|
||||
if self.processor.is_live:
|
||||
cr_tolerance_ms = self.processor.live_segment_target_duration_tolerance_ms
|
||||
|
||||
current_consumed_ranges = []
|
||||
for izf in enabled_initialized_formats:
|
||||
for cr in sorted(izf.consumed_ranges, key=lambda cr: cr.start_sequence_number):
|
||||
if (cr.start_time_ms - cr_tolerance_ms) <= self.processor.client_abr_state.player_time_ms <= cr.start_time_ms + cr.duration_ms + (cr_tolerance_ms * 2):
|
||||
chain = get_cr_chain(cr, izf.consumed_ranges)
|
||||
current_consumed_ranges.append(chain[-1])
|
||||
# There should only be one chain for the current player_time_ms (including the tolerance)
|
||||
break
|
||||
|
||||
min_consumed_duration_ms = None
|
||||
|
||||
# Get the lowest consumed range end time out of all current consumed ranges.
|
||||
if current_consumed_ranges:
|
||||
lowest_izf_consumed_range = min(current_consumed_ranges, key=lambda cr: cr.start_time_ms + cr.duration_ms)
|
||||
min_consumed_duration_ms = lowest_izf_consumed_range.start_time_ms + lowest_izf_consumed_range.duration_ms
|
||||
|
||||
if len(current_consumed_ranges) != len(enabled_initialized_formats) or min_consumed_duration_ms is None:
|
||||
# Missing a consumed range for a format.
|
||||
# In this case, consider player_time_ms to be our correct next time
|
||||
# May happen in the case of:
|
||||
# 1. A Format has not been initialized yet (can happen if response read fails)
|
||||
# or
|
||||
# 1. SABR_SEEK to time outside both formats consumed ranges
|
||||
# 2. ONE of the formats returns data after the SABR_SEEK in that request
|
||||
if min_consumed_duration_ms is None:
|
||||
min_consumed_duration_ms = self.processor.client_abr_state.player_time_ms
|
||||
else:
|
||||
min_consumed_duration_ms = min(min_consumed_duration_ms, self.processor.client_abr_state.player_time_ms)
|
||||
|
||||
# Usually provided by the server if there was no segments returned.
|
||||
# We'll use this to calculate the time to wait for the next request (for live streams).
|
||||
next_request_backoff_ms = (self.processor.next_request_policy and self.processor.next_request_policy.backoff_time_ms) or 0
|
||||
|
||||
request_player_time = self.processor.client_abr_state.player_time_ms
|
||||
self.logger.trace(f'min consumed duration ms: {min_consumed_duration_ms}')
|
||||
self.processor.client_abr_state.player_time_ms = min_consumed_duration_ms
|
||||
|
||||
# Check if the latest segment is the last one of each format (if data is available)
|
||||
if (
|
||||
not (self.processor.is_live and not self.processor.post_live)
|
||||
and enabled_initialized_formats
|
||||
and len(current_consumed_ranges) == len(enabled_initialized_formats)
|
||||
and all(
|
||||
(
|
||||
initialized_format.total_segments is not None
|
||||
# consumed ranges are not guaranteed to be in order
|
||||
and sorted(
|
||||
initialized_format.consumed_ranges,
|
||||
key=lambda cr: cr.end_sequence_number,
|
||||
)[-1].end_sequence_number >= initialized_format.total_segments
|
||||
)
|
||||
for initialized_format in enabled_initialized_formats
|
||||
)
|
||||
):
|
||||
self.logger.debug('Reached last expected segment for all enabled formats, assuming end of media')
|
||||
self._consumed = True
|
||||
|
||||
# Check if we have exceeded the total duration of the media (if not live),
|
||||
# or wait for the next segment (if live)
|
||||
elif self.processor.total_duration_ms and (self.processor.client_abr_state.player_time_ms >= self.processor.total_duration_ms):
|
||||
if self.processor.is_live:
|
||||
self.logger.trace(f'setting player time ms ({self.processor.client_abr_state.player_time_ms}) to total duration ms ({self.processor.total_duration_ms})')
|
||||
self.processor.client_abr_state.player_time_ms = self.processor.total_duration_ms
|
||||
if (
|
||||
self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
|
||||
and not self._is_retry
|
||||
and self._no_new_segments_tracker.timestamp_started < time.time() + self.live_end_wait_sec
|
||||
):
|
||||
self.logger.debug(f'No new segments received for at least {self.live_end_wait_sec} seconds, assuming end of live stream')
|
||||
self._consumed = True
|
||||
else:
|
||||
wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec)
|
||||
else:
|
||||
self.logger.debug(f'End of media (player time ms {self.processor.client_abr_state.player_time_ms} >= total duration ms {self.processor.total_duration_ms})')
|
||||
self._consumed = True
|
||||
|
||||
# Handle receiving no new segments before end the end of the video/stream
|
||||
# For videos, if exceeds max_empty_requests, this should not happen so we raise an error
|
||||
# For livestreams, if we exceed max_empty_requests, and we don't have live_metadata,
|
||||
# and have not received any data for a while, we can assume the stream has ended (as we don't know the head sequence number)
|
||||
elif (
|
||||
# Determine if we are receiving no segments as the live stream has ended.
|
||||
# Allow some tolerance the head segment may not be able to be received.
|
||||
self.processor.is_live and not self.processor.post_live
|
||||
and (
|
||||
getattr(self.processor.live_metadata, 'head_sequence_number', None) is None
|
||||
or (
|
||||
enabled_initialized_formats
|
||||
and len(current_consumed_ranges) == len(enabled_initialized_formats)
|
||||
and all(
|
||||
(
|
||||
initialized_format.total_segments is not None
|
||||
and sorted(
|
||||
initialized_format.consumed_ranges,
|
||||
key=lambda cr: cr.end_sequence_number,
|
||||
)[-1].end_sequence_number
|
||||
>= initialized_format.total_segments - self.live_end_segment_tolerance
|
||||
)
|
||||
for initialized_format in enabled_initialized_formats
|
||||
)
|
||||
)
|
||||
or self.processor.live_metadata.head_sequence_time_ms is None
|
||||
or (
|
||||
# Sometimes we receive a partial segment at the end of the stream
|
||||
# and the server seeks us to the end of the current segment.
|
||||
# As our consumed range for this segment has an estimated end time,
|
||||
# this may be slightly off what the server seeks to.
|
||||
# This can cause the player time to be slightly outside the consumed range.
|
||||
#
|
||||
# Because of this, we should also check the player time against
|
||||
# the head segment time using the estimated segment duration.
|
||||
# xxx: consider also taking into account the max seekable timestamp
|
||||
request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.live_end_segment_tolerance)
|
||||
)
|
||||
)
|
||||
):
|
||||
if (
|
||||
not self._is_retry # allow us to sleep on a retry
|
||||
and self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
|
||||
and self._no_new_segments_tracker.timestamp_started < time.time() + self.live_end_wait_sec
|
||||
):
|
||||
self.logger.debug(f'No new segments received for at least {self.live_end_wait_sec} seconds; assuming end of live stream')
|
||||
self._consumed = True
|
||||
elif self._no_new_segments_tracker.consecutive_requests >= 1:
|
||||
# Sometimes we can't get the head segment - rather tend to sit behind the head segment for the duration of the livestream
|
||||
wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec)
|
||||
elif (
|
||||
self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
|
||||
and not self._is_retry
|
||||
):
|
||||
raise SabrStreamError('No new segments received in three consecutive requests')
|
||||
|
||||
elif (
|
||||
not self.processor.is_live and next_request_backoff_ms
|
||||
and self._no_new_segments_tracker.consecutive_requests >= 1
|
||||
and any(t in self.processor.sabr_contexts_to_send for t in self.processor.sabr_context_updates)
|
||||
):
|
||||
wait_seconds = math.ceil(next_request_backoff_ms / 1000)
|
||||
self.logger.info(f'The server is requiring yt-dlp to wait {wait_seconds} seconds before continuing due to ad enforcement')
|
||||
|
||||
if wait_seconds:
|
||||
self.logger.debug(f'Waiting {wait_seconds} seconds for next segment(s)')
|
||||
time.sleep(wait_seconds)
|
||||
|
||||
def _parse_ump_response(self, response):
|
||||
self._unknown_part_types.clear()
|
||||
ump = UMPDecoder(response)
|
||||
for part in ump.iter_parts():
|
||||
if part.part_id == UMPPartId.MEDIA_HEADER:
|
||||
yield from self._process_media_header(part)
|
||||
elif part.part_id == UMPPartId.MEDIA:
|
||||
yield from self._process_media(part)
|
||||
elif part.part_id == UMPPartId.MEDIA_END:
|
||||
yield from self._process_media_end(part)
|
||||
elif part.part_id == UMPPartId.STREAM_PROTECTION_STATUS:
|
||||
yield from self._process_stream_protection_status(part)
|
||||
elif part.part_id == UMPPartId.SABR_REDIRECT:
|
||||
self._process_sabr_redirect(part)
|
||||
elif part.part_id == UMPPartId.FORMAT_INITIALIZATION_METADATA:
|
||||
yield from self._process_format_initialization_metadata(part)
|
||||
elif part.part_id == UMPPartId.NEXT_REQUEST_POLICY:
|
||||
self._process_next_request_policy(part)
|
||||
elif part.part_id == UMPPartId.LIVE_METADATA:
|
||||
yield from self._process_live_metadata(part)
|
||||
elif part.part_id == UMPPartId.SABR_SEEK:
|
||||
yield from self._process_sabr_seek(part)
|
||||
elif part.part_id == UMPPartId.SABR_ERROR:
|
||||
self._process_sabr_error(part)
|
||||
elif part.part_id == UMPPartId.SABR_CONTEXT_UPDATE:
|
||||
self._process_sabr_context_update(part)
|
||||
elif part.part_id == UMPPartId.SABR_CONTEXT_SENDING_POLICY:
|
||||
self._process_sabr_context_sending_policy(part)
|
||||
elif part.part_id == UMPPartId.RELOAD_PLAYER_RESPONSE:
|
||||
yield from self._process_reload_player_response(part)
|
||||
else:
|
||||
if part.part_id not in self._IGNORED_PARTS:
|
||||
self._unknown_part_types.add(part.part_id)
|
||||
self._log_part(part, msg='Unhandled part type', data=part.data.read())
|
||||
|
||||
# Cancel request processing if we are going to retry
|
||||
if self._current_sps_retry.error or self._current_http_retry.error:
|
||||
self.logger.debug('Request processing cancelled')
|
||||
return
|
||||
|
||||
def _process_media_header(self, part: UMPPart):
|
||||
media_header = protobug.load(part.data, MediaHeader)
|
||||
self._log_part(part=part, protobug_obj=media_header)
|
||||
|
||||
try:
|
||||
result = self.processor.process_media_header(media_header)
|
||||
if result.sabr_part:
|
||||
yield result.sabr_part
|
||||
except MediaSegmentMismatchError as e:
|
||||
# For livestreams, the server may not know the exact segment for a given player time.
|
||||
# For segments near stream head, it estimates using segment duration, which can cause off-by-one segment mismatches.
|
||||
# If a segment is much longer or shorter than expected, the server may return a segment ahead or behind.
|
||||
# In such cases, retry with an adjusted player time to resync.
|
||||
if self.processor.is_live and e.received_sequence_number == e.expected_sequence_number - 1:
|
||||
# The segment before the previous segment was possibly longer than expected.
|
||||
# Move the player time forward to try to adjust for this.
|
||||
self.processor.client_abr_state.player_time_ms += self.processor.live_segment_target_duration_tolerance_ms
|
||||
self._sq_mismatch_forward_count += 1
|
||||
self._current_http_retry.error = e
|
||||
return
|
||||
elif self.processor.is_live and e.received_sequence_number == e.expected_sequence_number + 2:
|
||||
# The previous segment was possibly shorter than expected
|
||||
# Move the player time backwards to try to adjust for this.
|
||||
self.processor.client_abr_state.player_time_ms = max(0, self.processor.client_abr_state.player_time_ms - self.processor.live_segment_target_duration_tolerance_ms)
|
||||
self._sq_mismatch_backtrack_count += 1
|
||||
self._current_http_retry.error = e
|
||||
return
|
||||
raise e
|
||||
|
||||
def _process_media(self, part: UMPPart):
|
||||
header_id = read_varint(part.data)
|
||||
content_length = part.size - part.data.tell()
|
||||
result = self.processor.process_media(header_id, content_length, part.data)
|
||||
if result.sabr_part:
|
||||
yield result.sabr_part
|
||||
|
||||
def _process_media_end(self, part: UMPPart):
|
||||
header_id = read_varint(part.data)
|
||||
self._log_part(part, msg=f'Header ID: {header_id}')
|
||||
|
||||
result = self.processor.process_media_end(header_id)
|
||||
if result.is_new_segment:
|
||||
self._received_new_segments = True
|
||||
|
||||
if result.sabr_part:
|
||||
yield result.sabr_part
|
||||
|
||||
def _process_live_metadata(self, part: UMPPart):
|
||||
live_metadata = protobug.load(part.data, LiveMetadata)
|
||||
self._log_part(part, protobug_obj=live_metadata)
|
||||
yield from self.processor.process_live_metadata(live_metadata).seek_sabr_parts
|
||||
|
||||
def _process_stream_protection_status(self, part: UMPPart):
|
||||
sps = protobug.load(part.data, StreamProtectionStatus)
|
||||
self._log_part(part, msg=f'Status: {StreamProtectionStatus.Status(sps.status).name}', protobug_obj=sps)
|
||||
result = self.processor.process_stream_protection_status(sps)
|
||||
if result.sabr_part:
|
||||
yield result.sabr_part
|
||||
|
||||
def _process_sabr_redirect(self, part: UMPPart):
|
||||
sabr_redirect = protobug.load(part.data, SabrRedirect)
|
||||
self._log_part(part, protobug_obj=sabr_redirect)
|
||||
if not sabr_redirect.redirect_url:
|
||||
self.logger.warning('Server requested to redirect to an invalid URL')
|
||||
return
|
||||
self.url = sabr_redirect.redirect_url
|
||||
|
||||
def _process_format_initialization_metadata(self, part: UMPPart):
|
||||
fmt_init_metadata = protobug.load(part.data, FormatInitializationMetadata)
|
||||
self._log_part(part, protobug_obj=fmt_init_metadata)
|
||||
result = self.processor.process_format_initialization_metadata(fmt_init_metadata)
|
||||
if result.sabr_part:
|
||||
yield result.sabr_part
|
||||
|
||||
def _process_next_request_policy(self, part: UMPPart):
|
||||
next_request_policy = protobug.load(part.data, NextRequestPolicy)
|
||||
self._log_part(part, protobug_obj=next_request_policy)
|
||||
self.processor.process_next_request_policy(next_request_policy)
|
||||
|
||||
def _process_sabr_seek(self, part: UMPPart):
|
||||
sabr_seek = protobug.load(part.data, SabrSeek)
|
||||
self._log_part(part, protobug_obj=sabr_seek)
|
||||
yield from self.processor.process_sabr_seek(sabr_seek).seek_sabr_parts
|
||||
|
||||
def _process_sabr_error(self, part: UMPPart):
|
||||
sabr_error = protobug.load(part.data, SabrError)
|
||||
self._log_part(part, protobug_obj=sabr_error)
|
||||
self._current_http_retry.error = SabrStreamError(f'SABR Protocol Error: {sabr_error}')
|
||||
|
||||
def _process_sabr_context_update(self, part: UMPPart):
|
||||
sabr_ctx_update = protobug.load(part.data, SabrContextUpdate)
|
||||
self._log_part(part, protobug_obj=sabr_ctx_update)
|
||||
self.processor.process_sabr_context_update(sabr_ctx_update)
|
||||
|
||||
def _process_sabr_context_sending_policy(self, part: UMPPart):
|
||||
sabr_ctx_sending_policy = protobug.load(part.data, SabrContextSendingPolicy)
|
||||
self._log_part(part, protobug_obj=sabr_ctx_sending_policy)
|
||||
self.processor.process_sabr_context_sending_policy(sabr_ctx_sending_policy)
|
||||
|
||||
def _process_reload_player_response(self, part: UMPPart):
|
||||
reload_player_response = protobug.load(part.data, ReloadPlayerResponse)
|
||||
self._log_part(part, protobug_obj=reload_player_response)
|
||||
yield RefreshPlayerResponseSabrPart(
|
||||
reason=RefreshPlayerResponseSabrPart.Reason.SABR_RELOAD_PLAYER_RESPONSE,
|
||||
reload_playback_token=reload_player_response.reload_playback_params.token,
|
||||
)
|
||||
|
||||
def _process_fallback_server(self):
|
||||
# Attempt to fall back to another GVS host in the case the current one fails
|
||||
new_url = next_gvs_fallback_url(self.url)
|
||||
if not new_url:
|
||||
self.logger.debug('No more fallback hosts available')
|
||||
|
||||
self.logger.warning(f'Falling back to host {urllib.parse.urlparse(new_url).netloc}')
|
||||
self.url = new_url
|
||||
|
||||
def _gvs_expiry(self):
|
||||
return int_or_none(traverse_obj(parse_qs(self.url), ('expire', 0), get_all=False))
|
||||
|
||||
def _process_expiry(self):
|
||||
expires_at = self._gvs_expiry()
|
||||
|
||||
if not expires_at:
|
||||
self.logger.warning(
|
||||
'No expiry timestamp found in SABR URL. Will not be able to refresh.', once=True)
|
||||
return
|
||||
|
||||
if expires_at - self.expiry_threshold_sec >= time.time():
|
||||
self.logger.trace(f'SABR url expires in {int(expires_at - time.time())} seconds')
|
||||
return
|
||||
|
||||
self.logger.debug(
|
||||
f'Requesting player response refresh as SABR URL is due to expire in {self.expiry_threshold_sec} seconds')
|
||||
yield RefreshPlayerResponseSabrPart(reason=RefreshPlayerResponseSabrPart.Reason.SABR_URL_EXPIRY)
|
||||
|
||||
def _log_part(self, part: UMPPart, msg=None, protobug_obj=None, data=None):
|
||||
if self.logger.log_level > self.logger.LogLevel.TRACE:
|
||||
return
|
||||
message = f'[{part.part_id.name}]: (Size {part.size})'
|
||||
if protobug_obj:
|
||||
message += f' Parsed: {protobug_obj}'
|
||||
uf = list(unknown_fields(protobug_obj))
|
||||
if uf:
|
||||
message += f' (Unknown fields: {uf})'
|
||||
if msg:
|
||||
message += f' {msg}'
|
||||
if data:
|
||||
message += f' Data: {base64.b64encode(data).decode("utf-8")}'
|
||||
self.logger.trace(message.strip())
|
||||
|
||||
def _log_state(self):
|
||||
# TODO: refactor
|
||||
if self.logger.log_level > self.logger.LogLevel.DEBUG:
|
||||
return
|
||||
|
||||
if self.processor.is_live and self.processor.post_live:
|
||||
live_message = f'post_live ({self.processor.live_segment_target_duration_sec}s)'
|
||||
elif self.processor.is_live:
|
||||
live_message = f'live ({self.processor.live_segment_target_duration_sec}s)'
|
||||
else:
|
||||
live_message = 'not_live'
|
||||
|
||||
if self.processor.is_live:
|
||||
live_message += ' bid:' + str_or_none(broadcast_id_from_url(self.url))
|
||||
|
||||
consumed_ranges_message = (
|
||||
', '.join(
|
||||
f'{izf.format_id.itag}:'
|
||||
+ ', '.join(
|
||||
f'{cf.start_sequence_number}-{cf.end_sequence_number} '
|
||||
f'({cf.start_time_ms}-'
|
||||
f'{cf.start_time_ms + cf.duration_ms})'
|
||||
for cf in izf.consumed_ranges
|
||||
)
|
||||
for izf in self.processor.initialized_formats.values()
|
||||
) or 'none'
|
||||
)
|
||||
|
||||
izf_parts = []
|
||||
for izf in self.processor.initialized_formats.values():
|
||||
s = f'{izf.format_id.itag}'
|
||||
if izf.discard:
|
||||
s += 'd'
|
||||
p = []
|
||||
if izf.total_segments:
|
||||
p.append(f'{izf.total_segments}')
|
||||
if izf.sequence_lmt is not None:
|
||||
p.append(f'lmt={izf.sequence_lmt}')
|
||||
if p:
|
||||
s += ('(' + ','.join(p) + ')')
|
||||
izf_parts.append(s)
|
||||
|
||||
initialized_formats_message = ', '.join(izf_parts) or 'none'
|
||||
|
||||
unknown_part_message = ''
|
||||
if self._unknown_part_types:
|
||||
unknown_part_message = 'unkpt:' + ', '.join(part_type.name for part_type in self._unknown_part_types)
|
||||
|
||||
sabr_context_update_msg = ''
|
||||
if self.processor.sabr_context_updates:
|
||||
sabr_context_update_msg += 'cu:[' + ','.join(
|
||||
f'{k}{"(n)" if k not in self.processor.sabr_contexts_to_send else ""}'
|
||||
for k in self.processor.sabr_context_updates
|
||||
) + ']'
|
||||
|
||||
self.logger.debug(
|
||||
"[SABR State] "
|
||||
f"v:{self.processor.video_id or 'unknown'} "
|
||||
f"c:{self.processor.client_info.client_name.name} "
|
||||
f"t:{self.processor.client_abr_state.player_time_ms} "
|
||||
f"td:{self.processor.total_duration_ms if self.processor.total_duration_ms else 'n/a'} "
|
||||
f"h:{urllib.parse.urlparse(self.url).netloc} "
|
||||
f"exp:{dt.timedelta(seconds=int(self._gvs_expiry() - time.time())) if self._gvs_expiry() else 'n/a'} "
|
||||
f"rn:{self._request_number} rnns:{self._no_new_segments_tracker.consecutive_requests} "
|
||||
f"lnns:{self._no_new_segments_tracker.live_head_segment_started or 'n/a'} "
|
||||
f"mmb:{self._sq_mismatch_backtrack_count} mmf:{self._sq_mismatch_forward_count} "
|
||||
f"pot:{'Y' if self.processor.po_token else 'N'} "
|
||||
f"sps:{self.processor.stream_protection_status.name if self.processor.stream_protection_status else 'n/a'} "
|
||||
f"{live_message} "
|
||||
f"if:[{initialized_formats_message}] "
|
||||
f"cr:[{consumed_ranges_message}] "
|
||||
f"{sabr_context_update_msg} "
|
||||
f"{unknown_part_message}",
|
||||
)
|
@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import urllib.parse
|
||||
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.models import ConsumedRange
|
||||
from yt_dlp.utils import int_or_none, orderedSet, parse_qs, str_or_none, update_url_query
|
||||
|
||||
|
||||
def get_cr_chain(start_consumed_range: ConsumedRange, consumed_ranges: list[ConsumedRange]) -> list[ConsumedRange]:
|
||||
# TODO: unit test
|
||||
# Return the continuous consumed range chain starting from the given consumed range
|
||||
# Note: It is assumed a segment is only present in one consumed range - it should not be allowed in multiple (by process media header)
|
||||
chain = [start_consumed_range]
|
||||
for cr in sorted(consumed_ranges, key=lambda r: r.start_sequence_number):
|
||||
if cr.start_sequence_number == chain[-1].end_sequence_number + 1:
|
||||
chain.append(cr)
|
||||
elif cr.start_sequence_number > chain[-1].end_sequence_number + 1:
|
||||
break
|
||||
return chain
|
||||
|
||||
|
||||
def next_gvs_fallback_url(gvs_url):
|
||||
# TODO: unit test
|
||||
qs = parse_qs(gvs_url)
|
||||
gvs_url_parsed = urllib.parse.urlparse(gvs_url)
|
||||
fvip = int_or_none(qs.get('fvip', [None])[0])
|
||||
mvi = int_or_none(qs.get('mvi', [None])[0])
|
||||
mn = str_or_none(qs.get('mn', [None])[0], default='').split(',')
|
||||
fallback_count = int_or_none(qs.get('fallback_count', ['0'])[0], default=0)
|
||||
|
||||
hosts = []
|
||||
|
||||
def build_host(current_host, f, m):
|
||||
rr = current_host.startswith('rr')
|
||||
if f is None or m is None:
|
||||
return None
|
||||
return ('rr' if rr else 'r') + str(f) + '---' + m + '.googlevideo.com'
|
||||
|
||||
original_host = build_host(gvs_url_parsed.netloc, mvi, mn[0])
|
||||
|
||||
# Order of fallback hosts:
|
||||
# 1. Fallback host in url (mn[1] + fvip)
|
||||
# 2. Fallback hosts brute forced (this usually contains the original host)
|
||||
for mn_entry in reversed(mn):
|
||||
for fvip_entry in orderedSet([fvip, 1, 2, 3, 4, 5]):
|
||||
fallback_host = build_host(gvs_url_parsed.netloc, fvip_entry, mn_entry)
|
||||
if fallback_host and fallback_host not in hosts:
|
||||
hosts.append(fallback_host)
|
||||
|
||||
if not hosts or len(hosts) == 1:
|
||||
return None
|
||||
|
||||
# if first fallback, anchor to start of list so we start with the known fallback hosts
|
||||
# Sometimes we may get a SABR_REDIRECT after a fallback, which gives a new host with new fallbacks.
|
||||
# In this case, the original host indicated by the url params would match the current host
|
||||
current_host_index = -1
|
||||
if fallback_count > 0 and gvs_url_parsed.netloc != original_host:
|
||||
with contextlib.suppress(ValueError):
|
||||
current_host_index = hosts.index(gvs_url_parsed.netloc)
|
||||
|
||||
def next_host(idx, h):
|
||||
return h[(idx + 1) % len(h)]
|
||||
|
||||
new_host = next_host(current_host_index + 1, hosts)
|
||||
# If the current URL only has one fallback host, then the first fallback host is the same as the current host.
|
||||
if new_host == gvs_url_parsed.netloc:
|
||||
new_host = next_host(current_host_index + 2, hosts)
|
||||
|
||||
# TODO: do not return new_host if it still matches the original host
|
||||
return update_url_query(
|
||||
gvs_url_parsed._replace(netloc=new_host).geturl(), {'fallback_count': fallback_count + 1})
|
||||
|
||||
|
||||
def ticks_to_ms(time_ticks: int, timescale: int):
|
||||
if time_ticks is None or timescale is None:
|
||||
return None
|
||||
return math.ceil((time_ticks / timescale) * 1000)
|
||||
|
||||
|
||||
def broadcast_id_from_url(url: str) -> str | None:
|
||||
return str_or_none(parse_qs(url).get('id', [None])[0])
|
@ -0,0 +1,161 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import io
|
||||
|
||||
|
||||
class UMPPartId(enum.IntEnum):
|
||||
UNKNOWN = -1
|
||||
ONESIE_HEADER = 10
|
||||
ONESIE_DATA = 11
|
||||
ONESIE_ENCRYPTED_MEDIA = 12
|
||||
MEDIA_HEADER = 20
|
||||
MEDIA = 21
|
||||
MEDIA_END = 22
|
||||
LIVE_METADATA = 31
|
||||
HOSTNAME_CHANGE_HINT = 32
|
||||
LIVE_METADATA_PROMISE = 33
|
||||
LIVE_METADATA_PROMISE_CANCELLATION = 34
|
||||
NEXT_REQUEST_POLICY = 35
|
||||
USTREAMER_VIDEO_AND_FORMAT_DATA = 36
|
||||
FORMAT_SELECTION_CONFIG = 37
|
||||
USTREAMER_SELECTED_MEDIA_STREAM = 38
|
||||
FORMAT_INITIALIZATION_METADATA = 42
|
||||
SABR_REDIRECT = 43
|
||||
SABR_ERROR = 44
|
||||
SABR_SEEK = 45
|
||||
RELOAD_PLAYER_RESPONSE = 46
|
||||
PLAYBACK_START_POLICY = 47
|
||||
ALLOWED_CACHED_FORMATS = 48
|
||||
START_BW_SAMPLING_HINT = 49
|
||||
PAUSE_BW_SAMPLING_HINT = 50
|
||||
SELECTABLE_FORMATS = 51
|
||||
REQUEST_IDENTIFIER = 52
|
||||
REQUEST_CANCELLATION_POLICY = 53
|
||||
ONESIE_PREFETCH_REJECTION = 54
|
||||
TIMELINE_CONTEXT = 55
|
||||
REQUEST_PIPELINING = 56
|
||||
SABR_CONTEXT_UPDATE = 57
|
||||
STREAM_PROTECTION_STATUS = 58
|
||||
SABR_CONTEXT_SENDING_POLICY = 59
|
||||
LAWNMOWER_POLICY = 60
|
||||
SABR_ACK = 61
|
||||
END_OF_TRACK = 62
|
||||
CACHE_LOAD_POLICY = 63
|
||||
LAWNMOWER_MESSAGING_POLICY = 64
|
||||
PREWARM_CONNECTION = 65
|
||||
PLAYBACK_DEBUG_INFO = 66
|
||||
SNACKBAR_MESSAGE = 67
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
return cls.UNKNOWN
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UMPPart:
|
||||
part_id: UMPPartId
|
||||
size: int
|
||||
data: io.BufferedIOBase
|
||||
|
||||
|
||||
class UMPDecoder:
|
||||
def __init__(self, fp: io.BufferedIOBase):
|
||||
self.fp = fp
|
||||
|
||||
def iter_parts(self):
|
||||
while not self.fp.closed:
|
||||
part_type = read_varint(self.fp)
|
||||
if part_type == -1 and not self.fp.closed:
|
||||
self.fp.close()
|
||||
|
||||
if self.fp.closed:
|
||||
break
|
||||
part_size = read_varint(self.fp)
|
||||
if part_size == -1 and not self.fp.closed:
|
||||
self.fp.close()
|
||||
|
||||
if self.fp.closed:
|
||||
raise EOFError('Unexpected EOF while reading part size')
|
||||
|
||||
part_data = self.fp.read(part_size)
|
||||
# In the future, we could allow streaming the part data.
|
||||
# But we will need to ensure that each part is completely read before continuing.
|
||||
yield UMPPart(UMPPartId(part_type), part_size, io.BytesIO(part_data))
|
||||
|
||||
|
||||
class UMPEncoder:
|
||||
def __init__(self, fp: io.BufferedIOBase):
|
||||
self.fp = fp
|
||||
|
||||
def write_part(self, part: UMPPart) -> None:
|
||||
write_varint(self.fp, part.part_id.value)
|
||||
write_varint(self.fp, part.size)
|
||||
self.fp.write(part.data.read())
|
||||
|
||||
__enter__ = lambda self: self
|
||||
__exit__ = lambda self, exc_type, exc_value, traceback: None
|
||||
|
||||
|
||||
def read_varint(fp: io.BufferedIOBase) -> int:
|
||||
# https://web.archive.org/web/20250430054327/https://github.com/gsuberland/UMP_Format/blob/main/UMP_Format.md
|
||||
# https://web.archive.org/web/20250429151021/https://github.com/davidzeng0/innertube/blob/main/googlevideo/ump.md
|
||||
byte = fp.read(1)
|
||||
if not byte:
|
||||
# Expected EOF
|
||||
return -1
|
||||
|
||||
prefix = byte[0]
|
||||
size = varint_size(prefix)
|
||||
result = 0
|
||||
shift = 0
|
||||
|
||||
if size != 5:
|
||||
shift = 8 - size
|
||||
mask = (1 << shift) - 1
|
||||
result |= prefix & mask
|
||||
|
||||
for _ in range(1, size):
|
||||
next_byte = fp.read(1)
|
||||
if not next_byte:
|
||||
return -1
|
||||
byte_int = next_byte[0]
|
||||
result |= byte_int << shift
|
||||
shift += 8
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def varint_size(byte: int) -> int:
|
||||
return 1 if byte < 128 else 2 if byte < 192 else 3 if byte < 224 else 4 if byte < 240 else 5
|
||||
|
||||
|
||||
def write_varint(fp: io.BufferedIOBase, value: int) -> None:
|
||||
# ref: https://github.com/LuanRT/googlevideo/blob/main/src/core/UmpWriter.ts
|
||||
if value < 0:
|
||||
raise ValueError('Value must be a non-negative integer')
|
||||
|
||||
if value < 128:
|
||||
fp.write(bytes([value]))
|
||||
elif value < 16384:
|
||||
fp.write(bytes([
|
||||
(value & 0x3F) | 0x80,
|
||||
value >> 6,
|
||||
]))
|
||||
elif value < 2097152:
|
||||
fp.write(bytes([
|
||||
(value & 0x1F) | 0xC0,
|
||||
(value >> 5) & 0xFF,
|
||||
value >> 13,
|
||||
]))
|
||||
elif value < 268435456:
|
||||
fp.write(bytes([
|
||||
(value & 0x0F) | 0xE0,
|
||||
(value >> 4) & 0xFF,
|
||||
(value >> 12) & 0xFF,
|
||||
value >> 20,
|
||||
]))
|
||||
else:
|
||||
data = bytearray(5)
|
||||
data[0] = 0xF0
|
||||
data[1:5] = value.to_bytes(4, 'little')
|
||||
fp.write(data)
|
Loading…
Reference in New Issue