diff --git a/downloader/downloader/main.py b/downloader/downloader/main.py index 50989ac..3189fcb 100644 --- a/downloader/downloader/main.py +++ b/downloader/downloader/main.py @@ -265,6 +265,12 @@ class StreamWorker(object): self.stopping = gevent.event.Event() # set to stop main loop self.getters = {} # map from url to SegmentGetter self.done = gevent.event.Event() # set when stopped and all getters are done + # Set up a Session for connection pooling. Note that if we have an issue, + # a new worker is created and so it gets a new session, just in case there's a problem + # with our connection pool. + # This worker's SegmentGetters will use its session by default for performance, + # but will fall back to a new one if something goes wrong. + self.session = requests.Session() def __repr__(self): return "<{} at 0x{:x} for stream {!r}>".format(type(self).__name__, id(self), self.stream) @@ -307,7 +313,7 @@ class StreamWorker(object): self.logger.debug("Getting media playlist {}".format(self.url)) try: with soft_hard_timeout(self.logger, "getting media playlist", self.FETCH_TIMEOUTS, self.trigger_new_worker): - playlist = twitch.get_media_playlist(self.url) + playlist = twitch.get_media_playlist(self.url, session=self.session) except Exception as e: self.logger.warning("Failed to fetch media playlist {}".format(self.url), exc_info=True) self.trigger_new_worker() @@ -332,7 +338,15 @@ class StreamWorker(object): if segment.uri not in self.getters: if date is None: raise ValueError("Cannot determine date of segment") - self.getters[segment.uri] = SegmentGetter(self.logger, self.manager.base_dir, self.manager.channel, self.stream, segment, date) + self.getters[segment.uri] = SegmentGetter( + self.logger, + self.session, + self.manager.base_dir, + self.manager.channel, + self.stream, + segment, + date, + ) gevent.spawn(self.getters[segment.uri].run) if date is not None: date += datetime.timedelta(seconds=segment.duration) @@ -376,7 +390,7 @@ class SegmentGetter(object): FETCH_HEADERS_TIMEOUTS = 5, 30 FETCH_FULL_TIMEOUTS = 15, 240 - def __init__(self, parent_logger, base_dir, channel, stream, segment, date): + def __init__(self, parent_logger, session, base_dir, channel, stream, segment, date): self.logger = parent_logger.getChild("SegmentGetter@{:x}".format(id(self))) self.base_dir = base_dir self.channel = channel @@ -386,6 +400,8 @@ class SegmentGetter(object): self.prefix = self.make_path_prefix() self.retry = None # Event, set to begin retrying self.done = gevent.event.Event() # set when file exists or we give up + # Our parent's connection pool, but we'll replace it if there's any issues + self.session = session def run(self): try: @@ -410,6 +426,9 @@ class SegmentGetter(object): # If worker has returned, and return value is true, we're done if worker.ready() and worker.value: break + # Create a new session, so we don't reuse a connection from the old session + # which had an error / some other issue. This is mostly just out of paranoia. + self.session = requests.Session() # if retry not set, wait for FETCH_RETRY first self.retry.wait(common.jitter(self.FETCH_RETRY)) self.logger.debug("Getter is done") @@ -470,7 +489,7 @@ class SegmentGetter(object): self.logger.debug("Downloading segment {} to {}".format(self.segment, temp_path)) with soft_hard_timeout(self.logger, "getting and writing segment", self.FETCH_FULL_TIMEOUTS, retry.set): with soft_hard_timeout(self.logger, "getting segment headers", self.FETCH_HEADERS_TIMEOUTS, retry.set): - resp = requests.get(self.segment.uri, stream=True) + resp = self.session.get(self.segment.uri, stream=True) # twitch returns 403 for expired segment urls, and 404 for very old urls where the original segment is gone. # the latter can happen if we have a network issue that cuts us off from twitch for some time. if resp.status_code in (403, 404): diff --git a/downloader/downloader/twitch.py b/downloader/downloader/twitch.py index ba37d6c..5e7fcfe 100644 --- a/downloader/downloader/twitch.py +++ b/downloader/downloader/twitch.py @@ -10,9 +10,9 @@ import hls_playlist logger = logging.getLogger(__name__) -def get_master_playlist(channel): +def get_master_playlist(channel, session=requests): """Get the master playlist for given channel from twitch""" - resp = requests.get( + resp = session.get( "https://api.twitch.tv/api/channels/{}/access_token.json".format(channel), params={'as3': 't'}, headers={ @@ -22,7 +22,7 @@ def get_master_playlist(channel): ) resp.raise_for_status() # getting access token token = resp.json() - resp = requests.get( + resp = session.get( "https://usher.ttvnw.net/api/channel/hls/{}.m3u8".format(channel), params={ # Taken from streamlink. Unsure what's needed and what changing things can do. @@ -93,7 +93,7 @@ def get_media_playlist_uris(master_playlist, target_qualities): return {name: variant.uri for name, variant in variants.items()} -def get_media_playlist(uri): - resp = requests.get(uri) +def get_media_playlist(uri, session=requests): + resp = session.get(uri) resp.raise_for_status() return hls_playlist.load(resp.text, base_uri=resp.url)