|
|
@ -1,14 +1,16 @@
|
|
|
|
import pytest
|
|
|
|
import pytest
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
|
|
|
|
|
|
|
|
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor
|
|
|
|
from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult
|
|
|
|
from yt_dlp.extractor.youtube._streaming.sabr.models import (
|
|
|
|
from yt_dlp.extractor.youtube._streaming.sabr.models import (
|
|
|
|
AudioSelector,
|
|
|
|
AudioSelector,
|
|
|
|
VideoSelector,
|
|
|
|
VideoSelector,
|
|
|
|
CaptionSelector,
|
|
|
|
CaptionSelector,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
|
|
|
|
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus
|
|
|
|
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo
|
|
|
|
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
@pytest.fixture
|
|
|
@ -233,7 +235,6 @@ class TestSabrProcessorInitialization:
|
|
|
|
assert processor.live_segment_target_duration_sec == 5
|
|
|
|
assert processor.live_segment_target_duration_sec == 5
|
|
|
|
assert processor.live_segment_target_duration_tolerance_ms == 100
|
|
|
|
assert processor.live_segment_target_duration_tolerance_ms == 100
|
|
|
|
assert processor.start_time_ms == 0
|
|
|
|
assert processor.start_time_ms == 0
|
|
|
|
assert processor.live_end_segment_tolerance == 10
|
|
|
|
|
|
|
|
assert processor.post_live is False
|
|
|
|
assert processor.post_live is False
|
|
|
|
|
|
|
|
|
|
|
|
def test_override_defaults(self, base_args):
|
|
|
|
def test_override_defaults(self, base_args):
|
|
|
@ -242,11 +243,65 @@ class TestSabrProcessorInitialization:
|
|
|
|
live_segment_target_duration_sec=8,
|
|
|
|
live_segment_target_duration_sec=8,
|
|
|
|
live_segment_target_duration_tolerance_ms=42,
|
|
|
|
live_segment_target_duration_tolerance_ms=42,
|
|
|
|
start_time_ms=123,
|
|
|
|
start_time_ms=123,
|
|
|
|
live_end_segment_tolerance=3,
|
|
|
|
|
|
|
|
post_live=True,
|
|
|
|
post_live=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
assert processor.live_segment_target_duration_sec == 8
|
|
|
|
assert processor.live_segment_target_duration_sec == 8
|
|
|
|
assert processor.live_segment_target_duration_tolerance_ms == 42
|
|
|
|
assert processor.live_segment_target_duration_tolerance_ms == 42
|
|
|
|
assert processor.start_time_ms == 123
|
|
|
|
assert processor.start_time_ms == 123
|
|
|
|
assert processor.live_end_segment_tolerance == 3
|
|
|
|
|
|
|
|
assert processor.post_live is True
|
|
|
|
assert processor.post_live is True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestStreamProtectionStatusPart:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
|
|
|
'sps,po_token,expected_status',
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
(StreamProtectionStatus.Status.OK, None, PoTokenStatusSabrPart.PoTokenStatus.NOT_REQUIRED),
|
|
|
|
|
|
|
|
(StreamProtectionStatus.Status.OK, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.OK),
|
|
|
|
|
|
|
|
(StreamProtectionStatus.Status.ATTESTATION_PENDING, None, PoTokenStatusSabrPart.PoTokenStatus.PENDING_MISSING),
|
|
|
|
|
|
|
|
(StreamProtectionStatus.Status.ATTESTATION_PENDING, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.PENDING),
|
|
|
|
|
|
|
|
(StreamProtectionStatus.Status.ATTESTATION_REQUIRED, None, PoTokenStatusSabrPart.PoTokenStatus.MISSING),
|
|
|
|
|
|
|
|
(StreamProtectionStatus.Status.ATTESTATION_REQUIRED, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.INVALID),
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_stream_protection_status_part(self, base_args, sps, po_token, expected_status):
|
|
|
|
|
|
|
|
processor = SabrProcessor(**base_args, po_token=po_token)
|
|
|
|
|
|
|
|
part = StreamProtectionStatus(status=sps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = processor.process_stream_protection_status(part)
|
|
|
|
|
|
|
|
assert isinstance(result, ProcessStreamProtectionStatusResult)
|
|
|
|
|
|
|
|
assert isinstance(result.sabr_part, PoTokenStatusSabrPart)
|
|
|
|
|
|
|
|
assert result.sabr_part.status == expected_status
|
|
|
|
|
|
|
|
assert processor.stream_protection_status == sps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_no_stream_protection_status(self, logger, base_args):
|
|
|
|
|
|
|
|
processor = SabrProcessor(**base_args, po_token='valid_token')
|
|
|
|
|
|
|
|
part = StreamProtectionStatus(status=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = processor.process_stream_protection_status(part)
|
|
|
|
|
|
|
|
assert isinstance(result, ProcessStreamProtectionStatusResult)
|
|
|
|
|
|
|
|
assert result.sabr_part is None
|
|
|
|
|
|
|
|
assert processor.stream_protection_status is None
|
|
|
|
|
|
|
|
assert logger.warning.call_count == 1
|
|
|
|
|
|
|
|
logger.warning.assert_called_with(
|
|
|
|
|
|
|
|
'Received an unknown StreamProtectionStatus: StreamProtectionStatus(status=None, max_retries=None)',
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestNextRequestPolicyPart:
|
|
|
|
|
|
|
|
def test_next_request_policy_part(self, logger, base_args):
|
|
|
|
|
|
|
|
processor = SabrProcessor(**base_args)
|
|
|
|
|
|
|
|
next_request_policy = NextRequestPolicy(target_audio_readahead_ms=123)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = processor.process_next_request_policy(next_request_policy)
|
|
|
|
|
|
|
|
assert result is None
|
|
|
|
|
|
|
|
assert processor.next_request_policy is next_request_policy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Verify it is overridden in the processor on another call
|
|
|
|
|
|
|
|
next_request_policy = NextRequestPolicy(target_video_readahead_ms=456)
|
|
|
|
|
|
|
|
result = processor.process_next_request_policy(next_request_policy)
|
|
|
|
|
|
|
|
assert result is None
|
|
|
|
|
|
|
|
assert processor.next_request_policy is next_request_policy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Check logger trace was called
|
|
|
|
|
|
|
|
assert logger.trace.call_count == 2
|
|
|
|